PgBrowser: add function arguments as another level in hierarchy. PgDiff: compare function arguments one by one.
authorRadek Brich <radek.brich@devl.cz>
Fri, 25 Jan 2013 17:06:54 +0100
changeset 59 65efd0c6919f
parent 58 0bcc13460dae
child 60 bb6b20106ff5
PgBrowser: add function arguments as another level in hierarchy. PgDiff: compare function arguments one by one.
pgtoolkit/pgbrowser.py
pgtoolkit/pgdiff.py
--- a/pgtoolkit/pgbrowser.py	Thu Jan 24 17:11:17 2013 +0100
+++ b/pgtoolkit/pgbrowser.py	Fri Jan 25 17:06:54 2013 +0100
@@ -113,17 +113,45 @@
     indexes = property(getindexes)
 
 
-class Function:
-    def __init__(self, browser, schema, oid, name, type, arguments, result, source):
+class Argument:
+    def __init__(self, browser, function, name, type, mode, default):
+        # PgBrowser instance
         self.browser = browser
-        self.oid = oid
+        # Function instance
+        self.function = function
         self.name = name
         self.type = type
-        self.arguments = arguments
+        self.mode = mode
+        self.default = default
+
+class Function:
+    def __init__(self, browser, schema, oid, name, function_name, type, result, source):
+        self.browser = browser
+        self.schema = schema
+        self.oid = oid
+        #: unique name - function name + arg types
+        self.name = name
+        #: pure function name without args
+        self.function_name = function_name
+        self.type = type
         self.result = result
         self.source = source
+        self._arguments = None
         self._definition = None
 
+    def refresh(self):
+        self.refresh_args()
+
+    def refresh_args(self):
+        rows = self.browser.list_function_args(self.oid)
+        self._arguments = OrderedDict([(x['name'], Argument(self.browser, self, **x)) for x in rows])
+
+    @property
+    def arguments(self):
+        if self._arguments is None:
+            self.refresh_args()
+        return self._arguments
+
     @property
     def definition(self):
         """Get full function definition including CREATE command."""
@@ -306,9 +334,12 @@
         return self._query('''
             SELECT
                 p.oid as "oid",
-                p.proname as "name",
+                p.proname || '(' || array_to_string(
+                    array(SELECT pg_catalog.format_type(unnest(p.proargtypes), NULL)),
+                      ', '
+                ) || ')' as "name",
+                p.proname as "function_name",
                 pg_catalog.pg_get_function_result(p.oid) as "result",
-                pg_catalog.pg_get_function_arguments(p.oid) as "arguments",
                 p.prosrc as "source",
                 CASE
                     WHEN p.proisagg THEN 'agg'
@@ -318,11 +349,41 @@
                 END as "type"
             FROM pg_catalog.pg_proc p
             LEFT JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
-            WHERE pg_catalog.pg_function_is_visible(p.oid)
-                AND n.nspname = %s
+            WHERE n.nspname = %s
             ORDER BY 1, 2, 4;
             ''', [schema])
 
+    def list_function_args(self, oid):
+        """List function arguments.
+
+        Notes about query:
+            type: Use allargtypes if present, argtypes otherwise.
+                The trick with [0:999] moves lower bound from 0 to default 1
+                by slicing all elements (slices has always lower bound 1).
+            mode: This trick makes array of NULLs of same length as argnames,
+                in case argmodes is NULL.
+            default: Use pg_get_expr, split output by ', '
+                FIXME: will fail if ', ' is present in default value string.
+        """
+        return self._query('''
+            SELECT
+              unnest(p.proargnames) AS "name",
+              pg_catalog.format_type(unnest(
+                COALESCE(p.proallargtypes, (p.proargtypes::oid[])[0:999])
+              ), NULL) AS "type",
+              unnest(
+                COALESCE(
+                  p.proargmodes::text[],
+                  array(SELECT NULL::text FROM generate_series(1, array_upper(p.proargnames, 1)))
+                )
+              ) AS "mode",
+              unnest(array_cat(
+                array_fill(NULL::text, array[COALESCE(array_upper(p.proargnames,1),0) - p.pronargdefaults]),
+                string_to_array(pg_get_expr(p.proargdefaults, 'pg_proc'::regclass, true), ', ')
+              )) AS "default"
+            FROM pg_proc p
+            WHERE p.oid = %s''', [oid])
+
     def get_function_definition(self, oid):
         """Get full function definition, including CREATE command etc.
 
--- a/pgtoolkit/pgdiff.py	Thu Jan 24 17:11:17 2013 +0100
+++ b/pgtoolkit/pgdiff.py	Fri Jan 25 17:06:54 2013 +0100
@@ -118,6 +118,18 @@
         self.name = table
 
 
+class DiffArgument(DiffBase):
+    def __init__(self, change, schema, function, argument):
+        DiffBase.__init__(self)
+        self.level = 2
+        self.type = 'argument'
+        self.change = change
+        self.schema = schema
+        self.function = function
+        self.argument = argument
+        self.name = argument
+
+
 class DiffFunction(DiffBase):
     def __init__(self, change, schema, function):
         DiffBase.__init__(self)
@@ -247,16 +259,6 @@
             if x not in src:
                 yield ('+', x)
 
-    def _compare_functions(self, a, b):
-        diff = []
-        if a.arguments != b.arguments:
-            diff.append(('args', a.arguments, b.arguments))
-        if a.result != b.result:
-            diff.append(('result', a.result, b.result))
-        if a.source != b.source:
-            diff.append(('source', a.source, b.source))
-        return diff
-
     def _compare_columns(self, a, b):
         diff = []
         if a.type != b.type:
@@ -275,6 +277,24 @@
             diff.append(('definition', a.definition, b.definition))
         return diff
 
+    def _compare_functions(self, a, b):
+        diff = []
+        if a.result != b.result:
+            diff.append(('result', a.result, b.result))
+        if a.source != b.source:
+            diff.append(('source', a.source, b.source))
+        return diff
+
+    def _compare_arguments(self, a, b):
+        diff = []
+        if a.type != b.type:
+            diff.append(('type', a.type, b.type))
+        if a.mode != b.mode:
+            diff.append(('mode', a.mode, b.mode))
+        if a.default != b.default:
+            diff.append(('default', a.default, b.default))
+        return diff
+
     def _diff_columns(self, schema, table, src_columns, dst_columns):
         for nd in self._diff_names(src_columns, dst_columns):
             if nd[1] in dst_columns:
@@ -336,16 +356,37 @@
             else:
                 yield tdo
 
+    def _diff_arguments(self, schema, function, src_args, dst_args):
+        for nd in self._diff_names(src_args, dst_args):
+            ado = DiffArgument(change=nd[0], schema=schema, function=function, argument=nd[1])
+            if nd[0] == '*':
+                a = src_args[nd[1]]
+                b = dst_args[nd[1]]
+                ado.changes = self._compare_arguments(a, b)
+                if ado.changes:
+                    yield ado
+            else:
+                yield ado
+
     def _diff_functions(self, schema, src_functions, dst_functions):
         for nd in self._diff_names(src_functions, dst_functions):
             fdo = DiffFunction(change=nd[0], schema=schema, function=nd[1])
             if nd[0] == '*':
-                # compare function body and arguments
+                # compare function body and result
                 a = src_functions[nd[1]]
                 b = dst_functions[nd[1]]
                 fdo.changes = self._compare_functions(a, b)
                 if fdo.changes:
                     yield fdo
+                    fdo = None
+                # arguments
+                src_args = src_functions[nd[1]].arguments
+                dst_args = dst_functions[nd[1]].arguments
+                for ado in self._diff_arguments(schema, nd[1], src_args, dst_args):
+                    if fdo:
+                        yield fdo
+                        fdo = None
+                    yield ado
             else:
                 yield fdo