Add basic support for types in browser and schema diff.
authorRadek Brich <brich.radek@ifortuna.cz>
Thu, 08 Aug 2013 15:26:24 +0200
changeset 85 11a282e23e0d
parent 84 3b5dd9efba35
child 86 b61b54aa9f96
Add basic support for types in browser and schema diff.
pgtoolkit/pgbrowser.py
pgtoolkit/pgdiff.py
--- a/pgtoolkit/pgbrowser.py	Wed Jul 24 13:11:37 2013 +0200
+++ b/pgtoolkit/pgbrowser.py	Thu Aug 08 15:26:24 2013 +0200
@@ -161,10 +161,21 @@
         return self._definition
 
 
+class Type:
+    def __init__(self, browser, schema, name, type, elements, description):
+        self.browser = browser
+        self.schema = schema
+        self.name = name
+        self.type = type
+        self.elements = elements
+        self.description = description
+
+
 class Schema:
     def __init__(self, browser, name, owner, acl, description, system):
         self._tables = None
         self._functions = None
+        self._types = None
         self.browser = browser
         self.name = name
         self.owner = owner
@@ -184,6 +195,10 @@
         rows = self.browser.list_functions(self.name)
         self._functions = OrderedDict([(x['name'], Function(self.browser, self, **x)) for x in rows])
 
+    def refresh_types(self):
+        rows = self.browser.list_types(self.name)
+        self._types = OrderedDict([(x['name'], Type(self.browser, self, **x)) for x in rows])
+
     @property
     def tables(self):
         if self._tables is None:
@@ -196,6 +211,12 @@
             self.refresh_functions()
         return self._functions
 
+    @property
+    def types(self):
+        if self._types is None:
+            self.refresh_types()
+        return self._types
+
 
 class PgBrowser:
     def __init__(self, conn=None):
@@ -395,6 +416,35 @@
         """
         return self._query('''select pg_get_functiondef(%s);''', [oid])
 
+    def list_types(self, schema='public'):
+        """List types in schema."""
+        return self._query('''
+            SELECT
+                t.typname AS "name",
+                CASE
+                    WHEN t.typtype = 'b' THEN 'base'::text
+                    WHEN t.typtype = 'c' THEN 'composite'::text
+                    WHEN t.typtype = 'd' THEN 'domain'::text
+                    WHEN t.typtype = 'e' THEN 'enum'::text
+                    WHEN t.typtype = 'p' THEN 'pseudo'::text
+                END AS "type",
+                ARRAY(
+                      SELECT e.enumlabel
+                      FROM pg_catalog.pg_enum e
+                      WHERE e.enumtypid = t.oid
+                      ORDER BY e.oid
+                ) AS "elements",
+                pg_catalog.obj_description(t.oid, 'pg_type') AS "description"
+            FROM pg_catalog.pg_type t
+            LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
+            WHERE (t.typrelid = 0 OR (SELECT c.relkind = 'c' FROM pg_catalog.pg_class c WHERE c.oid = t.typrelid))
+              AND NOT EXISTS(SELECT 1 FROM pg_catalog.pg_type el WHERE el.oid = t.typelem AND el.typarray = t.oid)
+                  AND n.nspname <> 'pg_catalog'
+                  AND n.nspname <> 'information_schema'
+              AND n.nspname = %(schema)s
+            ORDER BY 1, 2;
+        ''', {'schema': schema})
+
     def list_sequences(self, schema=None):
         '''List sequences in schema.'''
         return self._query('''
--- a/pgtoolkit/pgdiff.py	Wed Jul 24 13:11:37 2013 +0200
+++ b/pgtoolkit/pgdiff.py	Thu Aug 08 15:26:24 2013 +0200
@@ -252,6 +252,16 @@
         return out
 
 
+class DiffType(DiffBase):
+    def __init__(self, change, schema, name):
+        DiffBase.__init__(self)
+        self.level = 1
+        self.type = 'type'
+        self.change = change
+        self.schema = schema
+        self.name = name
+
+
 class PgDiff:
     def __init__(self, srcbrowser=None, dstbrowser=None):
         self.allowcolor = False
@@ -331,6 +341,14 @@
             diff.append(('default', a.default, b.default))
         return diff
 
+    def _compare_types(self, a, b):
+        diff = []
+        if a.type != b.type:
+            diff.append(('type', a.type, b.type))
+        if a.elements != b.elements:
+            diff.append(('elements', repr(a.elements), repr(b.elements)))
+        return diff
+
     def _diff_columns(self, schema, table, src_columns, dst_columns):
         for nd in self._diff_names(src_columns, dst_columns):
             if nd[1] in dst_columns:
@@ -430,6 +448,18 @@
             else:
                 yield fdo
 
+    def _diff_types(self, schema, src_types, dst_types):
+        for nd in self._diff_names(src_types, dst_types):
+            tdo = DiffType(change=nd[0], schema=schema, name=nd[1])
+            if nd[0] == '*':
+                a = src_types[nd[1]]
+                b = dst_types[nd[1]]
+                tdo.changes = self._compare_types(a, b)
+                if tdo.changes:
+                    yield tdo
+            else:
+                yield tdo
+
     def iter_diff(self):
         '''Return diff between src and dst database schema.
 
@@ -461,6 +491,14 @@
                         yield sdo
                         sdo = None
                     yield fdo
+                # types
+                src_types = src_schemas[nd[1]].types
+                dst_types = dst_schemas[nd[1]].types
+                for tdo in self._diff_types(nd[1], src_types, dst_types):
+                    if sdo:
+                        yield sdo
+                        sdo = None
+                    yield tdo
             else:
                 yield sdo