Refactor ToolBase to allow tool composition. Add TableSync tool (composited). Move more tools under pgtool.
from pgtoolkit.toolbase import SrcDstTablesTool
from pgtoolkit.pgmanager import IntegrityError
class BatchCopyTool(SrcDstTablesTool):
"""
Copy data from one table to another, filtering by specified condition.
"""
def __init__(self):
SrcDstTablesTool.__init__(self, name='batchcopy', desc='')
def specify_args(self):
SrcDstTablesTool.specify_args(self)
self.parser.add_argument('--table-name', type=str, help='Table to be copied.')
self.parser.add_argument('--src-filter', type=str, help='WHERE condition for source query.')
self.parser.add_argument('--file-with-ids', type=str, help='Read source IDs from file (each ID on new line). Use these in --src-filter as {ids}')
self.parser.add_argument('--dst-exists', choices=['rollback', 'ignore', 'update'], default='rollback', help='What to do when destination record already exists.')
def main(self):
# read list of IDs from file
ids = '<no IDs read>'
if self.args.file_with_ids:
with open(self.args.file_with_ids, 'r') as f:
ids = ','.join(ln.rstrip() for ln in f.readlines())
# read source data
with self.pgm.cursor('src') as src_curs:
condition = self.args.src_filter.format(ids=ids) or 'true'
src_curs.execute('SELECT * FROM {} WHERE {}'.format(self.args.table_name, condition))
#TODO: ORDER BY id OFFSET 0 LIMIT 100
data = src_curs.fetchall_dict()
src_curs.connection.commit()
with self.pgm.cursor('dst') as dst_curs:
copied = 0
for row in data:
keys = ', '.join(row.keys())
values_mask = ', '.join(['%s'] * len(row))
query = 'INSERT INTO {} ({}) VALUES ({})'.format(self.args.table_name, keys, values_mask)
try:
dst_curs.execute('SAVEPOINT the_query;')
dst_curs.execute(query, list(row.values()))
dst_curs.execute('RELEASE SAVEPOINT the_query;')
copied += 1
except IntegrityError:
if self.args.dst_exists == 'rollback':
dst_curs.connection.rollback()
break
elif self.args.dst_exists == 'ignore':
dst_curs.execute('ROLLBACK TO SAVEPOINT the_query;')
elif self.args.dst_exists == 'update':
raise NotImplementedError()
dst_curs.connection.commit()
self.log.info('Copied %s rows.', copied)
cls = BatchCopyTool