pgtoolkit/pgdiff.py
changeset 58 0bcc13460dae
parent 53 4a049a5af657
child 59 65efd0c6919f
--- a/pgtoolkit/pgdiff.py	Mon Dec 17 21:12:04 2012 +0100
+++ b/pgtoolkit/pgdiff.py	Thu Jan 24 17:11:17 2013 +0100
@@ -118,6 +118,30 @@
         self.name = table
 
 
+class DiffFunction(DiffBase):
+    def __init__(self, change, schema, function):
+        DiffBase.__init__(self)
+        self.level = 1
+        self.type = 'function'
+        self.change = change
+        self.schema = schema
+        self.function = function
+        self.name = function
+
+    def _formatchanges(self):
+        res = []
+        for x in self.changes:
+            type, a, b = x
+            if type == 'source':
+                s = 'Changed source.'
+            else:
+                s = ''.join(['Changed ', type, ' from ',
+                    highlight(1,15), a, highlight(0), ' to ',
+                    highlight(1,15), b, highlight(0), '.'])
+            res.append(s)
+        return ' '.join(res)
+
+
 class DiffColumn(DiffBase):
     ALTER_COMMANDS = {
         '+' : 'ADD',
@@ -223,6 +247,16 @@
             if x not in src:
                 yield ('+', x)
 
+    def _compare_functions(self, a, b):
+        diff = []
+        if a.arguments != b.arguments:
+            diff.append(('args', a.arguments, b.arguments))
+        if a.result != b.result:
+            diff.append(('result', a.result, b.result))
+        if a.source != b.source:
+            diff.append(('source', a.source, b.source))
+        return diff
+
     def _compare_columns(self, a, b):
         diff = []
         if a.type != b.type:
@@ -277,7 +311,7 @@
             else:
                 yield cdo
 
-    def _difftables(self, schema, src_tables, dst_tables):
+    def _diff_tables(self, schema, src_tables, dst_tables):
         for nd in self._diff_names(src_tables, dst_tables):
             if not self._test_table(nd[1]):
                 continue
@@ -302,6 +336,19 @@
             else:
                 yield tdo
 
+    def _diff_functions(self, schema, src_functions, dst_functions):
+        for nd in self._diff_names(src_functions, dst_functions):
+            fdo = DiffFunction(change=nd[0], schema=schema, function=nd[1])
+            if nd[0] == '*':
+                # compare function body and arguments
+                a = src_functions[nd[1]]
+                b = dst_functions[nd[1]]
+                fdo.changes = self._compare_functions(a, b)
+                if fdo.changes:
+                    yield fdo
+            else:
+                yield fdo
+
     def iter_diff(self):
         '''Return diff between src and dst database schema.
 
@@ -317,13 +364,22 @@
         for nd in self._diff_names(src, dst):
             sdo = DiffSchema(change=nd[0], schema=nd[1])
             if nd[0] == '*':
+                # tables
                 src_tables = src_schemas[nd[1]].tables
                 dst_tables = dst_schemas[nd[1]].tables
-                for tdo in self._difftables(nd[1], src_tables, dst_tables):
+                for tdo in self._diff_tables(nd[1], src_tables, dst_tables):
                     if sdo:
                         yield sdo
                         sdo = None
                     yield tdo
+                # functions
+                src_functions = src_schemas[nd[1]].functions
+                dst_functions = dst_schemas[nd[1]].functions
+                for fdo in self._diff_functions(nd[1], src_functions, dst_functions):
+                    if sdo:
+                        yield sdo
+                        sdo = None
+                    yield fdo
             else:
                 yield sdo