from pgtoolkit import toolbase, pgmanager, pgdatadiff
from pgtoolkit.toolbase import SrcDstTablesTool
from pycolib.ansicolor import highlight, BOLD, YELLOW

import sys


class TableDiffTool(SrcDstTablesTool):

    """
    Print differences between data in tables.

    Requirements:
    * Source table must have defined PRIMARY KEY.
    * Destination table must contain all columns from source table.
      Order is not important.

    """

    def __init__(self):
        SrcDstTablesTool.__init__(self, name='tablediff', desc=self.__doc__, allow_reverse=True)

    def specify_args(self):
        SrcDstTablesTool.specify_args(self)
        self.parser.add_argument('--sql', action='store_true', help='Output is SQL script.')
        self.parser.add_argument('--rowcount', action='store_true', help='Compare number of rows.')
        self.parser.add_argument('-o', '--output-file', help='Output file for sql queries.')

    def main(self):
        srcconn = self.pgm.get_conn('src')
        dstconn = self.pgm.get_conn('dst')

        if self.args.output_file:
            output_file = open(self.args.output_file, 'w')
        else:
            output_file = sys.stdout

        dd = pgdatadiff.PgDataDiff(srcconn, dstconn)

        for srcschema, srctable, dstschema, dsttable in self.tables():
            print('-- Diff from [%s] %s.%s to [%s] %s.%s' % (
                self.args.src, srcschema, srctable,
                self.args.dst, dstschema, dsttable),
                file=output_file)

            if self.args.rowcount:
                with self.pgm.cursor('src') as curs:
                    curs.execute('''SELECT count(*) FROM "%s"."%s"''' % (srcschema, srctable))
                    srccount = curs.fetchone()[0]
                with self.pgm.cursor('dst') as curs:
                    curs.execute('''SELECT count(*) FROM "%s"."%s"''' % (dstschema, dsttable))
                    dstcount = curs.fetchone()[0]
                if srccount != dstcount:
                    print(highlight(1, BOLD | YELLOW),
                        "Row count differs: src=%s dst=%s" % (srccount, dstcount),
                        highlight(0), sep='', file=output_file)
                continue

            dd.settable1(srctable, srcschema)
            dd.settable2(dsttable, dstschema)

            if self.args.sql:
                dd.print_patch(file=output_file)
            else:
                dd.print_diff(file=output_file)


cls = TableDiffTool

