PgDiff: add partial support for SQL patch.
authorRadek Brich <radek.brich@devl.cz>
Wed, 26 Sep 2012 23:32:02 +0200 (2012-09-26)
changeset 47 bb8c729ae6ce
parent 46 037410ef2b6b
child 48 b82c7c2fb5af
PgDiff: add partial support for SQL patch.
pgtoolkit/pgdiff.py
--- a/pgtoolkit/pgdiff.py	Wed Sep 26 23:29:54 2012 +0200
+++ b/pgtoolkit/pgdiff.py	Wed Sep 26 23:32:02 2012 +0200
@@ -32,7 +32,14 @@
     COLORS = {
         '+' : BOLD | GREEN,
         '-' : BOLD | RED,
-        '*' : BOLD | YELLOW}
+        '*' : BOLD | YELLOW,
+    }
+    
+    COMMANDS = {
+        '+' : 'CREATE',
+        '-' : 'DROP',
+        '*' : 'ALTER',
+    }
     
     def __init__(self):
         self.changes = None
@@ -46,24 +53,24 @@
         out += [' ', self.type, ' ', self.name, highlight(0)]
         
         if self.changes:
-            out += [highlight(1, WHITE), ' (', self.formatchanges(), ')', highlight(0)]
+            out += [highlight(1, WHITE), ' (', self._formatchanges(), ')', highlight(0)]
 
         return ''.join(out)
 
-    def formatnotnull(self, notnull):
+    def _formatnotnull(self, notnull):
         if notnull:
             return 'NOT NULL'
         else:
             return None
     
-    def formatchanges(self):
+    def _formatchanges(self):
         res = []
         for x in self.changes:
             type, a, b = x
             if type == 'notnull':
                 type = ''
-                a = self.formatnotnull(a)
-                b = self.formatnotnull(b)
+                a = self._formatnotnull(a)
+                b = self._formatnotnull(b)
                 
             if a and b:
                 s = ''.join(['Changed ', type, ' from ',
@@ -83,6 +90,11 @@
                 s = ''.join(l)
             res.append(s)
         return ' '.join(res)
+    
+    def format_patch(self):
+        if self.change == '*' and self.type in ('schema', 'table'):
+            return None
+        return '%s %s %s;' % (self.COMMANDS[self.change], self.type.upper(), self.name)
 
 
 class DiffSchema(DiffBase):
@@ -93,7 +105,7 @@
         self.change = change
         self.schema = schema
         self.name = schema
-        
+
 
 class DiffTable(DiffBase):
     def __init__(self, change, schema, table):
@@ -107,7 +119,13 @@
 
 
 class DiffColumn(DiffBase):
-    def __init__(self, change, schema, table, column, changes=None):
+    ALTER_COMMANDS = {
+        '+' : 'ADD',
+        '-' : 'DROP',
+        '*' : 'ALTER',
+    }
+    
+    def __init__(self, change, schema, table, column, columntype, columndefault, changes=None):
         DiffBase.__init__(self)
         self.level = 2
         self.type = 'column'
@@ -115,8 +133,23 @@
         self.schema = schema
         self.table = table
         self.column = column
+        self.columntype = columntype
+        self.columndefault = columndefault
         self.name = column
         self.changes = changes
+    
+    def format_patch(self):
+        out = 'ALTER TABLE %s.%s %s COLUMN %s %s' % (
+            self.schema,
+            self.table,
+            self.ALTER_COMMANDS[self.change],
+            self.name,
+            self.columntype
+        )
+        if self.columndefault:
+            out += ' DEFAULT ' + self.columndefault
+        out += ';'
+        return out
 
 
 class DiffConstraint(DiffBase):
@@ -142,22 +175,21 @@
         self.include_tables = set()
         self.exclude_tables = set()
     
-    def _test_filter(self, schema):
+    def _test_schema(self, schema):
         if self.include_schemas and schema not in self.include_schemas:
             return False
         if schema in self.exclude_schemas:
             return False
         return True
     
-    def _test_table(self, schema, table):
-        name = schema + '.' + table
-        if self.include_tables and name not in self.include_tables:
+    def _test_table(self, table):
+        if self.include_tables and table not in self.include_tables:
             return False
-        if name in self.exclude_tables:
+        if table in self.exclude_tables:
             return False
         return True 
     
-    def _diffnames(self, src, dst):
+    def _diff_names(self, src, dst):
         for x in src:
             if x in dst:
                 yield ('*', x)
@@ -186,8 +218,9 @@
         return diff
                 
     def _diff_columns(self, schema, table, src_columns, dst_columns):
-        for nd in self._diffnames(src_columns, dst_columns):
-            cdo = DiffColumn(change=nd[0], schema=schema, table=table, column=nd[1])
+        for nd in self._diff_names(src_columns, dst_columns):
+            cdo = DiffColumn(change=nd[0], schema=schema, table=table, column=nd[1],
+                columntype=dst_columns[nd[1]].type, columndefault=dst_columns[nd[1]].default)
             if nd[0] == '*':
                 a = src_columns[nd[1]]
                 b = dst_columns[nd[1]]
@@ -198,7 +231,7 @@
                 yield cdo
 
     def _diff_constraints(self, schema, table, src_constraints, dst_constraints):
-        for nd in self._diffnames(src_constraints, dst_constraints):
+        for nd in self._diff_names(src_constraints, dst_constraints):
             cdo = DiffConstraint(change=nd[0], schema=schema, table=table, constraint=nd[1])
             if nd[0] == '*':
                 a = src_constraints[nd[1]]
@@ -210,8 +243,8 @@
                 yield cdo
 
     def _difftables(self, schema, src_tables, dst_tables):
-        for nd in self._diffnames(src_tables, dst_tables):
-            if not self._test_table(schema, nd[1]):
+        for nd in self._diff_names(src_tables, dst_tables):
+            if not self._test_table(nd[1]):
                 continue
             tdo = DiffTable(change=nd[0], schema=schema, table=nd[1])
             if nd[0] == '*':
@@ -236,7 +269,7 @@
 
     def iter_diff(self):
         '''Return diff between src and dst database schema.
-
+        
         Yields one line at the time. Each line is in form of object
         iherited from DiffBase. This object contains all information
         about changes. See format() method.
@@ -244,9 +277,9 @@
         '''
         src_schemas = self.src.schemas
         dst_schemas = self.dst.schemas
-        src = [x.name for x in src_schemas.values() if not x.system and self._test_filter(x.name)]
-        dst = [x.name for x in dst_schemas.values() if not x.system and self._test_filter(x.name)]
-        for nd in self._diffnames(src, dst):
+        src = [x.name for x in src_schemas.values() if not x.system and self._test_schema(x.name)]
+        dst = [x.name for x in dst_schemas.values() if not x.system and self._test_schema(x.name)]
+        for nd in self._diff_names(src, dst):
             sdo = DiffSchema(change=nd[0], schema=nd[1])
             if nd[0] == '*':
                 src_tables = src_schemas[nd[1]].tables
@@ -286,7 +319,9 @@
         
         '''
         for ln in self.iter_diff():
-            print(ln.format_patch())
+            patch = ln.format_patch()
+            if patch:
+                print(patch)
 
     def filter_schemas(self, include=[], exclude=[]):
         '''Modify list of schemas which are used for computing diff.