PgBrowser: Add functions. PgDiff: Compare functions.
authorRadek Brich <radek.brich@devl.cz>
Thu, 24 Jan 2013 17:11:17 +0100
changeset 58 0bcc13460dae
parent 57 ba323bbed6a4
child 59 65efd0c6919f
PgBrowser: Add functions. PgDiff: Compare functions.
pgtoolkit/pgbrowser.py
pgtoolkit/pgdiff.py
--- a/pgtoolkit/pgbrowser.py	Mon Dec 17 21:12:04 2012 +0100
+++ b/pgtoolkit/pgbrowser.py	Thu Jan 24 17:11:17 2013 +0100
@@ -112,9 +112,30 @@
         return self._indexes
     indexes = property(getindexes)
 
+
+class Function:
+    def __init__(self, browser, schema, oid, name, type, arguments, result, source):
+        self.browser = browser
+        self.oid = oid
+        self.name = name
+        self.type = type
+        self.arguments = arguments
+        self.result = result
+        self.source = source
+        self._definition = None
+
+    @property
+    def definition(self):
+        """Get full function definition including CREATE command."""
+        if not self._definition:
+            self._definition = self.browser.get_function_definition(self.oid)
+        return self._definition
+
+
 class Schema:
     def __init__(self, browser, name, owner, acl, description, system):
         self._tables = None
+        self._functions = None
         self.browser = browser
         self.name = name
         self.owner = owner
@@ -123,14 +144,28 @@
         self.system = system
 
     def refresh(self):
+        self.refresh_tables()
+        self.refresh_functions()
+
+    def refresh_tables(self):
         rows = self.browser.list_tables(self.name)
         self._tables = OrderedDict([(x['name'], Table(self.browser, self, **x)) for x in rows])
 
-    def gettables(self):
+    def refresh_functions(self):
+        rows = self.browser.list_functions(self.name)
+        self._functions = OrderedDict([(x['name'], Function(self.browser, self, **x)) for x in rows])
+
+    @property
+    def tables(self):
         if self._tables is None:
-            self.refresh()
+            self.refresh_tables()
         return self._tables
-    tables = property(gettables)
+
+    @property
+    def functions(self):
+        if self._functions is None:
+            self.refresh_functions()
+        return self._functions
 
 
 class PgBrowser:
@@ -142,14 +177,17 @@
         self.conn = conn
 
     def refresh(self):
+        self.refresh_schemas()
+
+    def refresh_schemas(self):
         rows = self.list_schemas()
         self._schemas = OrderedDict([(x['name'], Schema(self, **x)) for x in rows])
 
-    def getschemas(self):
+    @property
+    def schemas(self):
         if self._schemas is None:
-            self.refresh()
+            self.refresh_schemas()
         return self._schemas
-    schemas = property(getschemas)
 
     def _query(self, query, args):
         try:
@@ -263,6 +301,37 @@
             ORDER BY i.indisprimary DESC, i.indisunique DESC, c2.relname
             ''', {'schema': schema, 'table': table})
 
+    def list_functions(self, schema='public'):
+        '''List functions in schema.'''
+        return self._query('''
+            SELECT
+                p.oid as "oid",
+                p.proname as "name",
+                pg_catalog.pg_get_function_result(p.oid) as "result",
+                pg_catalog.pg_get_function_arguments(p.oid) as "arguments",
+                p.prosrc as "source",
+                CASE
+                    WHEN p.proisagg THEN 'agg'
+                    WHEN p.proiswindow THEN 'window'
+                    WHEN p.prorettype = 'pg_catalog.trigger'::pg_catalog.regtype THEN 'trigger'
+                    ELSE 'normal'
+                END as "type"
+            FROM pg_catalog.pg_proc p
+            LEFT JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
+            WHERE pg_catalog.pg_function_is_visible(p.oid)
+                AND n.nspname = %s
+            ORDER BY 1, 2, 4;
+            ''', [schema])
+
+    def get_function_definition(self, oid):
+        """Get full function definition, including CREATE command etc.
+
+        Args:
+            oid: function oid from pg_catalog.pg_proc (returned by list_functions)
+
+        """
+        return self._query('''select pg_get_functiondef(%s);''', [oid])
+
     def list_sequences(self, schema=None):
         '''List sequences in schema.'''
         return self._query('''
--- a/pgtoolkit/pgdiff.py	Mon Dec 17 21:12:04 2012 +0100
+++ b/pgtoolkit/pgdiff.py	Thu Jan 24 17:11:17 2013 +0100
@@ -118,6 +118,30 @@
         self.name = table
 
 
+class DiffFunction(DiffBase):
+    def __init__(self, change, schema, function):
+        DiffBase.__init__(self)
+        self.level = 1
+        self.type = 'function'
+        self.change = change
+        self.schema = schema
+        self.function = function
+        self.name = function
+
+    def _formatchanges(self):
+        res = []
+        for x in self.changes:
+            type, a, b = x
+            if type == 'source':
+                s = 'Changed source.'
+            else:
+                s = ''.join(['Changed ', type, ' from ',
+                    highlight(1,15), a, highlight(0), ' to ',
+                    highlight(1,15), b, highlight(0), '.'])
+            res.append(s)
+        return ' '.join(res)
+
+
 class DiffColumn(DiffBase):
     ALTER_COMMANDS = {
         '+' : 'ADD',
@@ -223,6 +247,16 @@
             if x not in src:
                 yield ('+', x)
 
+    def _compare_functions(self, a, b):
+        diff = []
+        if a.arguments != b.arguments:
+            diff.append(('args', a.arguments, b.arguments))
+        if a.result != b.result:
+            diff.append(('result', a.result, b.result))
+        if a.source != b.source:
+            diff.append(('source', a.source, b.source))
+        return diff
+
     def _compare_columns(self, a, b):
         diff = []
         if a.type != b.type:
@@ -277,7 +311,7 @@
             else:
                 yield cdo
 
-    def _difftables(self, schema, src_tables, dst_tables):
+    def _diff_tables(self, schema, src_tables, dst_tables):
         for nd in self._diff_names(src_tables, dst_tables):
             if not self._test_table(nd[1]):
                 continue
@@ -302,6 +336,19 @@
             else:
                 yield tdo
 
+    def _diff_functions(self, schema, src_functions, dst_functions):
+        for nd in self._diff_names(src_functions, dst_functions):
+            fdo = DiffFunction(change=nd[0], schema=schema, function=nd[1])
+            if nd[0] == '*':
+                # compare function body and arguments
+                a = src_functions[nd[1]]
+                b = dst_functions[nd[1]]
+                fdo.changes = self._compare_functions(a, b)
+                if fdo.changes:
+                    yield fdo
+            else:
+                yield fdo
+
     def iter_diff(self):
         '''Return diff between src and dst database schema.
 
@@ -317,13 +364,22 @@
         for nd in self._diff_names(src, dst):
             sdo = DiffSchema(change=nd[0], schema=nd[1])
             if nd[0] == '*':
+                # tables
                 src_tables = src_schemas[nd[1]].tables
                 dst_tables = dst_schemas[nd[1]].tables
-                for tdo in self._difftables(nd[1], src_tables, dst_tables):
+                for tdo in self._diff_tables(nd[1], src_tables, dst_tables):
                     if sdo:
                         yield sdo
                         sdo = None
                     yield tdo
+                # functions
+                src_functions = src_schemas[nd[1]].functions
+                dst_functions = dst_schemas[nd[1]].functions
+                for fdo in self._diff_functions(nd[1], src_functions, dst_functions):
+                    if sdo:
+                        yield sdo
+                        sdo = None
+                    yield fdo
             else:
                 yield sdo