Add reverse parameter for diff tools.
authorRadek Brich <radek.brich@devl.cz>
Mon, 17 Dec 2012 16:48:12 +0100
changeset 56 94e091c23ebb
parent 55 adc1615d8fc5
child 57 ba323bbed6a4
Add reverse parameter for diff tools.
pgtoolkit/toolbase.py
schemadiff.py
tablecopy.py
tablediff.py
--- a/pgtoolkit/toolbase.py	Thu Dec 13 17:15:10 2012 +0100
+++ b/pgtoolkit/toolbase.py	Mon Dec 17 16:48:12 2012 +0100
@@ -16,7 +16,7 @@
 
 
 class ToolBase:
-    def __init__(self, name, desc):
+    def __init__(self, name, desc, **kwargs):
         self.parser = argparse.ArgumentParser(description=desc)
         self.parser.add_argument('-d', dest='debug', action='store_true',
             help='Debug mode - print database queries.')
@@ -97,8 +97,8 @@
 
 
 class SimpleTool(ToolBase):
-    def __init__(self, name, desc):
-        ToolBase.__init__(self, name, desc)
+    def __init__(self, name, desc, **kwargs):
+        ToolBase.__init__(self, name, desc, **kwargs)
         self.parser.add_argument('target', metavar='target', type=str, help='Target database')
 
     def init(self):
@@ -107,19 +107,26 @@
 
 
 class SrcDstTool(ToolBase):
-    def __init__(self, name, desc):
-        ToolBase.__init__(self, name, desc)
+    def __init__(self, name, desc, **kwargs):
+        ToolBase.__init__(self, name, desc, **kwargs)
         self.parser.add_argument('src', metavar='source', type=str, help='Source database')
         self.parser.add_argument('dst', metavar='destination', type=str, help='Destination database')
+        if 'allow_reverse' in kwargs and kwargs['allow_reverse']:
+            self.parser.add_argument('-r', '--reverse', action='store_true', help='Reverse operation. Swap source and destination.')
 
     def init(self):
         ToolBase.init(self)
+        if self.is_reversed():
+            self.args.src, self.args.dst = self.args.dst, self.args.src
         self.prepare_conns_from_cmdline_args('src', 'dst')
 
+    def is_reversed(self):
+        return 'reverse' in self.args and self.args.reverse
+
 
 class SrcDstTablesTool(SrcDstTool):
-    def __init__(self, name, desc):
-        SrcDstTool.__init__(self, name, desc)
+    def __init__(self, name, desc, **kwargs):
+        SrcDstTool.__init__(self, name, desc, **kwargs)
         self.parser.add_argument('-t', '--src-table', metavar='source_table',
             dest='srctable', type=str, default='', help='Source table name.')
         self.parser.add_argument('-s', '--src-schema', metavar='source_schema',
@@ -152,6 +159,11 @@
         if not self.table2:
             self.table2 = self.table1
 
+        # swap src, dst when in reverse mode
+        if self.is_reversed():
+            self.schema1, self.schema2 = self.schema2, self.schema1
+            self.table1, self.table2 = self.table2, self.table1
+
     def tables(self):
         '''Generator. Yields schema1, table1, schema2, table2.'''
         srcconn = self.pgm.get_conn('src')
--- a/schemadiff.py	Thu Dec 13 17:15:10 2012 +0100
+++ b/schemadiff.py	Mon Dec 17 16:48:12 2012 +0100
@@ -11,23 +11,23 @@
 
 class SchemaDiffTool(toolbase.SrcDstTool):
     def __init__(self):
-        toolbase.SrcDstTool.__init__(self, name='schemadiff', desc='Database schema diff.')
-        
+        toolbase.SrcDstTool.__init__(self, name='schemadiff', desc='Database schema diff.', allow_reverse = True)
+
         self.parser.add_argument('-s', dest='schema', nargs='*', help='Schema filter')
         self.parser.add_argument('-t', dest='table', nargs='*', help='Table filter')
         self.parser.add_argument('--sql', action='store_true', help='Output is SQL script.')
-        
+
         self.init()
 
     def main(self):
         srcbrowser = pgbrowser.PgBrowser(self.pgm.get_conn('src'))
         dstbrowser = pgbrowser.PgBrowser(self.pgm.get_conn('dst'))
-        
+
         pgd = pgdiff.PgDiff(srcbrowser, dstbrowser)
 
         if self.args.schema:
             pgd.filter_schemas(include=self.args.schema)
-        
+
         if self.args.table:
             pgd.filter_tables(include=self.args.table)
 
--- a/tablecopy.py	Thu Dec 13 17:15:10 2012 +0100
+++ b/tablecopy.py	Mon Dec 17 16:48:12 2012 +0100
@@ -2,6 +2,7 @@
 #
 # Copy data between tables with same table schema.
 #
+# Copies full table, target table must be empty.
 # Can copy multiple tables in one run.
 # Sorts the tables according to references.
 #
@@ -17,20 +18,20 @@
 class TableCopyTool(toolbase.SrcDstTablesTool):
     def __init__(self):
         toolbase.SrcDstTablesTool.__init__(self, name='tablecopy', desc='Table copy tool.')
-        
+
         self.parser.add_argument('-n', '--no-action', dest='noaction', action='store_true',
             help="Do nothing, just print tables to be copied. Useful in combination with --regex.")
         self.parser.add_argument('--no-sort', dest='nosort', action='store_true',
             help="Do not sort. By default, tables are sorted by foreign key references.")
-        
+
         self.init()
 
     def main(self):
         self.srcconn = self.pgm.get_conn('src')
         self.dstconn = self.pgm.get_conn('dst')
-        
+
         dc = pgdatacopy.PgDataCopy(self.srcconn, self.dstconn)
-        
+
         if self.args.nosort:
             for table in self.tables():
                 self.copy_table(dc, *table)
@@ -39,7 +40,7 @@
             details = dict()
             pending = set()
             references = dict()
-            
+
             # build list of all table to be copied (pending) and references map
             for table in self.tables():
                 srcschema, srctable, dstschema, dsttable = table
@@ -47,7 +48,7 @@
                 details[name] = table
                 pending.add(name)
                 references[name] = self.get_references(dstschema, dsttable)
-            
+
             # copy files with fulfilled references, repeat until all done
             while pending:
                 for name in list(pending):
@@ -69,32 +70,32 @@
         print('Copying [%s] %s.%s --> [%s] %s.%s' % (
             self.args.src, srcschema, srctable,
             self.args.dst, dstschema, dsttable))
-        
+
         if self.args.noaction:
             return
-        
+
         dc.set_source(srctable, srcschema)
         dc.set_destination(dsttable, dstschema)
-        
+
         try:
             dc.check()
         except pgdatacopy.TargetNotEmptyError as e:
             print(' - error:', str(e))
             return
-        
+
         print(' - read                           ')
         buf = io.BytesIO()
         wrapped = ProgressWrapper(buf)
         dc.read(wrapped)
         data = buf.getvalue()
         buf.close()
-        
+
         print(' - write                          ')
         buf = io.BytesIO(data)
         wrapped = ProgressWrapper(buf, len(data))
         dc.write(wrapped)
         buf.close()
-        
+
         print(' - analyze                        ')
         dc.analyze()
 
--- a/tablediff.py	Thu Dec 13 17:15:10 2012 +0100
+++ b/tablediff.py	Mon Dec 17 16:48:12 2012 +0100
@@ -14,24 +14,24 @@
 
 class TableDiffTool(toolbase.SrcDstTablesTool):
     def __init__(self):
-        toolbase.SrcDstTablesTool.__init__(self, name='tablediff', desc='Table diff.')
-        
+        toolbase.SrcDstTablesTool.__init__(self, name='tablediff', desc='Table diff.', allow_reverse = True)
+
         self.parser.add_argument('--sql', action='store_true', help='Output is SQL script.')
         self.parser.add_argument('--rowcount', action='store_true', help='Compare number of rows.')
-        
+
         self.init()
 
     def main(self):
         srcconn = self.pgm.get_conn('src')
         dstconn = self.pgm.get_conn('dst')
-        
+
         dd = pgdatadiff.PgDataDiff(srcconn, dstconn)
-        
+
         for srcschema, srctable, dstschema, dsttable in self.tables():
             print('-- Diff from [%s] %s.%s to [%s] %s.%s' % (
                 self.args.src, srcschema, srctable,
                 self.args.dst, dstschema, dsttable))
-            
+
             if self.args.rowcount:
                 with self.pgm.cursor('src') as curs:
                     curs.execute('''SELECT count(*) FROM "%s"."%s"''' % (srcschema, srctable))
@@ -42,10 +42,10 @@
                 if srccount != dstcount:
                     print(highlight(1, BOLD | YELLOW), "Row count differs: src=%s dst=%s" % (srccount, dstcount), highlight(0), sep='')
                 continue
-            
+
             dd.settable1(srctable, srcschema)
             dd.settable2(dsttable, dstschema)
-            
+
             if self.args.sql:
                 dd.print_patch()
             else: