# -*- coding: utf-8 -*-
#
# PgBrowser - browse database schema and metadata
#
# Some of the queries came from psql.
#
# Copyright (c) 2011  Radek Brich <radek.brich@devl.cz>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.


from collections import OrderedDict


class Column:
    def __init__(self, browser, table,
        name, type, notnull, hasdefault, default, description):
        self.browser = browser  # Browser instance
        self.table = table  # Table instance
        self.name = name
        self.type = type
        self.notnull = notnull
        self.hasdefault = hasdefault
        self.default = default
        self.description = description


class Constraint:
    def __init__(self, browser, table, name, type, fname, fschema, definition):
        self.browser = browser
        self.table = table
        self.name = name
        self.type = type
        self.fname = fname  # foreign table name
        self.fschema = fschema  # foreign table schema
        self.definition = definition


class Index:
    def __init__(self, browser, table,
        name, primary, unique, clustered, valid, definition):
        self.browser = browser
        self.table = table
        self.name = name
        self.primary = primary
        self.unique = unique
        self.clustered = clustered
        self.valid = valid
        self.definition = definition


class Table:
    def __init__(self, browser, schema, name, owner, size, description):
        self._columns = None
        self._constraints = None
        self._indexes = None
        self.browser = browser  # Browser instance
        self.schema = schema  # Schema instance
        self.name = name  # table name, str
        self.owner = owner
        self.size = size
        self.description = description

    def refresh(self):
        self.refresh_columns()
        self.refresh_constraints()
        self.refresh_indexes()

    def refresh_columns(self):
        rows = self.browser.list_columns(self.name, self.schema.name)
        self._columns = OrderedDict([(x['name'], Column(self.browser, self, **x)) for x in rows])

    def refresh_constraints(self):
        rows = self.browser.list_constraints(self.name, self.schema.name)
        self._constraints = OrderedDict([(x['name'], Constraint(self.browser, self, **x)) for x in rows])

    def refresh_indexes(self):
        rows = self.browser.list_indexes(self.name, self.schema.name)
        self._indexes = OrderedDict([(x['name'], Index(self.browser, self, **x)) for x in rows])

    def getcolumns(self):
        if self._columns is None:
            self.refresh_columns()
        return self._columns
    columns = property(getcolumns)

    def getconstraints(self):
        if self._constraints is None:
            self.refresh_constraints()
        return self._constraints
    constraints = property(getconstraints)

    def getindexes(self):
        if self._indexes is None:
            self.refresh_indexes()
        return self._indexes
    indexes = property(getindexes)


class Function:
    def __init__(self, browser, schema, oid, name, type, arguments, result, source):
        self.browser = browser
        self.oid = oid
        self.name = name
        self.type = type
        self.arguments = arguments
        self.result = result
        self.source = source
        self._definition = None

    @property
    def definition(self):
        """Get full function definition including CREATE command."""
        if not self._definition:
            self._definition = self.browser.get_function_definition(self.oid)
        return self._definition


class Schema:
    def __init__(self, browser, name, owner, acl, description, system):
        self._tables = None
        self._functions = None
        self.browser = browser
        self.name = name
        self.owner = owner
        self.acl = acl
        self.description = description
        self.system = system

    def refresh(self):
        self.refresh_tables()
        self.refresh_functions()

    def refresh_tables(self):
        rows = self.browser.list_tables(self.name)
        self._tables = OrderedDict([(x['name'], Table(self.browser, self, **x)) for x in rows])

    def refresh_functions(self):
        rows = self.browser.list_functions(self.name)
        self._functions = OrderedDict([(x['name'], Function(self.browser, self, **x)) for x in rows])

    @property
    def tables(self):
        if self._tables is None:
            self.refresh_tables()
        return self._tables

    @property
    def functions(self):
        if self._functions is None:
            self.refresh_functions()
        return self._functions


class PgBrowser:
    def __init__(self, conn=None):
        self._schemas = None
        self.conn = conn

    def setconn(self, conn=None):
        self.conn = conn

    def refresh(self):
        self.refresh_schemas()

    def refresh_schemas(self):
        rows = self.list_schemas()
        self._schemas = OrderedDict([(x['name'], Schema(self, **x)) for x in rows])

    @property
    def schemas(self):
        if self._schemas is None:
            self.refresh_schemas()
        return self._schemas

    def _query(self, query, args):
        try:
            curs = self.conn.cursor()
            curs.execute(query, args)
            curs.connection.commit()
            rows = curs.fetchall()
            return [dict(zip([desc[0] for desc in curs.description], row)) for row in rows]
        finally:
            curs.close()

    def list_databases(self):
        return self._query('''
            SELECT
                d.datname as "name",
                pg_catalog.pg_get_userbyid(d.datdba) as "owner",
                pg_catalog.pg_encoding_to_char(d.encoding) as "encoding",
                d.datcollate as "collation",
                d.datctype as "ctype",
                d.datacl AS "acl",
                CASE WHEN pg_catalog.has_database_privilege(d.datname, 'CONNECT')
                    THEN pg_catalog.pg_database_size(d.datname)
                    ELSE -1 -- No access
                END as "size",
                t.spcname as "tablespace",
                pg_catalog.shobj_description(d.oid, 'pg_database') as "description"
            FROM pg_catalog.pg_database d
            JOIN pg_catalog.pg_tablespace t on d.dattablespace = t.oid
            ORDER BY 1;
            ''', [])

    def list_schemas(self):
        return self._query('''
            SELECT
                n.nspname AS "name",
                pg_catalog.pg_get_userbyid(n.nspowner) AS "owner",
                n.nspacl AS "acl",
                pg_catalog.obj_description(n.oid, 'pg_namespace') AS "description",
                CASE WHEN n.nspname IN ('information_schema', 'pg_catalog', 'pg_toast')
                    OR n.nspname ~ '^pg_temp_' OR n.nspname ~ '^pg_toast_temp_'
                    THEN TRUE
                    ELSE FALSE
                END AS "system"
            FROM pg_catalog.pg_namespace n
            ORDER BY 1;
            ''', [])

    def list_tables(self, schema='public'):
        return self._query('''
            SELECT
                c.relname as "name",
                pg_catalog.pg_get_userbyid(c.relowner) as "owner",
                pg_catalog.pg_relation_size(c.oid) as "size",
                pg_catalog.obj_description(c.oid, 'pg_class') as "description"
            FROM pg_catalog.pg_class c
            LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
            WHERE n.nspname = %s AND c.relkind IN ('r','s','')
            ORDER BY 1;
            ''', [schema])

    def list_columns(self, table, schema='public', order=2):
        return self._query('''
            SELECT
                --a.attrelid,
                a.attname as "name",
                format_type(a.atttypid, a.atttypmod) AS "type",
                a.attnotnull as "notnull",
                a.atthasdef as "hasdefault",
                pg_catalog.pg_get_expr(d.adbin, d.adrelid) as "default",
                pg_catalog.col_description(a.attrelid, a.attnum) AS "description"
            FROM pg_catalog.pg_attribute a
            LEFT JOIN pg_catalog.pg_class c ON a.attrelid = c.oid
            LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
            LEFT JOIN pg_catalog.pg_attrdef d ON a.attrelid = d.adrelid AND a.attnum = d.adnum
            WHERE n.nspname = %s AND c.relname = %s AND a.attnum > 0 AND NOT a.attisdropped
            ORDER BY ''' + str(order), [schema, table])

    def list_constraints(self, table, schema='public'):
        return self._query('''
            SELECT
                r.conname AS "name",
                r.contype AS "type",
                cf.relname AS "fname",
                nf.nspname AS "fschema",
                pg_catalog.pg_get_constraintdef(r.oid, true) as "definition"
            FROM pg_catalog.pg_constraint r
            JOIN pg_catalog.pg_class c ON r.conrelid = c.oid
            JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
            LEFT JOIN pg_catalog.pg_class cf ON r.confrelid = cf.oid
            LEFT JOIN pg_catalog.pg_namespace nf ON nf.oid = cf.relnamespace
            WHERE n.nspname = %s AND c.relname = %s
            ORDER BY 1
            ''', [schema, table])

    def list_indexes(self, table, schema='public'):
        return self._query('''
            SELECT
                c2.relname as "name",
                i.indisprimary as "primary",
                i.indisunique as "unique",
                i.indisclustered as "clustered",
                i.indisvalid as "valid",
                pg_catalog.pg_get_indexdef(i.indexrelid, 0, true) as "definition",
                ARRAY(SELECT a.attname FROM pg_catalog.pg_attribute a WHERE a.attrelid = c2.oid ORDER BY attnum) AS "columns"
                --c2.reltablespace as "tablespace_oid"
            FROM pg_catalog.pg_class c
            JOIN pg_catalog.pg_index i ON c.oid = i.indrelid
            JOIN pg_catalog.pg_class c2 ON i.indexrelid = c2.oid
            JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
            WHERE n.nspname = %(schema)s AND c.relname = %(table)s
            ORDER BY i.indisprimary DESC, i.indisunique DESC, c2.relname
            ''', {'schema': schema, 'table': table})

    def list_functions(self, schema='public'):
        '''List functions in schema.'''
        return self._query('''
            SELECT
                p.oid as "oid",
                p.proname as "name",
                pg_catalog.pg_get_function_result(p.oid) as "result",
                pg_catalog.pg_get_function_arguments(p.oid) as "arguments",
                p.prosrc as "source",
                CASE
                    WHEN p.proisagg THEN 'agg'
                    WHEN p.proiswindow THEN 'window'
                    WHEN p.prorettype = 'pg_catalog.trigger'::pg_catalog.regtype THEN 'trigger'
                    ELSE 'normal'
                END as "type"
            FROM pg_catalog.pg_proc p
            LEFT JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
            WHERE pg_catalog.pg_function_is_visible(p.oid)
                AND n.nspname = %s
            ORDER BY 1, 2, 4;
            ''', [schema])

    def get_function_definition(self, oid):
        """Get full function definition, including CREATE command etc.

        Args:
            oid: function oid from pg_catalog.pg_proc (returned by list_functions)

        """
        return self._query('''select pg_get_functiondef(%s);''', [oid])

    def list_sequences(self, schema=None):
        '''List sequences in schema.'''
        return self._query('''
            SELECT
                nc.nspname AS "sequence_schema",
                c.relname AS "sequence_name",
                t.relname AS "related_table",
                a.attname AS "related_column",
                format_type(a.atttypid, a.atttypmod) AS "related_column_type"
            FROM pg_class c
            JOIN pg_namespace nc ON nc.oid = c.relnamespace
            JOIN pg_depend d ON d.objid = c.oid
            JOIN pg_class t ON d.refobjid = t.oid
            JOIN pg_attribute a ON (d.refobjid, d.refobjsubid) = (a.attrelid, a.attnum)
            WHERE c.relkind = 'S' AND NOT pg_is_other_temp_schema(nc.oid)
            ''' + (schema and ' AND nc.nspname = %(schema)s' or '') + '''
        ''', {'schema': schema})

    def list_column_usage(self, table, column, schema='public'):
        '''List objects using the column.

        Currently shows views and constraints which use the column.

        This is useful to find which views block alteration of column type etc.

        '''
        return self._query('''
            SELECT
                'view' AS type, view_schema AS schema, view_name AS name
            FROM information_schema.view_column_usage
            WHERE table_schema=%(schema)s AND table_name=%(table)s AND column_name=%(column)s

            UNION

            SELECT
                'constraint' AS type, constraint_schema AS schema, constraint_name AS name
            FROM information_schema.constraint_column_usage
            WHERE table_schema=%(schema)s AND table_name=%(table)s AND column_name=%(column)s
            ''', {'schema':schema, 'table':table, 'column':column})

