pgtoolkit/toolbase.py
changeset 101 2a2d0d5df03b
parent 100 d6088dba8fea
child 102 fda45bdfd68d
--- a/pgtoolkit/toolbase.py	Tue May 06 18:37:41 2014 +0200
+++ b/pgtoolkit/toolbase.py	Tue May 06 18:37:43 2014 +0200
@@ -26,24 +26,29 @@
 
 
 class ToolBase:
-    def __init__(self, name, desc, **kwargs):
-        self.parser = argparse.ArgumentParser(prog=name, description=desc,
+
+    def __init__(self, name, desc=None, **kwargs):
+        self.config = ConfigParser()
+        self.parser = argparse.ArgumentParser(prog=name, description=desc or self.__doc__,
             formatter_class=ToolDescriptionFormatter)
-        self.parser.add_argument('-d', dest='debug', action='store_true',
-            help='Debug mode - print database queries.')
+        self.pgm = pgmanager.get_instance()
+        self.target_isolation_level = None
 
-        self.config = ConfigParser()
+    def setup(self, args=None):
+        self.specify_args()
+        self.load_args(args)
+        self.init_logging()
+
+    def specify_args(self):
         self.config.add_option('databases', dict)
         self.config.add_option('meta_db')
         self.config.add_option('meta_query')
-
-        self.pgm = pgmanager.get_instance()
-        self.target_isolation_level = None
+        self.parser.add_argument('-Q', dest='queries', action='store_true',
+            help='Print database queries.')
 
-    def init(self, args=None):
-        self.config.load('pgtoolkit.conf')
+    def load_args(self, args=None, config_file=None):
+        self.config.load(config_file or 'pgtoolkit.conf')
         self.args = self.parser.parse_args(args)
-        self.init_logging()
 
     def init_logging(self):
         # logging
@@ -59,18 +64,20 @@
         log_notices.addHandler(handler)
         log_notices.setLevel(logging.DEBUG)
 
-        if self.args.debug:
+        if self.args.queries:
             log_sql = logging.getLogger('pgmanager_sql')
             log_sql.addHandler(handler)
             log_sql.setLevel(logging.DEBUG)
 
     def prepare_conn_from_metadb(self, name, lookup_name):
-        '''Create connection in pgmanager using meta DB.
+        """Create connection in pgmanager using meta DB.
 
         name -- Name for connection in pgmanager.
         lookup_name -- Name of connection in meta DB.
 
-        '''
+        """
+        if not self.pgm.knows_conn('meta'):
+            self.pgm.create_conn(name='meta', dsn=self.config.meta_db)
         with self.pgm.cursor('meta') as curs:
             curs.execute(self.config.meta_query, [lookup_name])
             row = curs.fetchone_dict()
@@ -80,9 +87,10 @@
                     isolation_level=self.target_isolation_level,
                     **row)
                 return True
+        self.pgm.close_conn('meta')
 
     def prepare_conn_from_config(self, name, lookup_name):
-        '''Create connection in pgmanager using info in config.databases.'''
+        """Create connection in pgmanager using info in config.databases."""
         if self.config.databases:
             if lookup_name in self.config.databases:
                 dsn = self.config.databases[lookup_name]
@@ -99,9 +107,6 @@
             value: connection name in config or meta DB
 
         """
-        if self.config.meta_db:
-            self.pgm.create_conn(name='meta', dsn=self.config.meta_db)
-
         for name in kwargs:
             lookup_name = kwargs[name]
             found = self.prepare_conn_from_config(name, lookup_name)
@@ -110,41 +115,52 @@
             if not found:
                 raise ConnectionInfoNotFound('Connection name "%s" not found in config nor in meta DB.' % lookup_name)
 
-        if self.config.meta_db:
-            self.pgm.close_conn('meta')
-
 
 class SimpleTool(ToolBase):
-    def __init__(self, name, desc, **kwargs):
+
+    def __init__(self, name, desc=None, **kwargs):
         ToolBase.__init__(self, name, desc, **kwargs)
+
+    def specify_args(self):
+        ToolBase.specify_args(self)
         self.parser.add_argument('target', metavar='target', type=str, help='Target database')
 
-    def init(self, args=None):
-        ToolBase.init(self, args)
+    def setup(self, args=None):
+        ToolBase.setup(self, args)
         self.prepare_conns(target=self.args.target)
 
 
 class SrcDstTool(ToolBase):
-    def __init__(self, name, desc, **kwargs):
+
+    def __init__(self, name, desc=None, *, allow_reverse=False, force_reverse=False, **kwargs):
         ToolBase.__init__(self, name, desc, **kwargs)
+        self.allow_reverse = allow_reverse
+        self.force_reverse = force_reverse
+
+    def specify_args(self):
+        ToolBase.specify_args(self)
         self.parser.add_argument('src', metavar='source', type=str, help='Source database')
         self.parser.add_argument('dst', metavar='destination', type=str, help='Destination database')
-        if 'allow_reverse' in kwargs and kwargs['allow_reverse']:
+        if self.allow_reverse:
             self.parser.add_argument('-r', '--reverse', action='store_true', help='Reverse operation. Swap source and destination.')
 
-    def init(self, args=None):
-        ToolBase.init(self, args)
+    def load_args(self, args=None, config_file=None):
+        ToolBase.load_args(self, args, config_file)
         if self.is_reversed():
             self.args.src, self.args.dst = self.args.dst, self.args.src
+
+    def setup(self, args=None):
+        ToolBase.setup(self, args)
         self.prepare_conns(src=self.args.src, dst=self.args.dst)
 
     def is_reversed(self):
-        return 'reverse' in self.args and self.args.reverse
+        return ('reverse' in self.args and self.args.reverse) or self.force_reverse
 
 
 class SrcDstTablesTool(SrcDstTool):
-    def __init__(self, name, desc, **kwargs):
-        SrcDstTool.__init__(self, name, desc, **kwargs)
+
+    def specify_args(self):
+        SrcDstTool.specify_args(self)
         self.parser.add_argument('-t', '--src-table', metavar='source_table',
             dest='srctable', type=str, default='', help='Source table name.')
         self.parser.add_argument('-s', '--src-schema', metavar='source_schema',
@@ -155,9 +171,11 @@
             dest='dstschema', type=str, default='', help='Destination schema name (default=source_schema).')
         self.parser.add_argument('--regex', action='store_true', help="Use RE in schema or table name.")
 
-    def init(self, args=None):
-        SrcDstTool.init(self, args)
+    def load_args(self, args=None, config_file=None):
+        SrcDstTool.load_args(self, args, config_file)
+        self.load_table_names()
 
+    def load_table_names(self):
         self.schema1 = self.args.srcschema
         self.table1 = self.args.srctable
         self.schema2 = self.args.dstschema