diff -r 037410ef2b6b -r bb8c729ae6ce pgtoolkit/pgdiff.py --- a/pgtoolkit/pgdiff.py Wed Sep 26 23:29:54 2012 +0200 +++ b/pgtoolkit/pgdiff.py Wed Sep 26 23:32:02 2012 +0200 @@ -32,7 +32,14 @@ COLORS = { '+' : BOLD | GREEN, '-' : BOLD | RED, - '*' : BOLD | YELLOW} + '*' : BOLD | YELLOW, + } + + COMMANDS = { + '+' : 'CREATE', + '-' : 'DROP', + '*' : 'ALTER', + } def __init__(self): self.changes = None @@ -46,24 +53,24 @@ out += [' ', self.type, ' ', self.name, highlight(0)] if self.changes: - out += [highlight(1, WHITE), ' (', self.formatchanges(), ')', highlight(0)] + out += [highlight(1, WHITE), ' (', self._formatchanges(), ')', highlight(0)] return ''.join(out) - def formatnotnull(self, notnull): + def _formatnotnull(self, notnull): if notnull: return 'NOT NULL' else: return None - def formatchanges(self): + def _formatchanges(self): res = [] for x in self.changes: type, a, b = x if type == 'notnull': type = '' - a = self.formatnotnull(a) - b = self.formatnotnull(b) + a = self._formatnotnull(a) + b = self._formatnotnull(b) if a and b: s = ''.join(['Changed ', type, ' from ', @@ -83,6 +90,11 @@ s = ''.join(l) res.append(s) return ' '.join(res) + + def format_patch(self): + if self.change == '*' and self.type in ('schema', 'table'): + return None + return '%s %s %s;' % (self.COMMANDS[self.change], self.type.upper(), self.name) class DiffSchema(DiffBase): @@ -93,7 +105,7 @@ self.change = change self.schema = schema self.name = schema - + class DiffTable(DiffBase): def __init__(self, change, schema, table): @@ -107,7 +119,13 @@ class DiffColumn(DiffBase): - def __init__(self, change, schema, table, column, changes=None): + ALTER_COMMANDS = { + '+' : 'ADD', + '-' : 'DROP', + '*' : 'ALTER', + } + + def __init__(self, change, schema, table, column, columntype, columndefault, changes=None): DiffBase.__init__(self) self.level = 2 self.type = 'column' @@ -115,8 +133,23 @@ self.schema = schema self.table = table self.column = column + self.columntype = columntype + self.columndefault = columndefault self.name = column self.changes = changes + + def format_patch(self): + out = 'ALTER TABLE %s.%s %s COLUMN %s %s' % ( + self.schema, + self.table, + self.ALTER_COMMANDS[self.change], + self.name, + self.columntype + ) + if self.columndefault: + out += ' DEFAULT ' + self.columndefault + out += ';' + return out class DiffConstraint(DiffBase): @@ -142,22 +175,21 @@ self.include_tables = set() self.exclude_tables = set() - def _test_filter(self, schema): + def _test_schema(self, schema): if self.include_schemas and schema not in self.include_schemas: return False if schema in self.exclude_schemas: return False return True - def _test_table(self, schema, table): - name = schema + '.' + table - if self.include_tables and name not in self.include_tables: + def _test_table(self, table): + if self.include_tables and table not in self.include_tables: return False - if name in self.exclude_tables: + if table in self.exclude_tables: return False return True - def _diffnames(self, src, dst): + def _diff_names(self, src, dst): for x in src: if x in dst: yield ('*', x) @@ -186,8 +218,9 @@ return diff def _diff_columns(self, schema, table, src_columns, dst_columns): - for nd in self._diffnames(src_columns, dst_columns): - cdo = DiffColumn(change=nd[0], schema=schema, table=table, column=nd[1]) + for nd in self._diff_names(src_columns, dst_columns): + cdo = DiffColumn(change=nd[0], schema=schema, table=table, column=nd[1], + columntype=dst_columns[nd[1]].type, columndefault=dst_columns[nd[1]].default) if nd[0] == '*': a = src_columns[nd[1]] b = dst_columns[nd[1]] @@ -198,7 +231,7 @@ yield cdo def _diff_constraints(self, schema, table, src_constraints, dst_constraints): - for nd in self._diffnames(src_constraints, dst_constraints): + for nd in self._diff_names(src_constraints, dst_constraints): cdo = DiffConstraint(change=nd[0], schema=schema, table=table, constraint=nd[1]) if nd[0] == '*': a = src_constraints[nd[1]] @@ -210,8 +243,8 @@ yield cdo def _difftables(self, schema, src_tables, dst_tables): - for nd in self._diffnames(src_tables, dst_tables): - if not self._test_table(schema, nd[1]): + for nd in self._diff_names(src_tables, dst_tables): + if not self._test_table(nd[1]): continue tdo = DiffTable(change=nd[0], schema=schema, table=nd[1]) if nd[0] == '*': @@ -236,7 +269,7 @@ def iter_diff(self): '''Return diff between src and dst database schema. - + Yields one line at the time. Each line is in form of object iherited from DiffBase. This object contains all information about changes. See format() method. @@ -244,9 +277,9 @@ ''' src_schemas = self.src.schemas dst_schemas = self.dst.schemas - src = [x.name for x in src_schemas.values() if not x.system and self._test_filter(x.name)] - dst = [x.name for x in dst_schemas.values() if not x.system and self._test_filter(x.name)] - for nd in self._diffnames(src, dst): + src = [x.name for x in src_schemas.values() if not x.system and self._test_schema(x.name)] + dst = [x.name for x in dst_schemas.values() if not x.system and self._test_schema(x.name)] + for nd in self._diff_names(src, dst): sdo = DiffSchema(change=nd[0], schema=nd[1]) if nd[0] == '*': src_tables = src_schemas[nd[1]].tables @@ -286,7 +319,9 @@ ''' for ln in self.iter_diff(): - print(ln.format_patch()) + patch = ln.format_patch() + if patch: + print(patch) def filter_schemas(self, include=[], exclude=[]): '''Modify list of schemas which are used for computing diff.