1 #!/usr/bin/env python3 |
|
2 |
|
3 from pgtoolkit import toolbase |
|
4 from pgtoolkit.pgmanager import IntegrityError |
|
5 |
|
6 |
|
7 class BatchCopyTool(toolbase.SrcDstTablesTool): |
|
8 def __init__(self): |
|
9 toolbase.SrcDstTablesTool.__init__(self, name='batchcopy', desc='Copy data from one table to another.') |
|
10 |
|
11 self.parser.add_argument('--table-name', type=str, help='Table to be copied.') |
|
12 self.parser.add_argument('--src-filter', type=str, help='WHERE condition for source query.') |
|
13 self.parser.add_argument('--file-with-ids', type=str, help='Read source IDs from file (each ID on new line). Use these in --src-filter as {ids}') |
|
14 self.parser.add_argument('--dst-exists', choices=['rollback', 'ignore', 'update'], default='rollback', help='What to do when destination record already exists.') |
|
15 |
|
16 self.init() |
|
17 |
|
18 def main(self): |
|
19 # read list of IDs from file |
|
20 ids = '<no IDs read>' |
|
21 if self.args.file_with_ids: |
|
22 with open(self.args.file_with_ids, 'r') as f: |
|
23 ids = ','.join(ln.rstrip() for ln in f.readlines()) |
|
24 |
|
25 # read source data |
|
26 with self.pgm.cursor('src') as src_curs: |
|
27 condition = self.args.src_filter.format(ids=ids) or 'true' |
|
28 src_curs.execute('SELECT * FROM {} WHERE {}'.format(self.args.table_name, condition)) |
|
29 #TODO: ORDER BY id OFFSET 0 LIMIT 100 |
|
30 data = src_curs.fetchall_dict() |
|
31 src_curs.connection.commit() |
|
32 |
|
33 with self.pgm.cursor('dst') as dst_curs: |
|
34 copied = 0 |
|
35 for row in data: |
|
36 keys = ', '.join(row.keys()) |
|
37 values_mask = ', '.join(['%s'] * len(row)) |
|
38 query = 'INSERT INTO {} ({}) VALUES ({})'.format(self.args.table_name, keys, values_mask) |
|
39 try: |
|
40 dst_curs.execute('SAVEPOINT the_query;') |
|
41 dst_curs.execute(query, list(row.values())) |
|
42 dst_curs.execute('RELEASE SAVEPOINT the_query;') |
|
43 copied += 1 |
|
44 except IntegrityError: |
|
45 if self.args.dst_exists == 'rollback': |
|
46 dst_curs.connection.rollback() |
|
47 break |
|
48 elif self.args.dst_exists == 'ignore': |
|
49 dst_curs.execute('ROLLBACK TO SAVEPOINT the_query;') |
|
50 elif self.args.dst_exists == 'update': |
|
51 raise NotImplementedError() |
|
52 dst_curs.connection.commit() |
|
53 |
|
54 self.log.info('Copied %s rows.', copied) |
|
55 |
|
56 |
|
57 tool = BatchCopyTool() |
|
58 tool.main() |
|