pgtoolkit/pgdatadiff.py
changeset 31 c2e6e24b83d9
parent 27 5fb4883604d6
child 41 6aad5e35efe8
--- a/pgtoolkit/pgdatadiff.py	Mon Feb 27 15:12:40 2012 +0100
+++ b/pgtoolkit/pgdatadiff.py	Mon Mar 05 18:36:46 2012 +0100
@@ -37,11 +37,19 @@
         'V' : BOLD | WHITE,
         'K' : BOLD | BLUE}
     
-    def __init__(self, change, cols1, cols2, id=None):
+    def __init__(self, change, cols1, cols2, key=None):
+        '''
+        
+        change - one of '+', '-', '*' (add, remove, update)
+        cols1 - original column values (OrderedDict)
+        cols2 - new column values (OrderedDict)
+        key - primary key columns (OrderedDict)
+        
+        '''
         self.change = change
         self.cols1 = cols1
         self.cols2 = cols2
-        self.id = id
+        self.key = key
     
     def format(self):
         out = []
@@ -70,8 +78,9 @@
             return [', '.join([self._format_value_add(*x) for x in self.cols2.items()])]
         
         out = []        
-        if self.id:
-            out.extend([highlight(1, self.COLORS['*']), self.id[0], ': ', highlight(0), self.id[1], ', '])
+        if self.key:
+            for colname in self.key:
+                out.extend([highlight(1, self.COLORS['*']), colname, ': ', highlight(0), self.key[colname], ', '])
 
         items = []
         for i in range(len(self.cols1)):
@@ -122,9 +131,9 @@
 
     def _format_where(self):
         out = [' WHERE ']
-        out.extend([self.id[0], ' = '])
-        out.append(self.id[1])
-        out.append(';')
+        for colname in self.key:
+            out.extend([colname, ' = ', self.key[colname], ' AND '])
+        out[-1] = ';'
         return out
 
 class PgDataDiff:
@@ -196,14 +205,22 @@
 
     def _select(self):
         browser = pgbrowser.PgBrowser(self.conn1)
+        
         columns = browser.list_columns(schema=self.schema1, table=self.table1, order=1)
         if not columns:
             raise Exception('Table %s.%s not found.' % (self.schema1, self.table1))
         columns_sel = ', '.join(['"' + x['name'] + '"' for x in columns])
         self.colnames = [x['name'] for x in columns]
         
-        query1 = 'SELECT ' + columns_sel + ' FROM ' + self.fulltable1 + ' ORDER BY 1;'
-        query2 = 'SELECT ' + columns_sel + ' FROM ' + self.fulltable2 + ' ORDER BY 1;'
+        pkey = [ind for ind in browser.list_indexes(schema=self.schema1, table=self.table1) if ind['primary']]
+        if not pkey:
+            raise Exception('Table %s.%s has no primary key.' % (self.schema1, self.table1))
+        pkey = pkey[0]
+        pkey_sel = ', '.join(['"' + x + '"' for x in pkey['columns']])
+        self.pkeycolnames = pkey['columns']
+        
+        query1 = 'SELECT ' + columns_sel + ' FROM ' + self.fulltable1 + ' ORDER BY ' + pkey_sel
+        query2 = 'SELECT ' + columns_sel + ' FROM ' + self.fulltable2 + ' ORDER BY ' + pkey_sel
         
         curs1 = self.conn1.cursor()
         curs2 = self.conn2.cursor()
@@ -216,32 +233,31 @@
     def _compare_data(self, row1, row2):
         cols1 = OrderedDict()
         cols2 = OrderedDict()
-        for i in range(len(row1)):
-            if row1[i] != row2[i]:
-                cols1[self.colnames[i]] = row1[i]
-                cols2[self.colnames[i]] = row2[i]
+        for name in row1:
+            if row1[name] != row2[name]:
+                cols1[name] = row1[name]
+                cols2[name] = row2[name]
         if cols1:
-            id = (self.colnames[0], row1[0])
-            return DiffData('*', cols1, cols2, id=id)
+            key = OrderedDict(zip(self.pkeycolnames, [row1[colname] for colname in self.pkeycolnames]))
+            return DiffData('*', cols1, cols2, key=key)
         
         return None
     
     def _compare_row(self, row1, row2):
         if row2 is None:
-            cols1 = OrderedDict(zip(self.colnames, row1))
-            id = (self.colnames[0], row1[0])
-            return DiffData('-', cols1, None, id=id)
+            key = OrderedDict(zip(self.pkeycolnames, [row1[colname] for colname in self.pkeycolnames]))
+            return DiffData('-', row1, None, key=key)
         if row1 is None:
-            cols2 = OrderedDict(zip(self.colnames, row2))
-            return DiffData('+', None, cols2)
+            return DiffData('+', None, row2)
+        
         
-        if row1[0] < row2[0]:
-            cols1 = OrderedDict(zip(self.colnames, row1))
-            id = (self.colnames[0], row1[0])
-            return DiffData('-', cols1, None, id=id)
-        if row1[0] > row2[0]:
-            cols2 = OrderedDict(zip(self.colnames, row2))
-            return DiffData('+', None, cols2)
+        for keyname in self.pkeycolnames:
+            if row1[keyname] < row2[keyname]:
+                key = OrderedDict(zip(self.pkeycolnames, [row1[colname] for colname in self.pkeycolnames]))
+                return DiffData('-', row1, None, key=key)
+        for keyname in self.pkeycolnames:
+            if row1[keyname] > row2[keyname]:
+                return DiffData('+', None, row2)
         
         return self._compare_data(row1, row2)