PgDiff: Add patch support for SQL functions.
# -*- coding: utf-8 -*-
#
# PgDiff - capture differences of database metadata
#
# Depends on PgBrowser
#
# Copyright (c) 2011 Radek Brich <radek.brich@devl.cz>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from pycolib.ansicolor import *
import re
import difflib
class PgDiffError(Exception):
pass
class DiffBase:
COLORS = {
'+' : BOLD | GREEN,
'-' : BOLD | RED,
'*' : BOLD | YELLOW,
}
COMMANDS = {
'+' : 'CREATE',
'-' : 'DROP',
'*' : 'ALTER',
}
def __init__(self):
self.changes = None
def format(self):
out = [' ' * self.level]
out.append(highlight(1, self.COLORS[self.change]))
out.append(self.change)
out += [' ', self.type, ' ', self.name, highlight(0)]
if self.changes:
out += [highlight(1, WHITE), ' (', self._formatchanges(), ')', highlight(0)]
return ''.join(out)
def _formatnotnull(self, notnull):
if notnull:
return 'NOT NULL'
else:
return None
def _formatchanges(self):
res = []
for type, a, b in self.changes:
if type == 'notnull':
type = ''
a = self._formatnotnull(a)
b = self._formatnotnull(b)
if a and b:
s = ''.join(['Changed ', type, ' from ',
highlight(1,15), a, highlight(0), ' to ',
highlight(1,15), b, highlight(0), '.'])
elif a and not b:
l = ['Removed ']
if type:
l += [type, ' ']
l += [highlight(1,15), a, highlight(0), '.']
s = ''.join(l)
elif b and not a:
l = ['Added ']
if type:
l += [type, ' ']
l += [highlight(1,15), b, highlight(0), '.']
s = ''.join(l)
res.append(s)
return ' '.join(res)
def format_patch(self):
if self.change == '*' and self.type in ('schema', 'table'):
return None
return ['%s %s %s;' % (self.COMMANDS[self.change], self.type.upper(), self.name)]
class DiffSchema(DiffBase):
def __init__(self, change, schema):
DiffBase.__init__(self)
self.level = 0
self.type = 'schema'
self.change = change
self.schema = schema
self.name = schema
class DiffTable(DiffBase):
def __init__(self, change, schema, table):
DiffBase.__init__(self)
self.level = 1
self.type = 'table'
self.change = change
self.schema = schema
self.table = table
self.name = table
class DiffArgument(DiffBase):
def __init__(self, change, schema, function, argument):
DiffBase.__init__(self)
self.level = 2
self.type = 'argument'
self.change = change
self.schema = schema
self.function = function
self.argument = argument
self.name = argument
class DiffFunction(DiffBase):
def __init__(self, change, schema, function, definition, show_body_diff=False):
DiffBase.__init__(self)
self.level = 1
self.type = 'function'
self.change = change
self.schema = schema
self.function = function
#: New function definition
self.definition = definition
self.name = function
self.show_body_diff = show_body_diff
def _formatchanges(self):
res = []
for x in self.changes:
type, a, b = x
if type == 'source':
if self.show_body_diff:
lines = ['Source differs:\n']
for line in difflib.unified_diff(a, b, lineterm=''):
if line[:3] in ('---', '+++'):
continue
color = {' ': WHITE, '-': YELLOW, '+': GREEN, '@': WHITE|BOLD}[line[0]]
lines.append(highlight(1, color) + line + highlight(0) + '\n')
res.append(''.join(lines))
else:
res.append('Source differs.')
else:
res.append(''.join(['Changed ', type, ' from ',
highlight(1,15), a, highlight(0), ' to ',
highlight(1,15), b, highlight(0), '.']))
return ' '.join(res)
def format_patch(self):
return [self.definition]
class DiffColumn(DiffBase):
ALTER_COMMANDS = {
'+' : 'ADD',
'-' : 'DROP',
'*' : 'ALTER',
}
def __init__(self, change, schema, table, column, columntype, columndefault, columnnotnull, changes=None):
DiffBase.__init__(self)
self.level = 2
self.type = 'column'
self.change = change
self.schema = schema
self.table = table
self.column = column
self.columntype = columntype
self.columndefault = columndefault
self.columnnotnull = columnnotnull
self.name = column
self.changes = changes
def format_patch(self):
alter_table = 'ALTER TABLE %s.%s %s COLUMN %s' % (
self.schema,
self.table,
self.ALTER_COMMANDS[self.change],
self.name,
)
out = []
if self.change == '-':
out.append('%s;' % alter_table);
if self.change == '+':
notnull = ''
if self.columnnotnull:
notnull = ' NOT NULL'
default = ''
if self.columndefault:
default = ' DEFAULT %s' % self.columndefault
out.append('%s %s%s%s;'
% (alter_table, self.columntype, notnull, default));
if self.change == '*':
for type, a, b in self.changes:
if type == 'type':
out.append('%s TYPE %s;' % (alter_table, b))
if type == 'notnull':
if a and not b:
out.append('%s DROP NOT NULL;' % alter_table)
if not a and b:
out.append('%s SET NOT NULL;' % alter_table)
if type == 'default':
if b:
out.append('%s SET DEFAULT %s;' % (alter_table, b))
else:
out.append('%s DROP DEFAULT;' % alter_table)
return out
class DiffConstraint(DiffBase):
def __init__(self, change, schema, table, constraint, definition, changes=None):
DiffBase.__init__(self)
self.level = 2
self.type = 'constraint'
self.change = change
self.schema = schema
self.table = table
self.constraint = constraint
self.name = constraint
self.definition = definition
self.changes = changes
def format_patch(self):
q_alter = 'ALTER TABLE %s.%s' % (self.schema, self.table)
q_drop = '%s DROP CONSTRAINT %s;' % (q_alter, self.constraint)
q_add = '%s ADD CONSTRAINT %s %s;' % (q_alter, self.constraint, self.definition)
if self.change == '*':
out = [q_drop, q_add]
if self.change == '+':
out = [q_add]
if self.change == '-':
out = [q_drop]
return out
class DiffIndex(DiffBase):
def __init__(self, change, schema, table, index, definition, changes=None):
DiffBase.__init__(self)
self.level = 2
self.type = 'index'
self.change = change
self.schema = schema
self.table = table
self.index = index
self.name = index
self.definition = definition
self.changes = changes
def format_patch(self):
q_drop = 'DROP INDEX %s;' % (self.index,)
q_add = '%s;' % (self.definition,)
if self.change == '*':
out = [q_drop, q_add]
if self.change == '+':
out = [q_add]
if self.change == '-':
out = [q_drop]
return out
class DiffType(DiffBase):
def __init__(self, change, schema, name):
DiffBase.__init__(self)
self.level = 1
self.type = 'type'
self.change = change
self.schema = schema
self.name = name
class PgDiff:
def __init__(self, srcbrowser=None, dstbrowser=None):
self.allowcolor = False
self.src = srcbrowser
self.dst = dstbrowser
self.include_schemas = set() # if not empty, consider only these schemas for diff
self.exclude_schemas = set() # exclude these schemas from diff
self.include_tables = set()
self.exclude_tables = set()
self.function_regex = re.compile(r"")
self.function_body_diff = False
def _test_schema(self, schema):
if self.include_schemas and schema not in self.include_schemas:
return False
if schema in self.exclude_schemas:
return False
return True
def _test_table(self, table):
if self.include_tables and table not in self.include_tables:
return False
if table in self.exclude_tables:
return False
return True
def _test_function(self, function):
return bool(self.function_regex.match(function))
def _diff_names(self, src, dst):
for x in src:
if x in dst:
yield ('*', x)
else:
yield ('-', x)
for x in dst:
if x not in src:
yield ('+', x)
def _compare_columns(self, a, b):
diff = []
if a.type != b.type:
diff.append(('type', a.type, b.type))
if a.notnull != b.notnull:
diff.append(('notnull', a.notnull, b.notnull))
if a.default != b.default:
diff.append(('default', a.default, b.default))
return diff
def _compare_constraints(self, a, b):
diff = []
if a.type != b.type:
diff.append(('type', a.type, b.type))
if a.definition != b.definition:
diff.append(('definition', a.definition, b.definition))
return diff
def _compare_indexes(self, a, b):
diff = []
if a.definition != b.definition:
diff.append(('definition', a.definition, b.definition))
return diff
def _compare_functions(self, a, b):
diff = []
if a.result != b.result:
diff.append(('result', a.result, b.result))
# function source may differ in newlines (\n vs \r\n)
# split lines before comparison, so that these differencies are ignored
a_source = a.source.splitlines()
b_source = b.source.splitlines()
if a_source != b_source:
diff.append(('source', a_source, b_source))
return diff
def _compare_arguments(self, a, b):
diff = []
if a.type != b.type:
diff.append(('type', a.type, b.type))
if a.mode != b.mode:
diff.append(('mode', a.mode, b.mode))
if a.default != b.default:
diff.append(('default', a.default, b.default))
return diff
def _compare_types(self, a, b):
diff = []
if a.type != b.type:
diff.append(('type', a.type, b.type))
if a.elements != b.elements:
diff.append(('elements', repr(a.elements), repr(b.elements)))
return diff
def _diff_columns(self, schema, table, src_columns, dst_columns):
for nd in self._diff_names(src_columns, dst_columns):
if nd[1] in dst_columns:
dst_type = dst_columns[nd[1]].type
dst_default = dst_columns[nd[1]].default
dst_notnull = dst_columns[nd[1]].notnull
else:
dst_type = None
dst_default = None
dst_notnull = None
cdo = DiffColumn(change=nd[0], schema=schema, table=table, column=nd[1],
columntype=dst_type, columndefault=dst_default, columnnotnull=dst_notnull)
if nd[0] == '*':
a = src_columns[nd[1]]
b = dst_columns[nd[1]]
cdo.changes = self._compare_columns(a, b)
if cdo.changes:
yield cdo
else:
yield cdo
def _diff_constraints(self, schema, table, src_constraints, dst_constraints):
for nd in self._diff_names(src_constraints, dst_constraints):
if nd[1] in dst_constraints:
dst_definition = dst_constraints[nd[1]].definition
else:
dst_definition = None
cdo = DiffConstraint(change=nd[0], schema=schema, table=table, constraint=nd[1],
definition=dst_definition)
if nd[0] == '*':
a = src_constraints[nd[1]]
b = dst_constraints[nd[1]]
cdo.changes = self._compare_constraints(a, b)
if cdo.changes:
yield cdo
else:
yield cdo
def _diff_indexes(self, schema, table, src_indexes, dst_indexes):
for nd in self._diff_names(src_indexes, dst_indexes):
if nd[1] in dst_indexes:
dst_definition = dst_indexes[nd[1]].definition
else:
dst_definition = None
ido = DiffIndex(change=nd[0], schema=schema, table=table, index=nd[1],
definition=dst_definition)
if nd[0] == '*':
a = src_indexes[nd[1]]
b = dst_indexes[nd[1]]
ido.changes = self._compare_indexes(a, b)
if ido.changes:
yield ido
else:
yield ido
def _diff_tables(self, schema, src_tables, dst_tables):
for nd in self._diff_names(src_tables, dst_tables):
if not self._test_table(nd[1]):
continue
tdo = DiffTable(change=nd[0], schema=schema, table=nd[1])
if nd[0] == '*':
# columns
src_columns = src_tables[nd[1]].columns
dst_columns = dst_tables[nd[1]].columns
for cdo in self._diff_columns(schema, nd[1], src_columns, dst_columns):
if tdo:
yield tdo
tdo = None
yield cdo
# constraints
src_constraints = src_tables[nd[1]].constraints
dst_constraints = dst_tables[nd[1]].constraints
for cdo in self._diff_constraints(schema, nd[1], src_constraints, dst_constraints):
if tdo:
yield tdo
tdo = None
yield cdo
# indexes
src_indexes = src_tables[nd[1]].indexes
dst_indexes = dst_tables[nd[1]].indexes
for ido in self._diff_indexes(schema, nd[1], src_indexes, dst_indexes):
if tdo:
yield tdo
tdo = None
yield ido
else:
yield tdo
def _diff_arguments(self, schema, function, src_args, dst_args):
for nd in self._diff_names(src_args, dst_args):
ado = DiffArgument(change=nd[0], schema=schema, function=function, argument=nd[1])
if nd[0] == '*':
a = src_args[nd[1]]
b = dst_args[nd[1]]
ado.changes = self._compare_arguments(a, b)
if ado.changes:
yield ado
else:
yield ado
def _diff_functions(self, schema, src_functions, dst_functions):
for nd in self._diff_names(src_functions, dst_functions):
if not self._test_function(nd[1]):
continue
if nd[1] in dst_functions:
dst_definition = dst_functions[nd[1]].definition
else:
dst_definition = None
fdo = DiffFunction(change=nd[0], schema=schema, function=nd[1],
definition=dst_definition,
show_body_diff=self.function_body_diff)
if nd[0] == '*':
# compare function body and result
a = src_functions[nd[1]]
b = dst_functions[nd[1]]
fdo.changes = self._compare_functions(a, b)
if fdo.changes:
yield fdo
fdo = None
# arguments
src_args = src_functions[nd[1]].arguments
dst_args = dst_functions[nd[1]].arguments
for ado in self._diff_arguments(schema, nd[1], src_args, dst_args):
if fdo:
yield fdo
fdo = None
yield ado
else:
yield fdo
def _diff_types(self, schema, src_types, dst_types):
for nd in self._diff_names(src_types, dst_types):
tdo = DiffType(change=nd[0], schema=schema, name=nd[1])
if nd[0] == '*':
a = src_types[nd[1]]
b = dst_types[nd[1]]
tdo.changes = self._compare_types(a, b)
if tdo.changes:
yield tdo
else:
yield tdo
def iter_diff(self):
'''Return diff between src and dst database schema.
Yields one line at the time. Each line is in form of object
iherited from DiffBase. This object contains all information
about changes. See format() method.
'''
src_schemas = self.src.schemas
dst_schemas = self.dst.schemas
src = [x.name for x in src_schemas.values() if not x.system and self._test_schema(x.name)]
dst = [x.name for x in dst_schemas.values() if not x.system and self._test_schema(x.name)]
for nd in self._diff_names(src, dst):
sdo = DiffSchema(change=nd[0], schema=nd[1])
if nd[0] == '*':
# tables
src_tables = src_schemas[nd[1]].tables
dst_tables = dst_schemas[nd[1]].tables
for tdo in self._diff_tables(nd[1], src_tables, dst_tables):
if sdo:
yield sdo
sdo = None
yield tdo
# functions
src_functions = src_schemas[nd[1]].functions
dst_functions = dst_schemas[nd[1]].functions
for fdo in self._diff_functions(nd[1], src_functions, dst_functions):
if sdo:
yield sdo
sdo = None
yield fdo
# types
src_types = src_schemas[nd[1]].types
dst_types = dst_schemas[nd[1]].types
for tdo in self._diff_types(nd[1], src_types, dst_types):
if sdo:
yield sdo
sdo = None
yield tdo
else:
yield sdo
def print_diff(self):
'''Print diff between src and dst database schema.
The output is in human readable form.
Set allowcolor=True of PgDiff instance to get colored output.
'''
for ln in self.iter_diff():
print(ln.format())
def print_patch(self):
'''Print patch for updating from src schema to dst schema.
Supports table drop, add, column drop, add and following
changes of columns:
- type
- set/remove not null
- default value
This is experimental, not tested very much.
Do not use without checking the commands.
Even if it works as intended, it can cause table lock ups
and/or loss of data. You have been warned.
'''
for ln in self.iter_diff():
patch = ln.format_patch()
if patch:
print('\n'.join(patch))
def filter_schemas(self, include=[], exclude=[]):
'''Modify list of schemas which are used for computing diff.
include (list) -- if not empty, consider only these schemas for diff
exclude (list) -- exclude these schemas from diff
Order: include, exclude
include=[] means include everything
Raises:
PgDiffError: when schema from include list is not found in src db
'''
for schema in include:
self._check_schema_exist(schema)
self.include_schemas.clear()
self.include_schemas.update(include)
self.exclude_schemas.clear()
self.exclude_schemas.update(exclude)
def filter_tables(self, include=[], exclude=[]):
self.include_tables.clear()
self.include_tables.update(include)
self.exclude_tables.clear()
self.exclude_tables.update(exclude)
def filter_functions(self, regex=''):
self.function_regex = re.compile(regex)
def _check_schema_exist(self, schema):
if not schema in self.src.schemas:
raise PgDiffError('Schema "%s" not found in source database.' % schema)