pgtools/pgdiff.py
changeset 7 685b20d2d3ab
parent 6 4ab077c93b2d
child 8 2911935c524d
equal deleted inserted replaced
6:4ab077c93b2d 7:685b20d2d3ab
     1 # -*- coding: utf-8 -*-
       
     2 #
       
     3 # PgDiff - capture differences of database metadata
       
     4 #
       
     5 # Depends on PgBrowser
       
     6 #
       
     7 # Copyright (c) 2011  Radek Brich <radek.brich@devl.cz>
       
     8 #
       
     9 # Permission is hereby granted, free of charge, to any person obtaining a copy
       
    10 # of this software and associated documentation files (the "Software"), to deal
       
    11 # in the Software without restriction, including without limitation the rights
       
    12 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       
    13 # copies of the Software, and to permit persons to whom the Software is
       
    14 # furnished to do so, subject to the following conditions:
       
    15 #
       
    16 # The above copyright notice and this permission notice shall be included in
       
    17 # all copies or substantial portions of the Software.
       
    18 #
       
    19 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
       
    20 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
       
    21 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
       
    22 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
       
    23 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
       
    24 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
       
    25 # THE SOFTWARE.
       
    26 
       
    27 
       
    28 from common.highlight import *
       
    29 
       
    30 
       
    31 class DiffBase:
       
    32     COLORS = {
       
    33         '+' : BOLD | GREEN,
       
    34         '-' : BOLD | RED,
       
    35         '*' : BOLD | YELLOW}
       
    36     
       
    37     def __init__(self):
       
    38         self.changes = None
       
    39 
       
    40     def format(self):
       
    41         out = ['  ' * self.level]
       
    42 
       
    43         out.append(highlight(1, self.COLORS[self.change]))
       
    44         out.append(self.change)
       
    45         
       
    46         out += [' ', self.type, ' ', self.name, highlight(0)]
       
    47         
       
    48         if self.changes:
       
    49             out += [highlight(1, WHITE), ' (', self.formatchanges(), ')', highlight(0)]
       
    50 
       
    51         return ''.join(out)
       
    52 
       
    53     def formatnotnull(self, notnull):
       
    54         if notnull:
       
    55             return 'NOT NULL'
       
    56         else:
       
    57             return None
       
    58     
       
    59     def formatchanges(self):
       
    60         res = []
       
    61         for x in self.changes:
       
    62             type, a, b = x
       
    63             if type == 'notnull':
       
    64                 type = ''
       
    65                 a = self.formatnotnull(a)
       
    66                 b = self.formatnotnull(b)
       
    67                 
       
    68             if a and b:
       
    69                 s = ''.join(['Changed ', type, ' from ',
       
    70                     highlight(1,15), a, highlight(0), ' to ',
       
    71                     highlight(1,15), b, highlight(0), '.'])
       
    72             elif a and not b:
       
    73                 l = ['Removed ']
       
    74                 if type:
       
    75                     l += [type, ' ']
       
    76                 l += [highlight(1,15), a, highlight(0), '.']
       
    77                 s = ''.join(l)
       
    78             elif b and not a:
       
    79                 l = ['Added ']
       
    80                 if type:
       
    81                     l += [type, ' ']
       
    82                 l += [highlight(1,15), b, highlight(0), '.']
       
    83                 s = ''.join(l)
       
    84             res.append(s)
       
    85         return ' '.join(res)
       
    86 
       
    87 
       
    88 class DiffSchema(DiffBase):
       
    89     def __init__(self, change, schema):
       
    90         DiffBase.__init__(self)
       
    91         self.level = 0
       
    92         self.type = 'schema'
       
    93         self.change = change
       
    94         self.schema = schema
       
    95         self.name = schema
       
    96         
       
    97 
       
    98 class DiffTable(DiffBase):
       
    99     def __init__(self, change, schema, table):
       
   100         DiffBase.__init__(self)
       
   101         self.level = 1
       
   102         self.type = 'table'
       
   103         self.change = change
       
   104         self.schema = schema
       
   105         self.table = table
       
   106         self.name = table
       
   107 
       
   108 
       
   109 class DiffColumn(DiffBase):
       
   110     def __init__(self, change, schema, table, column, changes=None):
       
   111         DiffBase.__init__(self)
       
   112         self.level = 2
       
   113         self.type = 'column'
       
   114         self.change = change
       
   115         self.schema = schema
       
   116         self.table = table
       
   117         self.column = column
       
   118         self.name = column
       
   119         self.changes = changes
       
   120 
       
   121 
       
   122 class DiffConstraint(DiffBase):
       
   123     def __init__(self, change, schema, table, constraint, changes=None):
       
   124         DiffBase.__init__(self)
       
   125         self.level = 2
       
   126         self.type = 'constraint'
       
   127         self.change = change
       
   128         self.schema = schema
       
   129         self.table = table
       
   130         self.constraint = constraint
       
   131         self.name = constraint
       
   132         self.changes = changes
       
   133 
       
   134 
       
   135 class PgDiff:
       
   136     def __init__(self, srcbrowser=None, dstbrowser=None):
       
   137         self.allowcolor = False
       
   138         self.src = srcbrowser
       
   139         self.dst = dstbrowser
       
   140         self.include_schemas = set()  # if not empty, consider only these schemas for diff
       
   141         self.exclude_schemas = set()  # exclude these schemas from diff
       
   142         self.include_tables = set()
       
   143         self.exclude_tables = set()
       
   144     
       
   145     def _test_filter(self, schema):
       
   146         if self.include_schemas and schema not in self.include_schemas:
       
   147             return False
       
   148         if schema in self.exclude_schemas:
       
   149             return False
       
   150         return True
       
   151     
       
   152     def _test_table(self, schema, table):
       
   153         name = schema + '.' + table
       
   154         if self.include_tables and name not in self.include_tables:
       
   155             return False
       
   156         if name in self.exclude_tables:
       
   157             return False
       
   158         return True 
       
   159     
       
   160     def _diffnames(self, src, dst):
       
   161         for x in src:
       
   162             if x in dst:
       
   163                 yield ('*', x)
       
   164             else:
       
   165                 yield ('-', x)
       
   166         for x in dst:
       
   167             if x not in src:
       
   168                 yield ('+', x)
       
   169 
       
   170     def _compare_columns(self, a, b):
       
   171         diff = []
       
   172         if a.type != b.type:
       
   173             diff.append(('type', a.type, b.type))
       
   174         if a.notnull != b.notnull:
       
   175             diff.append(('notnull', a.notnull, b.notnull))
       
   176         if a.default != b.default:
       
   177             diff.append(('default', a.default, b.default))
       
   178         return diff
       
   179     
       
   180     def _compare_constraints(self, a, b):
       
   181         diff = []
       
   182         if a.type != b.type:
       
   183             diff.append(('type', a.type, b.type))
       
   184         if a.definition != b.definition:
       
   185             diff.append(('definition', a.definition, b.definition))
       
   186         return diff
       
   187                 
       
   188     def _diff_columns(self, schema, table, src_columns, dst_columns):
       
   189         for nd in self._diffnames(src_columns, dst_columns):
       
   190             cdo = DiffColumn(change=nd[0], schema=schema, table=table, column=nd[1])
       
   191             if nd[0] == '*':
       
   192                 a = src_columns[nd[1]]
       
   193                 b = dst_columns[nd[1]]
       
   194                 cdo.changes = self._compare_columns(a, b)
       
   195                 if cdo.changes:
       
   196                     yield cdo
       
   197             else:
       
   198                 yield cdo
       
   199 
       
   200     def _diff_constraints(self, schema, table, src_constraints, dst_constraints):
       
   201         for nd in self._diffnames(src_constraints, dst_constraints):
       
   202             cdo = DiffConstraint(change=nd[0], schema=schema, table=table, constraint=nd[1])
       
   203             if nd[0] == '*':
       
   204                 a = src_constraints[nd[1]]
       
   205                 b = dst_constraints[nd[1]]
       
   206                 cdo.changes = self._compare_constraints(a, b)
       
   207                 if cdo.changes:
       
   208                     yield cdo
       
   209             else:
       
   210                 yield cdo
       
   211 
       
   212     def _difftables(self, schema, src_tables, dst_tables):
       
   213         for nd in self._diffnames(src_tables, dst_tables):
       
   214             if not self._test_table(schema, nd[1]):
       
   215                 continue
       
   216             tdo = DiffTable(change=nd[0], schema=schema, table=nd[1])
       
   217             if nd[0] == '*':
       
   218                 # columns
       
   219                 src_columns = src_tables[nd[1]].columns
       
   220                 dst_columns = dst_tables[nd[1]].columns
       
   221                 for cdo in self._diff_columns(schema, nd[1], src_columns, dst_columns):
       
   222                     if tdo:
       
   223                         yield tdo
       
   224                         tdo = None
       
   225                     yield cdo
       
   226                 # constraints
       
   227                 src_constraints = src_tables[nd[1]].constraints
       
   228                 dst_constraints = dst_tables[nd[1]].constraints
       
   229                 for cdo in self._diff_constraints(schema, nd[1], src_constraints, dst_constraints):
       
   230                     if tdo:
       
   231                         yield tdo
       
   232                         tdo = None
       
   233                     yield cdo
       
   234             else:
       
   235                 yield tdo
       
   236 
       
   237     def iter_diff(self):
       
   238         '''Return diff between src and dst database schema.
       
   239 
       
   240         Yields one line at the time. Each line is in form of object
       
   241         iherited from DiffBase. This object contains all information
       
   242         about changes. See format() method.
       
   243         
       
   244         '''
       
   245         src_schemas = self.src.schemas
       
   246         dst_schemas = self.dst.schemas
       
   247         src = [x.name for x in src_schemas.values() if not x.system and self._test_filter(x.name)]
       
   248         dst = [x.name for x in dst_schemas.values() if not x.system and self._test_filter(x.name)]
       
   249         for nd in self._diffnames(src, dst):
       
   250             sdo = DiffSchema(change=nd[0], schema=nd[1])
       
   251             if nd[0] == '*':
       
   252                 src_tables = src_schemas[nd[1]].tables
       
   253                 dst_tables = dst_schemas[nd[1]].tables
       
   254                 for tdo in self._difftables(nd[1], src_tables, dst_tables):
       
   255                     if sdo:
       
   256                         yield sdo
       
   257                         sdo = None
       
   258                     yield tdo
       
   259             else:
       
   260                 yield sdo
       
   261 
       
   262     def print_diff(self):
       
   263         '''Print diff between src and dst database schema.
       
   264         
       
   265         The output is in human readable form.
       
   266         
       
   267         Set allowcolor=True of PgDiff instance to get colored output. 
       
   268         
       
   269         '''
       
   270         for ln in self.iter_diff():
       
   271             print(ln.format())
       
   272         
       
   273     def print_patch(self):
       
   274         '''Print patch for updating from src schema to dst schema.
       
   275         
       
   276         Supports table drop, add, column drop, add and following
       
   277         changes of columns:
       
   278           - type
       
   279           - set/remove not null
       
   280           - default value
       
   281         
       
   282         This is experimental, not tested very much.
       
   283         Do not use without checking the commands.
       
   284         Even if it works as intended, it can cause table lock ups
       
   285         and/or loss of data. You have been warned.
       
   286         
       
   287         '''
       
   288         for ln in self.iter_diff():
       
   289             print(ln.format_patch())
       
   290 
       
   291     def filter_schemas(self, include=[], exclude=[]):
       
   292         '''Modify list of schemas which are used for computing diff.
       
   293         
       
   294         include (list) -- if not empty, consider only these schemas for diff
       
   295         exclude (list) -- exclude these schemas from diff
       
   296         
       
   297         Order: include, exclude
       
   298         include=[] means include everything
       
   299         '''
       
   300         self.include_schemas.clear()
       
   301         self.include_schemas.update(include)
       
   302         self.exclude_schemas.clear()
       
   303         self.exclude_schemas.update(exclude)
       
   304 
       
   305 
       
   306     def filter_tables(self, include=[], exclude=[]):
       
   307         self.include_tables.clear()
       
   308         self.include_tables.update(include)
       
   309         self.exclude_tables.clear()
       
   310         self.exclude_tables.update(exclude)
       
   311