Update PgDiff: Support SQL patch for constraints. Fix changes of column default value.
authorRadek Brich <radek.brich@devl.cz>
Tue, 11 Dec 2012 11:25:06 +0100
changeset 53 4a049a5af657
parent 52 26121a8fe78b
child 54 291473ab847c
Update PgDiff: Support SQL patch for constraints. Fix changes of column default value.
pgtoolkit/pgdiff.py
--- a/pgtoolkit/pgdiff.py	Tue Dec 11 10:49:42 2012 +0100
+++ b/pgtoolkit/pgdiff.py	Tue Dec 11 11:25:06 2012 +0100
@@ -34,13 +34,13 @@
         '-' : BOLD | RED,
         '*' : BOLD | YELLOW,
     }
-    
+
     COMMANDS = {
         '+' : 'CREATE',
         '-' : 'DROP',
         '*' : 'ALTER',
     }
-    
+
     def __init__(self):
         self.changes = None
 
@@ -49,9 +49,9 @@
 
         out.append(highlight(1, self.COLORS[self.change]))
         out.append(self.change)
-        
+
         out += [' ', self.type, ' ', self.name, highlight(0)]
-        
+
         if self.changes:
             out += [highlight(1, WHITE), ' (', self._formatchanges(), ')', highlight(0)]
 
@@ -62,7 +62,7 @@
             return 'NOT NULL'
         else:
             return None
-    
+
     def _formatchanges(self):
         res = []
         for x in self.changes:
@@ -71,7 +71,7 @@
                 type = ''
                 a = self._formatnotnull(a)
                 b = self._formatnotnull(b)
-                
+
             if a and b:
                 s = ''.join(['Changed ', type, ' from ',
                     highlight(1,15), a, highlight(0), ' to ',
@@ -90,11 +90,11 @@
                 s = ''.join(l)
             res.append(s)
         return ' '.join(res)
-    
+
     def format_patch(self):
         if self.change == '*' and self.type in ('schema', 'table'):
             return None
-        return '%s %s %s;' % (self.COMMANDS[self.change], self.type.upper(), self.name)
+        return ['%s %s %s;' % (self.COMMANDS[self.change], self.type.upper(), self.name)]
 
 
 class DiffSchema(DiffBase):
@@ -124,7 +124,7 @@
         '-' : 'DROP',
         '*' : 'ALTER',
     }
-    
+
     def __init__(self, change, schema, table, column, columntype, columndefault, changes=None):
         DiffBase.__init__(self)
         self.level = 2
@@ -137,23 +137,34 @@
         self.columndefault = columndefault
         self.name = column
         self.changes = changes
-    
+
     def format_patch(self):
-        out = 'ALTER TABLE %s.%s %s COLUMN %s %s' % (
+        if self.change == '*':
+            type_statement = ' TYPE'
+        else:
+            type_statement = ''
+        if self.columntype is not None:
+            type_statement += ' ' + self.columntype;
+        out = []
+        out += ['ALTER TABLE %s.%s %s COLUMN %s%s;' % (
             self.schema,
             self.table,
             self.ALTER_COMMANDS[self.change],
             self.name,
-            self.columntype
-        )
+            type_statement
+        )]
         if self.columndefault:
-            out += ' DEFAULT ' + self.columndefault
-        out += ';'
+            out += ['ALTER TABLE %s.%s ALTER COLUMN %s SET DEFAULT %s;' % (
+                self.schema,
+                self.table,
+                self.name,
+                self.columndefault
+                )]
         return out
 
 
 class DiffConstraint(DiffBase):
-    def __init__(self, change, schema, table, constraint, changes=None):
+    def __init__(self, change, schema, table, constraint, definition, changes=None):
         DiffBase.__init__(self)
         self.level = 2
         self.type = 'constraint'
@@ -162,8 +173,21 @@
         self.table = table
         self.constraint = constraint
         self.name = constraint
+        self.definition = definition
         self.changes = changes
 
+    def format_patch(self):
+        q_alter = 'ALTER TABLE %s.%s' % (self.schema, self.table)
+        q_drop = '%s DROP CONSTRAINT %s;' % (q_alter, self.constraint)
+        q_add = '%s ADD CONSTRAINT %s %s;' % (q_alter, self.constraint, self.definition)
+        if self.change == '*':
+            out = [q_drop, q_add]
+        if self.change == '+':
+            out = [q_add]
+        if self.change == '-':
+            out = [q_drop]
+        return out
+
 
 class PgDiff:
     def __init__(self, srcbrowser=None, dstbrowser=None):
@@ -174,21 +198,21 @@
         self.exclude_schemas = set()  # exclude these schemas from diff
         self.include_tables = set()
         self.exclude_tables = set()
-    
+
     def _test_schema(self, schema):
         if self.include_schemas and schema not in self.include_schemas:
             return False
         if schema in self.exclude_schemas:
             return False
         return True
-    
+
     def _test_table(self, table):
         if self.include_tables and table not in self.include_tables:
             return False
         if table in self.exclude_tables:
             return False
-        return True 
-    
+        return True
+
     def _diff_names(self, src, dst):
         for x in src:
             if x in dst:
@@ -208,7 +232,7 @@
         if a.default != b.default:
             diff.append(('default', a.default, b.default))
         return diff
-    
+
     def _compare_constraints(self, a, b):
         diff = []
         if a.type != b.type:
@@ -216,11 +240,17 @@
         if a.definition != b.definition:
             diff.append(('definition', a.definition, b.definition))
         return diff
-                
+
     def _diff_columns(self, schema, table, src_columns, dst_columns):
         for nd in self._diff_names(src_columns, dst_columns):
+            if nd[1] in dst_columns:
+                dst_type = dst_columns[nd[1]].type
+                dst_default = dst_columns[nd[1]].default
+            else:
+                dst_type = None
+                dst_default = None
             cdo = DiffColumn(change=nd[0], schema=schema, table=table, column=nd[1],
-                columntype=dst_columns[nd[1]].type, columndefault=dst_columns[nd[1]].default)
+                columntype=dst_type, columndefault=dst_default)
             if nd[0] == '*':
                 a = src_columns[nd[1]]
                 b = dst_columns[nd[1]]
@@ -232,7 +262,12 @@
 
     def _diff_constraints(self, schema, table, src_constraints, dst_constraints):
         for nd in self._diff_names(src_constraints, dst_constraints):
-            cdo = DiffConstraint(change=nd[0], schema=schema, table=table, constraint=nd[1])
+            if nd[1] in dst_constraints:
+                dst_definition = dst_constraints[nd[1]].definition
+            else:
+                dst_definition = None
+            cdo = DiffConstraint(change=nd[0], schema=schema, table=table, constraint=nd[1],
+                definition=dst_definition)
             if nd[0] == '*':
                 a = src_constraints[nd[1]]
                 b = dst_constraints[nd[1]]
@@ -269,11 +304,11 @@
 
     def iter_diff(self):
         '''Return diff between src and dst database schema.
-        
+
         Yields one line at the time. Each line is in form of object
         iherited from DiffBase. This object contains all information
         about changes. See format() method.
-        
+
         '''
         src_schemas = self.src.schemas
         dst_schemas = self.dst.schemas
@@ -294,41 +329,41 @@
 
     def print_diff(self):
         '''Print diff between src and dst database schema.
-        
+
         The output is in human readable form.
-        
-        Set allowcolor=True of PgDiff instance to get colored output. 
-        
+
+        Set allowcolor=True of PgDiff instance to get colored output.
+
         '''
         for ln in self.iter_diff():
             print(ln.format())
-        
+
     def print_patch(self):
         '''Print patch for updating from src schema to dst schema.
-        
+
         Supports table drop, add, column drop, add and following
         changes of columns:
           - type
           - set/remove not null
           - default value
-        
+
         This is experimental, not tested very much.
         Do not use without checking the commands.
         Even if it works as intended, it can cause table lock ups
         and/or loss of data. You have been warned.
-        
+
         '''
         for ln in self.iter_diff():
             patch = ln.format_patch()
             if patch:
-                print(patch)
+                print('\n'.join(patch))
 
     def filter_schemas(self, include=[], exclude=[]):
         '''Modify list of schemas which are used for computing diff.
-        
+
         include (list) -- if not empty, consider only these schemas for diff
         exclude (list) -- exclude these schemas from diff
-        
+
         Order: include, exclude
         include=[] means include everything
         '''