--- a/pgtoolkit/pgdiff.py Wed Sep 26 23:29:54 2012 +0200
+++ b/pgtoolkit/pgdiff.py Wed Sep 26 23:32:02 2012 +0200
@@ -32,7 +32,14 @@
COLORS = {
'+' : BOLD | GREEN,
'-' : BOLD | RED,
- '*' : BOLD | YELLOW}
+ '*' : BOLD | YELLOW,
+ }
+
+ COMMANDS = {
+ '+' : 'CREATE',
+ '-' : 'DROP',
+ '*' : 'ALTER',
+ }
def __init__(self):
self.changes = None
@@ -46,24 +53,24 @@
out += [' ', self.type, ' ', self.name, highlight(0)]
if self.changes:
- out += [highlight(1, WHITE), ' (', self.formatchanges(), ')', highlight(0)]
+ out += [highlight(1, WHITE), ' (', self._formatchanges(), ')', highlight(0)]
return ''.join(out)
- def formatnotnull(self, notnull):
+ def _formatnotnull(self, notnull):
if notnull:
return 'NOT NULL'
else:
return None
- def formatchanges(self):
+ def _formatchanges(self):
res = []
for x in self.changes:
type, a, b = x
if type == 'notnull':
type = ''
- a = self.formatnotnull(a)
- b = self.formatnotnull(b)
+ a = self._formatnotnull(a)
+ b = self._formatnotnull(b)
if a and b:
s = ''.join(['Changed ', type, ' from ',
@@ -83,6 +90,11 @@
s = ''.join(l)
res.append(s)
return ' '.join(res)
+
+ def format_patch(self):
+ if self.change == '*' and self.type in ('schema', 'table'):
+ return None
+ return '%s %s %s;' % (self.COMMANDS[self.change], self.type.upper(), self.name)
class DiffSchema(DiffBase):
@@ -93,7 +105,7 @@
self.change = change
self.schema = schema
self.name = schema
-
+
class DiffTable(DiffBase):
def __init__(self, change, schema, table):
@@ -107,7 +119,13 @@
class DiffColumn(DiffBase):
- def __init__(self, change, schema, table, column, changes=None):
+ ALTER_COMMANDS = {
+ '+' : 'ADD',
+ '-' : 'DROP',
+ '*' : 'ALTER',
+ }
+
+ def __init__(self, change, schema, table, column, columntype, columndefault, changes=None):
DiffBase.__init__(self)
self.level = 2
self.type = 'column'
@@ -115,8 +133,23 @@
self.schema = schema
self.table = table
self.column = column
+ self.columntype = columntype
+ self.columndefault = columndefault
self.name = column
self.changes = changes
+
+ def format_patch(self):
+ out = 'ALTER TABLE %s.%s %s COLUMN %s %s' % (
+ self.schema,
+ self.table,
+ self.ALTER_COMMANDS[self.change],
+ self.name,
+ self.columntype
+ )
+ if self.columndefault:
+ out += ' DEFAULT ' + self.columndefault
+ out += ';'
+ return out
class DiffConstraint(DiffBase):
@@ -142,22 +175,21 @@
self.include_tables = set()
self.exclude_tables = set()
- def _test_filter(self, schema):
+ def _test_schema(self, schema):
if self.include_schemas and schema not in self.include_schemas:
return False
if schema in self.exclude_schemas:
return False
return True
- def _test_table(self, schema, table):
- name = schema + '.' + table
- if self.include_tables and name not in self.include_tables:
+ def _test_table(self, table):
+ if self.include_tables and table not in self.include_tables:
return False
- if name in self.exclude_tables:
+ if table in self.exclude_tables:
return False
return True
- def _diffnames(self, src, dst):
+ def _diff_names(self, src, dst):
for x in src:
if x in dst:
yield ('*', x)
@@ -186,8 +218,9 @@
return diff
def _diff_columns(self, schema, table, src_columns, dst_columns):
- for nd in self._diffnames(src_columns, dst_columns):
- cdo = DiffColumn(change=nd[0], schema=schema, table=table, column=nd[1])
+ for nd in self._diff_names(src_columns, dst_columns):
+ cdo = DiffColumn(change=nd[0], schema=schema, table=table, column=nd[1],
+ columntype=dst_columns[nd[1]].type, columndefault=dst_columns[nd[1]].default)
if nd[0] == '*':
a = src_columns[nd[1]]
b = dst_columns[nd[1]]
@@ -198,7 +231,7 @@
yield cdo
def _diff_constraints(self, schema, table, src_constraints, dst_constraints):
- for nd in self._diffnames(src_constraints, dst_constraints):
+ for nd in self._diff_names(src_constraints, dst_constraints):
cdo = DiffConstraint(change=nd[0], schema=schema, table=table, constraint=nd[1])
if nd[0] == '*':
a = src_constraints[nd[1]]
@@ -210,8 +243,8 @@
yield cdo
def _difftables(self, schema, src_tables, dst_tables):
- for nd in self._diffnames(src_tables, dst_tables):
- if not self._test_table(schema, nd[1]):
+ for nd in self._diff_names(src_tables, dst_tables):
+ if not self._test_table(nd[1]):
continue
tdo = DiffTable(change=nd[0], schema=schema, table=nd[1])
if nd[0] == '*':
@@ -236,7 +269,7 @@
def iter_diff(self):
'''Return diff between src and dst database schema.
-
+
Yields one line at the time. Each line is in form of object
iherited from DiffBase. This object contains all information
about changes. See format() method.
@@ -244,9 +277,9 @@
'''
src_schemas = self.src.schemas
dst_schemas = self.dst.schemas
- src = [x.name for x in src_schemas.values() if not x.system and self._test_filter(x.name)]
- dst = [x.name for x in dst_schemas.values() if not x.system and self._test_filter(x.name)]
- for nd in self._diffnames(src, dst):
+ src = [x.name for x in src_schemas.values() if not x.system and self._test_schema(x.name)]
+ dst = [x.name for x in dst_schemas.values() if not x.system and self._test_schema(x.name)]
+ for nd in self._diff_names(src, dst):
sdo = DiffSchema(change=nd[0], schema=nd[1])
if nd[0] == '*':
src_tables = src_schemas[nd[1]].tables
@@ -286,7 +319,9 @@
'''
for ln in self.iter_diff():
- print(ln.format_patch())
+ patch = ln.format_patch()
+ if patch:
+ print(patch)
def filter_schemas(self, include=[], exclude=[]):
'''Modify list of schemas which are used for computing diff.