|      1 #!/usr/bin/env python3 |         | 
|      2  |         | 
|      3 from pgtoolkit import toolbase |         | 
|      4 from pgtoolkit.pgmanager import IntegrityError |         | 
|      5  |         | 
|      6  |         | 
|      7 class BatchCopyTool(toolbase.SrcDstTablesTool): |         | 
|      8     def __init__(self): |         | 
|      9         toolbase.SrcDstTablesTool.__init__(self, name='batchcopy', desc='Copy data from one table to another.') |         | 
|     10  |         | 
|     11         self.parser.add_argument('--table-name', type=str, help='Table to be copied.') |         | 
|     12         self.parser.add_argument('--src-filter', type=str, help='WHERE condition for source query.') |         | 
|     13         self.parser.add_argument('--file-with-ids', type=str, help='Read source IDs from file (each ID on new line). Use these in --src-filter as {ids}') |         | 
|     14         self.parser.add_argument('--dst-exists', choices=['rollback', 'ignore', 'update'], default='rollback', help='What to do when destination record already exists.') |         | 
|     15  |         | 
|     16         self.init() |         | 
|     17  |         | 
|     18     def main(self): |         | 
|     19         # read list of IDs from file |         | 
|     20         ids = '<no IDs read>' |         | 
|     21         if self.args.file_with_ids: |         | 
|     22             with open(self.args.file_with_ids, 'r') as f: |         | 
|     23                 ids = ','.join(ln.rstrip() for ln in f.readlines()) |         | 
|     24  |         | 
|     25         # read source data |         | 
|     26         with self.pgm.cursor('src') as src_curs: |         | 
|     27             condition = self.args.src_filter.format(ids=ids) or 'true' |         | 
|     28             src_curs.execute('SELECT * FROM {} WHERE {}'.format(self.args.table_name, condition)) |         | 
|     29             #TODO:  ORDER BY id OFFSET 0 LIMIT 100 |         | 
|     30             data = src_curs.fetchall_dict() |         | 
|     31             src_curs.connection.commit() |         | 
|     32  |         | 
|     33         with self.pgm.cursor('dst') as dst_curs: |         | 
|     34             copied = 0 |         | 
|     35             for row in data: |         | 
|     36                 keys = ', '.join(row.keys()) |         | 
|     37                 values_mask = ', '.join(['%s'] * len(row)) |         | 
|     38                 query = 'INSERT INTO {} ({}) VALUES ({})'.format(self.args.table_name, keys, values_mask) |         | 
|     39                 try: |         | 
|     40                     dst_curs.execute('SAVEPOINT the_query;') |         | 
|     41                     dst_curs.execute(query, list(row.values())) |         | 
|     42                     dst_curs.execute('RELEASE SAVEPOINT the_query;') |         | 
|     43                     copied += 1 |         | 
|     44                 except IntegrityError: |         | 
|     45                     if self.args.dst_exists == 'rollback': |         | 
|     46                         dst_curs.connection.rollback() |         | 
|     47                         break |         | 
|     48                     elif self.args.dst_exists == 'ignore': |         | 
|     49                         dst_curs.execute('ROLLBACK TO SAVEPOINT the_query;') |         | 
|     50                     elif self.args.dst_exists == 'update': |         | 
|     51                         raise NotImplementedError() |         | 
|     52             dst_curs.connection.commit() |         | 
|     53  |         | 
|     54         self.log.info('Copied %s rows.', copied) |         | 
|     55  |         | 
|     56  |         | 
|     57 tool = BatchCopyTool() |         | 
|     58 tool.main() |         |