# -*- coding: utf-8 -*-
#
# PgDiff - capture differences of database metadata
#
# Depends on PgBrowser
#
# Copyright (c) 2011  Radek Brich <radek.brich@devl.cz>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.


from pycolib.ansicolor import *

import re
import difflib


class PgDiffError(Exception):
    pass


class DiffBase:
    COLORS = {
        '+' : BOLD | GREEN,
        '-' : BOLD | RED,
        '*' : BOLD | YELLOW,
    }

    COMMANDS = {
        '+' : 'CREATE',
        '-' : 'DROP',
        '*' : 'ALTER',
    }

    def __init__(self):
        self.changes = None

    def format(self):
        out = ['  ' * self.level]

        out.append(highlight(1, self.COLORS[self.change]))
        out.append(self.change)

        out += [' ', self.type, ' ', self.name, highlight(0)]

        if self.changes:
            out += [highlight(1, WHITE), ' (', self._formatchanges(), ')', highlight(0)]

        return ''.join(out)

    def _formatnotnull(self, notnull):
        if notnull:
            return 'NOT NULL'
        else:
            return None

    def _formatchanges(self):
        res = []
        for type, a, b in self.changes:
            if type == 'notnull':
                type = ''
                a = self._formatnotnull(a)
                b = self._formatnotnull(b)

            if a and b:
                s = ''.join(['Changed ', type, ' from ',
                    highlight(1,15), a, highlight(0), ' to ',
                    highlight(1,15), b, highlight(0), '.'])
            elif a and not b:
                l = ['Removed ']
                if type:
                    l += [type, ' ']
                l += [highlight(1,15), a, highlight(0), '.']
                s = ''.join(l)
            elif b and not a:
                l = ['Added ']
                if type:
                    l += [type, ' ']
                l += [highlight(1,15), b, highlight(0), '.']
                s = ''.join(l)
            res.append(s)
        return ' '.join(res)

    def format_patch(self):
        if self.change == '*' and self.type in ('schema', 'table'):
            return None
        return ['%s %s %s;' % (self.COMMANDS[self.change], self.type.upper(), self.name)]


class DiffSchema(DiffBase):
    def __init__(self, change, schema):
        DiffBase.__init__(self)
        self.level = 0
        self.type = 'schema'
        self.change = change
        self.schema = schema
        self.name = schema


class DiffTable(DiffBase):
    def __init__(self, change, schema, table):
        DiffBase.__init__(self)
        self.level = 1
        self.type = 'table'
        self.change = change
        self.schema = schema
        self.table = table
        self.name = table


class DiffArgument(DiffBase):
    def __init__(self, change, schema, function, argument):
        DiffBase.__init__(self)
        self.level = 2
        self.type = 'argument'
        self.change = change
        self.schema = schema
        self.function = function
        self.argument = argument
        self.name = argument


class DiffFunction(DiffBase):
    def __init__(self, change, schema, function, show_body_diff=False):
        DiffBase.__init__(self)
        self.level = 1
        self.type = 'function'
        self.change = change
        self.schema = schema
        self.function = function
        self.name = function
        self.show_body_diff = show_body_diff

    def _formatchanges(self):
        res = []
        for x in self.changes:
            type, a, b = x
            if type == 'source':
                if self.show_body_diff:
                    lines = ['Source differs:\n']
                    for line in difflib.unified_diff(a, b, lineterm=''):
                        if line[:3] in ('---', '+++'):
                            continue
                        color = {' ': WHITE, '-': YELLOW, '+': GREEN, '@': WHITE|BOLD}[line[0]]
                        lines.append(highlight(1, color) + line + highlight(0) + '\n')
                    res.append(''.join(lines))
                else:
                    res.append('Source differs.')
            else:
                res.append(''.join(['Changed ', type, ' from ',
                    highlight(1,15), a, highlight(0), ' to ',
                    highlight(1,15), b, highlight(0), '.']))
        return ' '.join(res)


class DiffColumn(DiffBase):
    ALTER_COMMANDS = {
        '+' : 'ADD',
        '-' : 'DROP',
        '*' : 'ALTER',
    }

    def __init__(self, change, schema, table, column, columntype, columndefault, columnnotnull, changes=None):
        DiffBase.__init__(self)
        self.level = 2
        self.type = 'column'
        self.change = change
        self.schema = schema
        self.table = table
        self.column = column
        self.columntype = columntype
        self.columndefault = columndefault
        self.columnnotnull = columnnotnull
        self.name = column
        self.changes = changes

    def format_patch(self):
        alter_table = 'ALTER TABLE %s.%s %s COLUMN %s' % (
            self.schema,
            self.table,
            self.ALTER_COMMANDS[self.change],
            self.name,
        )
        out = []
        if self.change == '-':
            out.append('%s;' % alter_table);
        if self.change == '+':
            notnull = ''
            if self.columnnotnull:
                notnull = ' NOT NULL'
            default = ''
            if self.columndefault:
                default = ' DEFAULT %s' % self.columndefault
            out.append('%s %s%s%s;'
                % (alter_table, self.columntype, notnull, default));
        if self.change == '*':
            for type, a, b in self.changes:
                if type == 'type':
                    out.append('%s TYPE %s;' % (alter_table, b))
                if type == 'notnull':
                    if a and not b:
                        out.append('%s DROP NOT NULL;' % alter_table)
                    if not a and b:
                        out.append('%s SET NOT NULL;' % alter_table)
                if type == 'default':
                    if b:
                        out.append('%s SET DEFAULT %s;' % (alter_table, b))
                    else:
                        out.append('%s DROP DEFAULT;' % alter_table)
        return out


class DiffConstraint(DiffBase):
    def __init__(self, change, schema, table, constraint, definition, changes=None):
        DiffBase.__init__(self)
        self.level = 2
        self.type = 'constraint'
        self.change = change
        self.schema = schema
        self.table = table
        self.constraint = constraint
        self.name = constraint
        self.definition = definition
        self.changes = changes

    def format_patch(self):
        q_alter = 'ALTER TABLE %s.%s' % (self.schema, self.table)
        q_drop = '%s DROP CONSTRAINT %s;' % (q_alter, self.constraint)
        q_add = '%s ADD CONSTRAINT %s %s;' % (q_alter, self.constraint, self.definition)
        if self.change == '*':
            out = [q_drop, q_add]
        if self.change == '+':
            out = [q_add]
        if self.change == '-':
            out = [q_drop]
        return out


class DiffIndex(DiffBase):
    def __init__(self, change, schema, table, index, definition, changes=None):
        DiffBase.__init__(self)
        self.level = 2
        self.type = 'index'
        self.change = change
        self.schema = schema
        self.table = table
        self.index = index
        self.name = index
        self.definition = definition
        self.changes = changes

    def format_patch(self):
        q_drop = 'DROP INDEX %s;' % (self.index,)
        q_add = '%s;' % (self.definition,)
        if self.change == '*':
            out = [q_drop, q_add]
        if self.change == '+':
            out = [q_add]
        if self.change == '-':
            out = [q_drop]
        return out


class DiffType(DiffBase):
    def __init__(self, change, schema, name):
        DiffBase.__init__(self)
        self.level = 1
        self.type = 'type'
        self.change = change
        self.schema = schema
        self.name = name


class PgDiff:
    def __init__(self, srcbrowser=None, dstbrowser=None):
        self.allowcolor = False
        self.src = srcbrowser
        self.dst = dstbrowser
        self.include_schemas = set()  # if not empty, consider only these schemas for diff
        self.exclude_schemas = set()  # exclude these schemas from diff
        self.include_tables = set()
        self.exclude_tables = set()
        self.function_regex = re.compile(r"")
        self.function_body_diff = False

    def _test_schema(self, schema):
        if self.include_schemas and schema not in self.include_schemas:
            return False
        if schema in self.exclude_schemas:
            return False
        return True

    def _test_table(self, table):
        if self.include_tables and table not in self.include_tables:
            return False
        if table in self.exclude_tables:
            return False
        return True

    def _test_function(self, function):
        return bool(self.function_regex.match(function))

    def _diff_names(self, src, dst):
        for x in src:
            if x in dst:
                yield ('*', x)
            else:
                yield ('-', x)
        for x in dst:
            if x not in src:
                yield ('+', x)

    def _compare_columns(self, a, b):
        diff = []
        if a.type != b.type:
            diff.append(('type', a.type, b.type))
        if a.notnull != b.notnull:
            diff.append(('notnull', a.notnull, b.notnull))
        if a.default != b.default:
            diff.append(('default', a.default, b.default))
        return diff

    def _compare_constraints(self, a, b):
        diff = []
        if a.type != b.type:
            diff.append(('type', a.type, b.type))
        if a.definition != b.definition:
            diff.append(('definition', a.definition, b.definition))
        return diff

    def _compare_indexes(self, a, b):
        diff = []
        if a.definition != b.definition:
            diff.append(('definition', a.definition, b.definition))
        return diff

    def _compare_functions(self, a, b):
        diff = []
        if a.result != b.result:
            diff.append(('result', a.result, b.result))
        # function source may differ in newlines (\n vs \r\n)
        # split lines before comparison, so that these differencies are ignored
        a_source = a.source.splitlines()
        b_source = b.source.splitlines()
        if a_source != b_source:
            diff.append(('source', a_source, b_source))
        return diff

    def _compare_arguments(self, a, b):
        diff = []
        if a.type != b.type:
            diff.append(('type', a.type, b.type))
        if a.mode != b.mode:
            diff.append(('mode', a.mode, b.mode))
        if a.default != b.default:
            diff.append(('default', a.default, b.default))
        return diff

    def _compare_types(self, a, b):
        diff = []
        if a.type != b.type:
            diff.append(('type', a.type, b.type))
        if a.elements != b.elements:
            diff.append(('elements', repr(a.elements), repr(b.elements)))
        return diff

    def _diff_columns(self, schema, table, src_columns, dst_columns):
        for nd in self._diff_names(src_columns, dst_columns):
            if nd[1] in dst_columns:
                dst_type = dst_columns[nd[1]].type
                dst_default = dst_columns[nd[1]].default
                dst_notnull = dst_columns[nd[1]].notnull
            else:
                dst_type = None
                dst_default = None
                dst_notnull = None
            cdo = DiffColumn(change=nd[0], schema=schema, table=table, column=nd[1],
                columntype=dst_type, columndefault=dst_default, columnnotnull=dst_notnull)
            if nd[0] == '*':
                a = src_columns[nd[1]]
                b = dst_columns[nd[1]]
                cdo.changes = self._compare_columns(a, b)
                if cdo.changes:
                    yield cdo
            else:
                yield cdo

    def _diff_constraints(self, schema, table, src_constraints, dst_constraints):
        for nd in self._diff_names(src_constraints, dst_constraints):
            if nd[1] in dst_constraints:
                dst_definition = dst_constraints[nd[1]].definition
            else:
                dst_definition = None
            cdo = DiffConstraint(change=nd[0], schema=schema, table=table, constraint=nd[1],
                definition=dst_definition)
            if nd[0] == '*':
                a = src_constraints[nd[1]]
                b = dst_constraints[nd[1]]
                cdo.changes = self._compare_constraints(a, b)
                if cdo.changes:
                    yield cdo
            else:
                yield cdo

    def _diff_indexes(self, schema, table, src_indexes, dst_indexes):
        for nd in self._diff_names(src_indexes, dst_indexes):
            if nd[1] in dst_indexes:
                dst_definition = dst_indexes[nd[1]].definition
            else:
                dst_definition = None
            ido = DiffIndex(change=nd[0], schema=schema, table=table, index=nd[1],
                definition=dst_definition)
            if nd[0] == '*':
                a = src_indexes[nd[1]]
                b = dst_indexes[nd[1]]
                ido.changes = self._compare_indexes(a, b)
                if ido.changes:
                    yield ido
            else:
                yield ido

    def _diff_tables(self, schema, src_tables, dst_tables):
        for nd in self._diff_names(src_tables, dst_tables):
            if not self._test_table(nd[1]):
                continue
            tdo = DiffTable(change=nd[0], schema=schema, table=nd[1])
            if nd[0] == '*':
                # columns
                src_columns = src_tables[nd[1]].columns
                dst_columns = dst_tables[nd[1]].columns
                for cdo in self._diff_columns(schema, nd[1], src_columns, dst_columns):
                    if tdo:
                        yield tdo
                        tdo = None
                    yield cdo
                # constraints
                src_constraints = src_tables[nd[1]].constraints
                dst_constraints = dst_tables[nd[1]].constraints
                for cdo in self._diff_constraints(schema, nd[1], src_constraints, dst_constraints):
                    if tdo:
                        yield tdo
                        tdo = None
                    yield cdo
                # indexes
                src_indexes = src_tables[nd[1]].indexes
                dst_indexes = dst_tables[nd[1]].indexes
                for ido in self._diff_indexes(schema, nd[1], src_indexes, dst_indexes):
                    if tdo:
                        yield tdo
                        tdo = None
                    yield ido
            else:
                yield tdo

    def _diff_arguments(self, schema, function, src_args, dst_args):
        for nd in self._diff_names(src_args, dst_args):
            ado = DiffArgument(change=nd[0], schema=schema, function=function, argument=nd[1])
            if nd[0] == '*':
                a = src_args[nd[1]]
                b = dst_args[nd[1]]
                ado.changes = self._compare_arguments(a, b)
                if ado.changes:
                    yield ado
            else:
                yield ado

    def _diff_functions(self, schema, src_functions, dst_functions):
        for nd in self._diff_names(src_functions, dst_functions):
            if not self._test_function(nd[1]):
                continue
            fdo = DiffFunction(change=nd[0], schema=schema, function=nd[1], show_body_diff=self.function_body_diff)
            if nd[0] == '*':
                # compare function body and result
                a = src_functions[nd[1]]
                b = dst_functions[nd[1]]
                fdo.changes = self._compare_functions(a, b)
                if fdo.changes:
                    yield fdo
                    fdo = None
                # arguments
                src_args = src_functions[nd[1]].arguments
                dst_args = dst_functions[nd[1]].arguments
                for ado in self._diff_arguments(schema, nd[1], src_args, dst_args):
                    if fdo:
                        yield fdo
                        fdo = None
                    yield ado
            else:
                yield fdo

    def _diff_types(self, schema, src_types, dst_types):
        for nd in self._diff_names(src_types, dst_types):
            tdo = DiffType(change=nd[0], schema=schema, name=nd[1])
            if nd[0] == '*':
                a = src_types[nd[1]]
                b = dst_types[nd[1]]
                tdo.changes = self._compare_types(a, b)
                if tdo.changes:
                    yield tdo
            else:
                yield tdo

    def iter_diff(self):
        '''Return diff between src and dst database schema.

        Yields one line at the time. Each line is in form of object
        iherited from DiffBase. This object contains all information
        about changes. See format() method.

        '''
        src_schemas = self.src.schemas
        dst_schemas = self.dst.schemas
        src = [x.name for x in src_schemas.values() if not x.system and self._test_schema(x.name)]
        dst = [x.name for x in dst_schemas.values() if not x.system and self._test_schema(x.name)]
        for nd in self._diff_names(src, dst):
            sdo = DiffSchema(change=nd[0], schema=nd[1])
            if nd[0] == '*':
                # tables
                src_tables = src_schemas[nd[1]].tables
                dst_tables = dst_schemas[nd[1]].tables
                for tdo in self._diff_tables(nd[1], src_tables, dst_tables):
                    if sdo:
                        yield sdo
                        sdo = None
                    yield tdo
                # functions
                src_functions = src_schemas[nd[1]].functions
                dst_functions = dst_schemas[nd[1]].functions
                for fdo in self._diff_functions(nd[1], src_functions, dst_functions):
                    if sdo:
                        yield sdo
                        sdo = None
                    yield fdo
                # types
                src_types = src_schemas[nd[1]].types
                dst_types = dst_schemas[nd[1]].types
                for tdo in self._diff_types(nd[1], src_types, dst_types):
                    if sdo:
                        yield sdo
                        sdo = None
                    yield tdo
            else:
                yield sdo

    def print_diff(self):
        '''Print diff between src and dst database schema.

        The output is in human readable form.

        Set allowcolor=True of PgDiff instance to get colored output.

        '''
        for ln in self.iter_diff():
            print(ln.format())

    def print_patch(self):
        '''Print patch for updating from src schema to dst schema.

        Supports table drop, add, column drop, add and following
        changes of columns:
          - type
          - set/remove not null
          - default value

        This is experimental, not tested very much.
        Do not use without checking the commands.
        Even if it works as intended, it can cause table lock ups
        and/or loss of data. You have been warned.

        '''
        for ln in self.iter_diff():
            patch = ln.format_patch()
            if patch:
                print('\n'.join(patch))

    def filter_schemas(self, include=[], exclude=[]):
        '''Modify list of schemas which are used for computing diff.

        include (list) -- if not empty, consider only these schemas for diff
        exclude (list) -- exclude these schemas from diff

        Order: include, exclude
        include=[] means include everything

        Raises:
            PgDiffError: when schema from include list is not found in src db

        '''
        for schema in include:
            self._check_schema_exist(schema)
        self.include_schemas.clear()
        self.include_schemas.update(include)
        self.exclude_schemas.clear()
        self.exclude_schemas.update(exclude)

    def filter_tables(self, include=[], exclude=[]):
        self.include_tables.clear()
        self.include_tables.update(include)
        self.exclude_tables.clear()
        self.exclude_tables.update(exclude)

    def filter_functions(self, regex=''):
        self.function_regex = re.compile(regex)

    def _check_schema_exist(self, schema):
        if not schema in self.src.schemas:
            raise PgDiffError('Schema "%s" not found in source database.' % schema)

