1 from pgtoolkit.toolbase import SrcDstTablesTool |
|
2 from pgtoolkit.pgmanager import IntegrityError |
|
3 |
|
4 |
|
5 class BatchCopyTool(SrcDstTablesTool): |
|
6 |
|
7 """ |
|
8 Copy data from one table to another, filtering by specified condition. |
|
9 |
|
10 """ |
|
11 |
|
12 def __init__(self): |
|
13 SrcDstTablesTool.__init__(self, name='batchcopy', desc='') |
|
14 |
|
15 def specify_args(self): |
|
16 SrcDstTablesTool.specify_args(self) |
|
17 self.parser.add_argument('--table-name', type=str, help='Table to be copied.') |
|
18 self.parser.add_argument('--src-filter', type=str, help='WHERE condition for source query.') |
|
19 self.parser.add_argument('--file-with-ids', type=str, help='Read source IDs from file (each ID on new line). Use these in --src-filter as {ids}') |
|
20 self.parser.add_argument('--dst-exists', choices=['rollback', 'ignore', 'update'], default='rollback', help='What to do when destination record already exists.') |
|
21 |
|
22 def main(self): |
|
23 # read list of IDs from file |
|
24 ids = '<no IDs read>' |
|
25 if self.args.file_with_ids: |
|
26 with open(self.args.file_with_ids, 'r') as f: |
|
27 ids = ','.join(ln.rstrip() for ln in f.readlines()) |
|
28 |
|
29 # read source data |
|
30 with self.pgm.cursor('src') as src_curs: |
|
31 condition = self.args.src_filter.format(ids=ids) or 'true' |
|
32 src_curs.execute('SELECT * FROM {} WHERE {}'.format(self.args.table_name, condition)) |
|
33 #TODO: ORDER BY id OFFSET 0 LIMIT 100 |
|
34 data = src_curs.fetchall_dict() |
|
35 src_curs.connection.commit() |
|
36 |
|
37 with self.pgm.cursor('dst') as dst_curs: |
|
38 copied = 0 |
|
39 for row in data: |
|
40 keys = ', '.join(row.keys()) |
|
41 values_mask = ', '.join(['%s'] * len(row)) |
|
42 query = 'INSERT INTO {} ({}) VALUES ({})'.format(self.args.table_name, keys, values_mask) |
|
43 try: |
|
44 dst_curs.execute('SAVEPOINT the_query;') |
|
45 dst_curs.execute(query, list(row.values())) |
|
46 dst_curs.execute('RELEASE SAVEPOINT the_query;') |
|
47 copied += 1 |
|
48 except IntegrityError: |
|
49 if self.args.dst_exists == 'rollback': |
|
50 dst_curs.connection.rollback() |
|
51 break |
|
52 elif self.args.dst_exists == 'ignore': |
|
53 dst_curs.execute('ROLLBACK TO SAVEPOINT the_query;') |
|
54 elif self.args.dst_exists == 'update': |
|
55 raise NotImplementedError() |
|
56 dst_curs.connection.commit() |
|
57 |
|
58 self.log.info('Copied %s rows.', copied) |
|
59 |
|
60 |
|
61 cls = BatchCopyTool |
|