PgDiff: Update patch for table column changed.
authorRadek Brich <radek.brich@devl.cz>
Fri, 25 Jan 2013 17:44:49 +0100
changeset 60 bb6b20106ff5
parent 59 65efd0c6919f
child 61 703bba757605
PgDiff: Update patch for table column changed.
pgtoolkit/pgdiff.py
--- a/pgtoolkit/pgdiff.py	Fri Jan 25 17:06:54 2013 +0100
+++ b/pgtoolkit/pgdiff.py	Fri Jan 25 17:44:49 2013 +0100
@@ -65,8 +65,7 @@
 
     def _formatchanges(self):
         res = []
-        for x in self.changes:
-            type, a, b = x
+        for type, a, b in self.changes:
             if type == 'notnull':
                 type = ''
                 a = self._formatnotnull(a)
@@ -161,7 +160,7 @@
         '*' : 'ALTER',
     }
 
-    def __init__(self, change, schema, table, column, columntype, columndefault, changes=None):
+    def __init__(self, change, schema, table, column, columntype, columndefault, columnnotnull, changes=None):
         DiffBase.__init__(self)
         self.level = 2
         self.type = 'column'
@@ -171,31 +170,43 @@
         self.column = column
         self.columntype = columntype
         self.columndefault = columndefault
+        self.columnnotnull = columnnotnull
         self.name = column
         self.changes = changes
 
     def format_patch(self):
-        if self.change == '*':
-            type_statement = ' TYPE'
-        else:
-            type_statement = ''
-        if self.columntype is not None:
-            type_statement += ' ' + self.columntype;
-        out = []
-        out += ['ALTER TABLE %s.%s %s COLUMN %s%s;' % (
+        alter_table = 'ALTER TABLE %s.%s %s COLUMN %s' % (
             self.schema,
             self.table,
             self.ALTER_COMMANDS[self.change],
             self.name,
-            type_statement
-        )]
-        if self.columndefault:
-            out += ['ALTER TABLE %s.%s ALTER COLUMN %s SET DEFAULT %s;' % (
-                self.schema,
-                self.table,
-                self.name,
-                self.columndefault
-                )]
+        )
+        out = []
+        if self.change == '-':
+            out.append('%s;' % alter_table);
+        if self.change == '+':
+            notnull = ''
+            if self.columnnotnull:
+                notnull = ' NOT NULL'
+            default = ''
+            if self.columndefault:
+                default = ' DEFAULT %s' % self.columndefault
+            out.append('%s %s%s%s;'
+                % (alter_table, self.columntype, notnull, default));
+        if self.change == '*':
+            for type, a, b in self.changes:
+                if type == 'type':
+                    out.append('%s TYPE %s;' % (alter_table, b))
+                if type == 'notnull':
+                    if a and not b:
+                        out.append('%s DROP NOT NULL;' % alter_table)
+                    if not a and b:
+                        out.append('%s SET NOT NULL;' % alter_table)
+                if type == 'default':
+                    if b:
+                        out.append('%s SET DEFAULT %s;' % (alter_table, b))
+                    else:
+                        out.append('%s DROP DEFAULT;' % alter_table)
         return out
 
 
@@ -300,11 +311,13 @@
             if nd[1] in dst_columns:
                 dst_type = dst_columns[nd[1]].type
                 dst_default = dst_columns[nd[1]].default
+                dst_notnull = dst_columns[nd[1]].notnull
             else:
                 dst_type = None
                 dst_default = None
+                dst_notnull = None
             cdo = DiffColumn(change=nd[0], schema=schema, table=table, column=nd[1],
-                columntype=dst_type, columndefault=dst_default)
+                columntype=dst_type, columndefault=dst_default, columnnotnull=dst_notnull)
             if nd[0] == '*':
                 a = src_columns[nd[1]]
                 b = dst_columns[nd[1]]