PgBrowser: Add functions. PgDiff: Compare functions.
--- a/pgtoolkit/pgbrowser.py Mon Dec 17 21:12:04 2012 +0100
+++ b/pgtoolkit/pgbrowser.py Thu Jan 24 17:11:17 2013 +0100
@@ -112,9 +112,30 @@
return self._indexes
indexes = property(getindexes)
+
+class Function:
+ def __init__(self, browser, schema, oid, name, type, arguments, result, source):
+ self.browser = browser
+ self.oid = oid
+ self.name = name
+ self.type = type
+ self.arguments = arguments
+ self.result = result
+ self.source = source
+ self._definition = None
+
+ @property
+ def definition(self):
+ """Get full function definition including CREATE command."""
+ if not self._definition:
+ self._definition = self.browser.get_function_definition(self.oid)
+ return self._definition
+
+
class Schema:
def __init__(self, browser, name, owner, acl, description, system):
self._tables = None
+ self._functions = None
self.browser = browser
self.name = name
self.owner = owner
@@ -123,14 +144,28 @@
self.system = system
def refresh(self):
+ self.refresh_tables()
+ self.refresh_functions()
+
+ def refresh_tables(self):
rows = self.browser.list_tables(self.name)
self._tables = OrderedDict([(x['name'], Table(self.browser, self, **x)) for x in rows])
- def gettables(self):
+ def refresh_functions(self):
+ rows = self.browser.list_functions(self.name)
+ self._functions = OrderedDict([(x['name'], Function(self.browser, self, **x)) for x in rows])
+
+ @property
+ def tables(self):
if self._tables is None:
- self.refresh()
+ self.refresh_tables()
return self._tables
- tables = property(gettables)
+
+ @property
+ def functions(self):
+ if self._functions is None:
+ self.refresh_functions()
+ return self._functions
class PgBrowser:
@@ -142,14 +177,17 @@
self.conn = conn
def refresh(self):
+ self.refresh_schemas()
+
+ def refresh_schemas(self):
rows = self.list_schemas()
self._schemas = OrderedDict([(x['name'], Schema(self, **x)) for x in rows])
- def getschemas(self):
+ @property
+ def schemas(self):
if self._schemas is None:
- self.refresh()
+ self.refresh_schemas()
return self._schemas
- schemas = property(getschemas)
def _query(self, query, args):
try:
@@ -263,6 +301,37 @@
ORDER BY i.indisprimary DESC, i.indisunique DESC, c2.relname
''', {'schema': schema, 'table': table})
+ def list_functions(self, schema='public'):
+ '''List functions in schema.'''
+ return self._query('''
+ SELECT
+ p.oid as "oid",
+ p.proname as "name",
+ pg_catalog.pg_get_function_result(p.oid) as "result",
+ pg_catalog.pg_get_function_arguments(p.oid) as "arguments",
+ p.prosrc as "source",
+ CASE
+ WHEN p.proisagg THEN 'agg'
+ WHEN p.proiswindow THEN 'window'
+ WHEN p.prorettype = 'pg_catalog.trigger'::pg_catalog.regtype THEN 'trigger'
+ ELSE 'normal'
+ END as "type"
+ FROM pg_catalog.pg_proc p
+ LEFT JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
+ WHERE pg_catalog.pg_function_is_visible(p.oid)
+ AND n.nspname = %s
+ ORDER BY 1, 2, 4;
+ ''', [schema])
+
+ def get_function_definition(self, oid):
+ """Get full function definition, including CREATE command etc.
+
+ Args:
+ oid: function oid from pg_catalog.pg_proc (returned by list_functions)
+
+ """
+ return self._query('''select pg_get_functiondef(%s);''', [oid])
+
def list_sequences(self, schema=None):
'''List sequences in schema.'''
return self._query('''
--- a/pgtoolkit/pgdiff.py Mon Dec 17 21:12:04 2012 +0100
+++ b/pgtoolkit/pgdiff.py Thu Jan 24 17:11:17 2013 +0100
@@ -118,6 +118,30 @@
self.name = table
+class DiffFunction(DiffBase):
+ def __init__(self, change, schema, function):
+ DiffBase.__init__(self)
+ self.level = 1
+ self.type = 'function'
+ self.change = change
+ self.schema = schema
+ self.function = function
+ self.name = function
+
+ def _formatchanges(self):
+ res = []
+ for x in self.changes:
+ type, a, b = x
+ if type == 'source':
+ s = 'Changed source.'
+ else:
+ s = ''.join(['Changed ', type, ' from ',
+ highlight(1,15), a, highlight(0), ' to ',
+ highlight(1,15), b, highlight(0), '.'])
+ res.append(s)
+ return ' '.join(res)
+
+
class DiffColumn(DiffBase):
ALTER_COMMANDS = {
'+' : 'ADD',
@@ -223,6 +247,16 @@
if x not in src:
yield ('+', x)
+ def _compare_functions(self, a, b):
+ diff = []
+ if a.arguments != b.arguments:
+ diff.append(('args', a.arguments, b.arguments))
+ if a.result != b.result:
+ diff.append(('result', a.result, b.result))
+ if a.source != b.source:
+ diff.append(('source', a.source, b.source))
+ return diff
+
def _compare_columns(self, a, b):
diff = []
if a.type != b.type:
@@ -277,7 +311,7 @@
else:
yield cdo
- def _difftables(self, schema, src_tables, dst_tables):
+ def _diff_tables(self, schema, src_tables, dst_tables):
for nd in self._diff_names(src_tables, dst_tables):
if not self._test_table(nd[1]):
continue
@@ -302,6 +336,19 @@
else:
yield tdo
+ def _diff_functions(self, schema, src_functions, dst_functions):
+ for nd in self._diff_names(src_functions, dst_functions):
+ fdo = DiffFunction(change=nd[0], schema=schema, function=nd[1])
+ if nd[0] == '*':
+ # compare function body and arguments
+ a = src_functions[nd[1]]
+ b = dst_functions[nd[1]]
+ fdo.changes = self._compare_functions(a, b)
+ if fdo.changes:
+ yield fdo
+ else:
+ yield fdo
+
def iter_diff(self):
'''Return diff between src and dst database schema.
@@ -317,13 +364,22 @@
for nd in self._diff_names(src, dst):
sdo = DiffSchema(change=nd[0], schema=nd[1])
if nd[0] == '*':
+ # tables
src_tables = src_schemas[nd[1]].tables
dst_tables = dst_schemas[nd[1]].tables
- for tdo in self._difftables(nd[1], src_tables, dst_tables):
+ for tdo in self._diff_tables(nd[1], src_tables, dst_tables):
if sdo:
yield sdo
sdo = None
yield tdo
+ # functions
+ src_functions = src_schemas[nd[1]].functions
+ dst_functions = dst_schemas[nd[1]].functions
+ for fdo in self._diff_functions(nd[1], src_functions, dst_functions):
+ if sdo:
+ yield sdo
+ sdo = None
+ yield fdo
else:
yield sdo