batchcopy.py
author Radek Brich <brich.radek@ifortuna.cz>
Tue, 06 May 2014 18:34:38 +0200
changeset 99 245646538743
parent 98 024299702087
permissions -rwxr-xr-x
Update runquery tool: Add --one-query-per-line parameter.

#!/usr/bin/env python3

from pgtoolkit import toolbase
from pgtoolkit.pgmanager import IntegrityError


class BatchCopyTool(toolbase.SrcDstTablesTool):
    def __init__(self):
        toolbase.SrcDstTablesTool.__init__(self, name='batchcopy', desc='Copy data from one table to another.')

        self.parser.add_argument('--table-name', type=str, help='Table to be copied.')
        self.parser.add_argument('--src-filter', type=str, help='WHERE condition for source query.')
        self.parser.add_argument('--file-with-ids', type=str, help='Read source IDs from file (each ID on new line). Use these in --src-filter as {ids}')
        self.parser.add_argument('--dst-exists', choices=['rollback', 'ignore', 'update'], default='rollback', help='What to do when destination record already exists.')

        self.init()

    def main(self):
        # read list of IDs from file
        ids = '<no IDs read>'
        if self.args.file_with_ids:
            with open(self.args.file_with_ids, 'r') as f:
                ids = ','.join(ln.rstrip() for ln in f.readlines())

        # read source data
        with self.pgm.cursor('src') as src_curs:
            condition = self.args.src_filter.format(ids=ids) or 'true'
            src_curs.execute('SELECT * FROM {} WHERE {}'.format(self.args.table_name, condition))
            #TODO:  ORDER BY id OFFSET 0 LIMIT 100
            data = src_curs.fetchall_dict()
            src_curs.connection.commit()

        with self.pgm.cursor('dst') as dst_curs:
            copied = 0
            for row in data:
                keys = ', '.join(row.keys())
                values_mask = ', '.join(['%s'] * len(row))
                query = 'INSERT INTO {} ({}) VALUES ({})'.format(self.args.table_name, keys, values_mask)
                try:
                    dst_curs.execute('SAVEPOINT the_query;')
                    dst_curs.execute(query, list(row.values()))
                    dst_curs.execute('RELEASE SAVEPOINT the_query;')
                    copied += 1
                except IntegrityError:
                    if self.args.dst_exists == 'rollback':
                        dst_curs.connection.rollback()
                        break
                    elif self.args.dst_exists == 'ignore':
                        dst_curs.execute('ROLLBACK TO SAVEPOINT the_query;')
                    elif self.args.dst_exists == 'update':
                        raise NotImplementedError()
            dst_curs.connection.commit()

        self.log.info('Copied %s rows.', copied)


tool = BatchCopyTool()
tool.main()