pgtools/pgdatadiff.py
changeset 6 4ab077c93b2d
equal deleted inserted replaced
5:57cf8fdff5ed 6:4ab077c93b2d
       
     1 # -*- coding: utf-8 -*-
       
     2 #
       
     3 # PgDataDiff - compare tables, print data differencies
       
     4 #
       
     5 # Copyright (c) 2011  Radek Brich <radek.brich@devl.cz>
       
     6 #
       
     7 # Permission is hereby granted, free of charge, to any person obtaining a copy
       
     8 # of this software and associated documentation files (the "Software"), to deal
       
     9 # in the Software without restriction, including without limitation the rights
       
    10 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       
    11 # copies of the Software, and to permit persons to whom the Software is
       
    12 # furnished to do so, subject to the following conditions:
       
    13 #
       
    14 # The above copyright notice and this permission notice shall be included in
       
    15 # all copies or substantial portions of the Software.
       
    16 #
       
    17 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
       
    18 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
       
    19 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
       
    20 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
       
    21 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
       
    22 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
       
    23 # THE SOFTWARE.
       
    24 
       
    25 
       
    26 from psycopg2.extensions import adapt
       
    27 
       
    28 from common.highlight import *
       
    29 from collections import OrderedDict
       
    30 
       
    31 
       
    32 class DiffData:
       
    33     COLORS = {
       
    34         '+' : BOLD | GREEN,
       
    35         '-' : BOLD | RED,
       
    36         '*' : BOLD | YELLOW,
       
    37         'V' : BOLD | WHITE}
       
    38     
       
    39     def __init__(self, table, change, src_cols, dst_cols, id=None):
       
    40         self.table = table
       
    41         self.change = change
       
    42         self.src_cols = src_cols
       
    43         self.dst_cols = dst_cols
       
    44         self.id = id
       
    45     
       
    46     def format(self):
       
    47         out = []
       
    48         
       
    49         out.append(highlight(1, self.COLORS[self.change]))
       
    50         out.extend([self.change, ' '])
       
    51         
       
    52         out.extend(self._format_changes())
       
    53         
       
    54         out.append(highlight(0))
       
    55         
       
    56         return ''.join(out)
       
    57 
       
    58     def format_patch(self):
       
    59         method = {
       
    60             '+' : self._format_insert,
       
    61             '-' : self._format_delete,
       
    62             '*' : self._format_update}
       
    63         
       
    64         return method[self.change]()
       
    65 
       
    66     def _format_changes(self):
       
    67         if self.src_cols and not self.dst_cols:
       
    68             return [', '.join([self._format_value_del(*x) for x in self.src_cols.items()])]
       
    69         if not self.src_cols and self.dst_cols:
       
    70             return [', '.join([self._format_value_add(*x) for x in self.dst_cols.items()])]
       
    71         
       
    72         items = []
       
    73         for i in range(len(self.src_cols)):
       
    74             items.append((
       
    75                 list(self.src_cols.keys())[i],
       
    76                 list(self.src_cols.values())[i],
       
    77                 list(self.dst_cols.values())[i]))
       
    78             
       
    79         return [', '.join([self._format_value_change(*x) for x in items])]
       
    80 
       
    81     def _format_value_del(self, k, v):
       
    82         fs = (highlight(1, self.COLORS['-']) + '{}: ' + highlight(0) + '{}')
       
    83         return fs.format(k, adapt(v).getquoted().decode())
       
    84 
       
    85     def _format_value_add(self, k, v):
       
    86         fs = (highlight(1, self.COLORS['+']) + '{}: ' + highlight(0) + 
       
    87             highlight(1, self.COLORS['V']) + '{}' + highlight(0))
       
    88         return fs.format(k, adapt(v).getquoted().decode())
       
    89 
       
    90     def _format_value_change(self, k, v1, v2):
       
    91         fs = (highlight(1, self.COLORS['*']) + '{}: ' + highlight(0) + 
       
    92             '{} ▶ ' +
       
    93             highlight(1, self.COLORS['V']) + '{}' + highlight(0))
       
    94         return fs.format(k,
       
    95             adapt(v1).getquoted().decode(),
       
    96             adapt(v2).getquoted().decode())
       
    97 
       
    98     def _format_insert(self):
       
    99         out = ['INSERT INTO ', self.table, ' (']
       
   100         out.append(', '.join(self.dst_cols.keys()))
       
   101         out.append(') VALUES (')
       
   102         out.append(', '.join([adapt(v).getquoted().decode() for v in self.dst_cols.values()]))
       
   103         out.append(');')
       
   104         return ''.join(out)
       
   105     
       
   106     def _format_delete(self):
       
   107         out = ['DELETE FROM ', self.table]
       
   108         out.extend(self._format_where()) 
       
   109         return ''.join(out)
       
   110     
       
   111     def _format_update(self):
       
   112         out = ['UPDATE ', self.table, ' SET ']
       
   113         out.append(', '.join([self._format_set(*x) for x in self.dst_cols.items()]))
       
   114         out.extend(self._format_where())
       
   115         return ''.join(out)
       
   116 
       
   117     def _format_set(self, k, v):
       
   118         return '{} = {}'.format(k, adapt(v).getquoted().decode())
       
   119 
       
   120     def _format_where(self):
       
   121         out = [' WHERE ']
       
   122         out.extend([self.id[0], ' = '])
       
   123         out.append(adapt(self.id[1]).getquoted().decode())
       
   124         out.append(';')
       
   125         return out
       
   126 
       
   127 class PgDataDiff:
       
   128     def __init__(self, table=None, src_rows=None, dst_rows=None, col_names=None):
       
   129         self.allowcolor = False
       
   130         self.table = table
       
   131         self.src_rows = src_rows
       
   132         self.dst_rows = dst_rows
       
   133         self.col_names = col_names
       
   134     
       
   135     def iter_diff(self):
       
   136         '''Return differencies between data of two tables.
       
   137         
       
   138         Yields one line at the time.
       
   139         
       
   140         '''
       
   141         while True:
       
   142             try:
       
   143                 diff = self._compare_row(self.src_rows, self.dst_rows)
       
   144             except IndexError:
       
   145                 break
       
   146             
       
   147             if diff:
       
   148                 yield diff
       
   149         
       
   150     def print_diff(self):
       
   151         '''Print differencies between data of two tables.
       
   152         
       
   153         The output is in human readable form.
       
   154         
       
   155         Set allowcolor=True of PgDataDiff instance to get colored output.
       
   156         
       
   157         '''
       
   158         for ln in self.iter_diff():
       
   159             print(ln.format())
       
   160     
       
   161     def print_patch(self):
       
   162         '''Print SQL script usable as patch for destination table.
       
   163         
       
   164         Supports INSERT, DELETE and UPDATE operations.
       
   165         
       
   166         '''
       
   167         for ln in self.iter_diff():
       
   168             print(ln.format_patch())
       
   169 
       
   170     def _compare_data(self, src, dst):
       
   171         src_cols = OrderedDict()
       
   172         dst_cols = OrderedDict()
       
   173         for i in range(len(src)):
       
   174             if src[i] != dst[i]:
       
   175                 src_cols[self.col_names[i]] = src[i]
       
   176                 dst_cols[self.col_names[i]] = dst[i]
       
   177         if src_cols:
       
   178             id = (self.col_names[0], src[0])
       
   179             return DiffData(self.table, '*', src_cols, dst_cols, id=id)
       
   180         
       
   181         return None
       
   182     
       
   183     def _compare_row(self, src_rows, dst_rows):
       
   184         if len(src_rows) and not len(dst_rows):
       
   185             src = src_rows.pop(0)
       
   186             src_cols = OrderedDict(zip(self.col_names, src))
       
   187             return DiffData(self.table, '-', src_cols, None)
       
   188         if not len(src_rows) and len(dst_rows):
       
   189             dst = dst_rows.pop(0)
       
   190             dst_cols = OrderedDict(zip(self.col_names, dst))
       
   191             return DiffData(self.table, '+', None, dst_cols)
       
   192         
       
   193         src = src_rows[0]
       
   194         dst = dst_rows[0]
       
   195         
       
   196         if src[0] < dst[0]:
       
   197             del src_rows[0]
       
   198             src_cols = OrderedDict(zip(self.col_names, src))
       
   199             id = (self.col_names[0], src[0])
       
   200             return DiffData(self.table, '-', src_cols, None, id=id)
       
   201         if src[0] > dst[0]:
       
   202             del dst_rows[0]
       
   203             dst_cols = OrderedDict(zip(self.col_names, dst))
       
   204             return DiffData(self.table, '+', None, dst_cols)
       
   205         
       
   206         del src_rows[0]
       
   207         del dst_rows[0]
       
   208         return self._compare_data(src, dst)
       
   209