tablediff.py
changeset 34 98c7809af415
parent 14 a900bc629ecc
child 35 e7f79c4a27ce
--- a/tablediff.py	Wed Mar 28 17:25:18 2012 +0200
+++ b/tablediff.py	Thu May 10 08:42:21 2012 +0200
@@ -8,40 +8,48 @@
 #    Order is not important.
 #
 
-from pgtoolkit import pgmanager, pgbrowser, pgdatadiff, toolbase
+from pgtoolkit import toolbase, pgmanager, pgdatadiff
+from pgtoolkit.highlight import *
 
 
-class TableDiffTool(toolbase.SrcDstTool):
+class TableDiffTool(toolbase.SrcDstTablesTool):
     def __init__(self):
-        toolbase.SrcDstTool.__init__(self, name='tablediff', desc='Table diff.')
+        toolbase.SrcDstTablesTool.__init__(self, name='tablediff', desc='Table diff.')
         
-        self.parser.add_argument('srctable', metavar='srctable',
-            type=str, help='Source table name.')
-        self.parser.add_argument('--dst-table', dest='dsttable', metavar='dsttable',
-            type=str, default=None, help='Destination table (default=srctable).')
-        self.parser.add_argument('-s', '--src-schema', dest='srcschema', metavar='srcschema',
-            type=str, default='public', help='Schema name (default=public).')
-        self.parser.add_argument('--dst-schema', dest='dstschema', metavar='dstschema',
-            type=str, default=None, help='Destination schema name (default=srcschema).')
         self.parser.add_argument('--sql', action='store_true', help='Output is SQL script.')
+        self.parser.add_argument('--rowcount', action='store_true', help='Compare number of rows.')
         
         self.init()
 
     def main(self):
-        srcschema = self.args.srcschema
-        dstschema = self.args.dstschema if self.args.dstschema else self.args.srcschema
+        srcconn = self.pgm.get_conn('src')
+        dstconn = self.pgm.get_conn('dst')
         
-        srctable = self.args.srctable
-        dsttable = self.args.dsttable if self.args.dsttable else self.args.srctable
+        dd = pgdatadiff.PgDataDiff(srcconn, dstconn)
         
-        dd = pgdatadiff.PgDataDiff(self.pgm.get_conn('src'), self.pgm.get_conn('dst'))
-        dd.settable1(srctable, srcschema)
-        dd.settable2(dsttable, dstschema)
-        
-        if self.args.sql:
-            dd.print_patch()
-        else:
-            dd.print_diff()
+        for srcschema, srctable, dstschema, dsttable in self.tables():
+            print('Diff from [%s] %s.%s to [%s] %s.%s' % (
+                self.args.src, srcschema, srctable,
+                self.args.dst, dstschema, dsttable))
+            
+            if self.args.rowcount:
+                with self.pgm.cursor('src') as curs:
+                    curs.execute('''SELECT count(*) FROM "%s"."%s"''' % (srcschema, srctable))
+                    srccount = curs.fetchone()[0]
+                with self.pgm.cursor('dst') as curs:
+                    curs.execute('''SELECT count(*) FROM "%s"."%s"''' % (dstschema, dsttable))
+                    dstcount = curs.fetchone()[0]
+                if srccount != dstcount:
+                    print(highlight(1, BOLD | YELLOW), "Row count differs: src=%s dst=%s" % (srccount, dstcount), highlight(0), sep='')
+                continue
+            
+            dd.settable1(srctable, srcschema)
+            dd.settable2(dsttable, dstschema)
+            
+            if self.args.sql:
+                dd.print_patch()
+            else:
+                dd.print_diff()
 
 
 tool = TableDiffTool()