pgtoolkit/pgdiff.py
changeset 63 8c7f0a51ba50
parent 61 703bba757605
child 64 687e18e5ca93
--- a/pgtoolkit/pgdiff.py	Thu Jan 31 11:02:04 2013 +0100
+++ b/pgtoolkit/pgdiff.py	Thu Jan 31 13:24:57 2013 +0100
@@ -26,6 +26,10 @@
 
 
 from pgtoolkit.highlight import *
+from pgtoolkit.colordiff import colordiff
+
+import re
+import difflib
 
 
 class PgDiffError(Exception):
@@ -134,7 +138,7 @@
 
 
 class DiffFunction(DiffBase):
-    def __init__(self, change, schema, function):
+    def __init__(self, change, schema, function, show_body_diff=False):
         DiffBase.__init__(self)
         self.level = 1
         self.type = 'function'
@@ -142,18 +146,28 @@
         self.schema = schema
         self.function = function
         self.name = function
+        self.show_body_diff = show_body_diff
 
     def _formatchanges(self):
         res = []
         for x in self.changes:
             type, a, b = x
             if type == 'source':
-                s = 'Changed source.'
+                if self.show_body_diff:
+                    lines = ['Source differs:\n']
+                    for line in difflib.unified_diff(a, b, lineterm=''):
+                        if line[:3] in ('---', '+++'):
+                            continue
+                        lines.append(line + '\n')
+                    diff = ''.join(lines)
+                    diff = colordiff(diff)
+                    res.append(diff)
+                else:
+                    res.append('Source differs.')
             else:
-                s = ''.join(['Changed ', type, ' from ',
+                res.append(''.join(['Changed ', type, ' from ',
                     highlight(1,15), a, highlight(0), ' to ',
-                    highlight(1,15), b, highlight(0), '.'])
-            res.append(s)
+                    highlight(1,15), b, highlight(0), '.']))
         return ' '.join(res)
 
 
@@ -249,6 +263,8 @@
         self.exclude_schemas = set()  # exclude these schemas from diff
         self.include_tables = set()
         self.exclude_tables = set()
+        self.function_regex = re.compile(r"")
+        self.function_body_diff = False
 
     def _test_schema(self, schema):
         if self.include_schemas and schema not in self.include_schemas:
@@ -264,6 +280,9 @@
             return False
         return True
 
+    def _test_function(self, function):
+        return bool(self.function_regex.match(function))
+
     def _diff_names(self, src, dst):
         for x in src:
             if x in dst:
@@ -296,8 +315,12 @@
         diff = []
         if a.result != b.result:
             diff.append(('result', a.result, b.result))
-        if a.source != b.source:
-            diff.append(('source', a.source, b.source))
+        # function source may differ in newlines (\n vs \r\n)
+        # split lines before comparison, so that these differencies are ignored
+        a_source = a.source.splitlines()
+        b_source = b.source.splitlines()
+        if a_source != b_source:
+            diff.append(('source', a_source, b_source))
         return diff
 
     def _compare_arguments(self, a, b):
@@ -387,7 +410,9 @@
 
     def _diff_functions(self, schema, src_functions, dst_functions):
         for nd in self._diff_names(src_functions, dst_functions):
-            fdo = DiffFunction(change=nd[0], schema=schema, function=nd[1])
+            if not self._test_function(nd[1]):
+                continue
+            fdo = DiffFunction(change=nd[0], schema=schema, function=nd[1], show_body_diff=self.function_body_diff)
             if nd[0] == '*':
                 # compare function body and result
                 a = src_functions[nd[1]]
@@ -492,13 +517,14 @@
         self.exclude_schemas.clear()
         self.exclude_schemas.update(exclude)
 
-
     def filter_tables(self, include=[], exclude=[]):
         self.include_tables.clear()
         self.include_tables.update(include)
         self.exclude_tables.clear()
         self.exclude_tables.update(exclude)
 
+    def filter_functions(self, regex=''):
+        self.function_regex = re.compile(regex)
 
     def _check_schema_exist(self, schema):
         if not schema in self.src.schemas: