--- a/pgtoolkit/pgdiff.py Fri Jan 25 17:06:54 2013 +0100
+++ b/pgtoolkit/pgdiff.py Fri Jan 25 17:44:49 2013 +0100
@@ -65,8 +65,7 @@
def _formatchanges(self):
res = []
- for x in self.changes:
- type, a, b = x
+ for type, a, b in self.changes:
if type == 'notnull':
type = ''
a = self._formatnotnull(a)
@@ -161,7 +160,7 @@
'*' : 'ALTER',
}
- def __init__(self, change, schema, table, column, columntype, columndefault, changes=None):
+ def __init__(self, change, schema, table, column, columntype, columndefault, columnnotnull, changes=None):
DiffBase.__init__(self)
self.level = 2
self.type = 'column'
@@ -171,31 +170,43 @@
self.column = column
self.columntype = columntype
self.columndefault = columndefault
+ self.columnnotnull = columnnotnull
self.name = column
self.changes = changes
def format_patch(self):
- if self.change == '*':
- type_statement = ' TYPE'
- else:
- type_statement = ''
- if self.columntype is not None:
- type_statement += ' ' + self.columntype;
- out = []
- out += ['ALTER TABLE %s.%s %s COLUMN %s%s;' % (
+ alter_table = 'ALTER TABLE %s.%s %s COLUMN %s' % (
self.schema,
self.table,
self.ALTER_COMMANDS[self.change],
self.name,
- type_statement
- )]
- if self.columndefault:
- out += ['ALTER TABLE %s.%s ALTER COLUMN %s SET DEFAULT %s;' % (
- self.schema,
- self.table,
- self.name,
- self.columndefault
- )]
+ )
+ out = []
+ if self.change == '-':
+ out.append('%s;' % alter_table);
+ if self.change == '+':
+ notnull = ''
+ if self.columnnotnull:
+ notnull = ' NOT NULL'
+ default = ''
+ if self.columndefault:
+ default = ' DEFAULT %s' % self.columndefault
+ out.append('%s %s%s%s;'
+ % (alter_table, self.columntype, notnull, default));
+ if self.change == '*':
+ for type, a, b in self.changes:
+ if type == 'type':
+ out.append('%s TYPE %s;' % (alter_table, b))
+ if type == 'notnull':
+ if a and not b:
+ out.append('%s DROP NOT NULL;' % alter_table)
+ if not a and b:
+ out.append('%s SET NOT NULL;' % alter_table)
+ if type == 'default':
+ if b:
+ out.append('%s SET DEFAULT %s;' % (alter_table, b))
+ else:
+ out.append('%s DROP DEFAULT;' % alter_table)
return out
@@ -300,11 +311,13 @@
if nd[1] in dst_columns:
dst_type = dst_columns[nd[1]].type
dst_default = dst_columns[nd[1]].default
+ dst_notnull = dst_columns[nd[1]].notnull
else:
dst_type = None
dst_default = None
+ dst_notnull = None
cdo = DiffColumn(change=nd[0], schema=schema, table=table, column=nd[1],
- columntype=dst_type, columndefault=dst_default)
+ columntype=dst_type, columndefault=dst_default, columnnotnull=dst_notnull)
if nd[0] == '*':
a = src_columns[nd[1]]
b = dst_columns[nd[1]]