pgtoolkit/tools/batchcopy.py
changeset 101 2a2d0d5df03b
parent 98 024299702087
equal deleted inserted replaced
100:d6088dba8fea 101:2a2d0d5df03b
       
     1 from pgtoolkit.toolbase import SrcDstTablesTool
       
     2 from pgtoolkit.pgmanager import IntegrityError
       
     3 
       
     4 
       
     5 class BatchCopyTool(SrcDstTablesTool):
       
     6 
       
     7     """
       
     8     Copy data from one table to another, filtering by specified condition.
       
     9 
       
    10     """
       
    11 
       
    12     def __init__(self):
       
    13         SrcDstTablesTool.__init__(self, name='batchcopy', desc='')
       
    14 
       
    15     def specify_args(self):
       
    16         SrcDstTablesTool.specify_args(self)
       
    17         self.parser.add_argument('--table-name', type=str, help='Table to be copied.')
       
    18         self.parser.add_argument('--src-filter', type=str, help='WHERE condition for source query.')
       
    19         self.parser.add_argument('--file-with-ids', type=str, help='Read source IDs from file (each ID on new line). Use these in --src-filter as {ids}')
       
    20         self.parser.add_argument('--dst-exists', choices=['rollback', 'ignore', 'update'], default='rollback', help='What to do when destination record already exists.')
       
    21 
       
    22     def main(self):
       
    23         # read list of IDs from file
       
    24         ids = '<no IDs read>'
       
    25         if self.args.file_with_ids:
       
    26             with open(self.args.file_with_ids, 'r') as f:
       
    27                 ids = ','.join(ln.rstrip() for ln in f.readlines())
       
    28 
       
    29         # read source data
       
    30         with self.pgm.cursor('src') as src_curs:
       
    31             condition = self.args.src_filter.format(ids=ids) or 'true'
       
    32             src_curs.execute('SELECT * FROM {} WHERE {}'.format(self.args.table_name, condition))
       
    33             #TODO:  ORDER BY id OFFSET 0 LIMIT 100
       
    34             data = src_curs.fetchall_dict()
       
    35             src_curs.connection.commit()
       
    36 
       
    37         with self.pgm.cursor('dst') as dst_curs:
       
    38             copied = 0
       
    39             for row in data:
       
    40                 keys = ', '.join(row.keys())
       
    41                 values_mask = ', '.join(['%s'] * len(row))
       
    42                 query = 'INSERT INTO {} ({}) VALUES ({})'.format(self.args.table_name, keys, values_mask)
       
    43                 try:
       
    44                     dst_curs.execute('SAVEPOINT the_query;')
       
    45                     dst_curs.execute(query, list(row.values()))
       
    46                     dst_curs.execute('RELEASE SAVEPOINT the_query;')
       
    47                     copied += 1
       
    48                 except IntegrityError:
       
    49                     if self.args.dst_exists == 'rollback':
       
    50                         dst_curs.connection.rollback()
       
    51                         break
       
    52                     elif self.args.dst_exists == 'ignore':
       
    53                         dst_curs.execute('ROLLBACK TO SAVEPOINT the_query;')
       
    54                     elif self.args.dst_exists == 'update':
       
    55                         raise NotImplementedError()
       
    56             dst_curs.connection.commit()
       
    57 
       
    58         self.log.info('Copied %s rows.', copied)
       
    59 
       
    60 
       
    61 cls = BatchCopyTool