pydbkit/mymanager_oursql.py
changeset 104 d8ff52a0390f
parent 77 2cfef775f518
child 106 db4c582a2abd
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pydbkit/mymanager_oursql.py	Wed Jul 09 18:03:54 2014 +0200
@@ -0,0 +1,347 @@
+# -*- coding: utf-8 -*-
+#
+# MyManager - manage database connections (MySQL version)
+#
+# Requires: Python 2.6 / 2.7 / 3.2, oursql
+#
+# Part of pydbkit
+# http://hg.devl.cz/pydbkit
+#
+# Copyright (c) 2011, 2013  Radek Brich <radek.brich@devl.cz>
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+
+"""MySQL database connection manager
+
+MyManager wraps oursql in same manner as PgManager wraps psycopg2.
+It's fully compatible so it should work as drop-in replacement for PgManager.
+
+It adds following features over oursql:
+
+ * Save and reuse database connection parameters
+
+ * Connection pooling
+
+ * Easy query using the with statement
+
+ * Dictionary rows
+
+Example:
+
+    from pydbkit import mymanager_oursql
+
+    dbm = mymanager.get_instance()
+    dbm.create_conn(host='127.0.0.1', dbname='default')
+
+    with dbm.cursor() as curs:
+        curs.execute('SELECT now() AS now')
+        row = curs.fetchone_dict()
+        print(row.now)
+
+See PgManager docs for more information.
+
+"""
+
+from contextlib import contextmanager
+from collections import OrderedDict
+import logging
+import threading
+import multiprocessing
+
+import oursql
+
+from oursql import DatabaseError, IntegrityError, OperationalError
+
+
+log_sql = logging.getLogger("mymanager_sql")
+log_sql.addHandler(logging.NullHandler())
+
+
+class MyManagerError(Exception):
+
+    pass
+
+
+class RowDict(OrderedDict):
+    """Special dictionary used for rows returned from queries.
+
+    Items keep order in which columns where returned from database.
+
+    It supports three styles of access:
+
+        Dict style:
+            row['id']
+            for key in row:
+                ...
+
+        Object style (only works if column name does not collide with any method name):
+            row.id
+
+        Tuple style:
+            row[0]
+            id, name = row.values()
+
+    """
+
+    def __getitem__(self, key):
+        if isinstance(key, int):
+            return tuple(self.values())[key]
+        else:
+            return OrderedDict.__getitem__(self, key)
+
+    def __getattr__(self, key):
+        try:
+            return self[key]
+        except KeyError:
+            raise AttributeError(key)
+
+
+class ConnectionInfo:
+
+    def __init__(self, name, isolation_level=None,
+                 init_statement=None, pool_size=1, **kw):
+        self.name = name  # connection name is logged with SQL queries
+        self.isolation_level = isolation_level
+        self.init_statement = init_statement
+        self.pool_size = pool_size
+        self.parameters = kw
+        self.adjust_parameters()
+
+    def adjust_parameters(self):
+        '''Rename Postgres parameters to proper value for MySQL.'''
+        m = {'dbname' : 'db', 'password' : 'passwd'}
+        res = dict()
+        for k, v in list(self.parameters.items()):
+            if k in m:
+                k = m[k]
+            res[k] = v
+        self.parameters = res
+
+
+class Cursor(oursql.Cursor):
+
+    def execute(self, query, args=[]):
+        try:
+            return super(Cursor, self).execute(query, args)
+        finally:
+            self._log_query(query, args)
+
+    def callproc(self, procname, args=[]):
+        try:
+            return super(Cursor, self).callproc(procname, args)
+        finally:
+            self._log_query(query, args)
+
+    def row_dict(self, row, lstrip=None):
+        adjustname = lambda a: a
+        if lstrip:
+            adjustname = lambda a: a.lstrip(lstrip)
+        return RowDict(zip([adjustname(desc[0]) for desc in self.description], row))
+
+    def fetchone_dict(self, lstrip=None):
+        row = super(Cursor, self).fetchone()
+        if row is None:
+            return None
+        return self.row_dict(row, lstrip)
+
+    def fetchall_dict(self, lstrip=None):
+        rows = super(Cursor, self).fetchall()
+        return [self.row_dict(row, lstrip) for row in rows]
+
+    def _log_query(self, query, args):
+        name = self.connection.name if hasattr(self.connection, 'name') else '-'
+        log_sql.debug('[%s] %s %s' % (name, query, args))
+
+
+class MyManager:
+
+    def __init__(self):
+        self.conn_known = {}  # available connections
+        self.conn_pool = {}
+        self.lock = threading.Lock()
+        self.pid = multiprocessing.current_process().pid  # forking check
+
+    def __del__(self):
+        for conn in tuple(self.conn_known.keys()):
+            self.destroy_conn(conn)
+
+    def create_conn(self, name='default', isolation_level=None, **kw):
+        '''Create named connection.'''
+        if name in self.conn_known:
+            raise MyManagerError('Connection name "%s" already registered.' % name)
+
+        isolation_level = self._normalize_isolation_level(isolation_level)
+        ci = ConnectionInfo(name, isolation_level, **kw)
+
+        self.conn_known[name] = ci
+        self.conn_pool[name] = []
+
+    def close_conn(self, name='default'):
+        '''Close all connections of given name.
+
+        Connection credentials are still saved.
+
+        '''
+        while len(self.conn_pool[name]):
+            conn = self.conn_pool[name].pop()
+            conn.close()
+
+    def destroy_conn(self, name='default'):
+        '''Destroy connection.
+
+        Counterpart of create_conn.
+
+        '''
+        if not name in self.conn_known:
+            raise MyManagerError('Connection name "%s" not registered.' % name)
+
+        self.close_conn(name)
+
+        del self.conn_known[name]
+        del self.conn_pool[name]
+
+    def get_conn(self, name='default'):
+        '''Get connection of name 'name' from pool.'''
+        self._check_fork()
+        self.lock.acquire()
+        try:
+            if not name in self.conn_known:
+                raise MyManagerError("Connection name '%s' not registered." % name)
+
+            # connection from pool
+            conn = None
+            while len(self.conn_pool[name]) and conn is None:
+                conn = self.conn_pool[name].pop()
+                try:
+                    conn.ping()
+                except oursql.MySQLError:
+                    conn.close()
+                    conn = None
+
+            if conn is None:
+                ci = self.conn_known[name]
+                conn = self._connect(ci)
+        finally:
+            self.lock.release()
+        return conn
+
+    def put_conn(self, conn, name='default'):
+        '''Put connection back to pool.
+
+        Name must be same as used for get_conn,
+        otherwise things become broken.
+
+        '''
+        self.lock.acquire()
+        try:
+            if not name in self.conn_known:
+                raise MyManagerError("Connection name '%s' not registered." % name)
+
+            if len(self.conn_pool[name]) >= self.conn_known[name].pool_size:
+                conn.close()
+                return
+
+            # connection returned to the pool must not be in transaction
+            try:
+                conn.rollback()
+            except OperationalError:
+                conn.close()
+                return
+
+            self.conn_pool[name].append(conn)
+        finally:
+            self.lock.release()
+
+    @contextmanager
+    def cursor(self, name='default'):
+        '''Cursor context.
+
+        Uses any connection of name 'name' from pool
+        and returns cursor for that connection.
+
+        '''
+        conn = self.get_conn(name)
+
+        try:
+            curs = conn.cursor()
+            yield curs
+        finally:
+            curs.close()
+            self.put_conn(conn, name)
+
+    def _connect(self, ci):
+        conn = oursql.connect(default_cursor=Cursor, **ci.parameters)
+        if not ci.isolation_level is None:
+            if ci.isolation_level == 'AUTOCOMMIT':
+                conn.autocommit(True)
+            else:
+                curs = conn.cursor()
+                curs.execute('SET SESSION TRANSACTION ISOLATION LEVEL ' + ci.isolation_level)
+                curs.close()
+        if ci.init_statement:
+            curs = conn.cursor()
+            curs.execute(ci.init_statement)
+            curs.connection.commit()
+            curs.close()
+        return conn
+
+    def _normalize_isolation_level(self, level):
+        if level is None:
+            return level
+        if type(level) == str:
+            level = level.upper().replace('_', ' ')
+            if level in (
+                'AUTOCOMMIT',
+                'READ UNCOMMITTED',
+                'READ COMMITTED',
+                'REPEATABLE READ',
+                'SERIALIZABLE'):
+                return level
+        raise MyManagerError('Unknown isolation level name: "%s"', level)
+
+    def _check_fork(self):
+        '''Check if process was forked (PID has changed).
+
+        If it was, clean parent's connections.
+        New connections are created for children.
+        Known connection credentials are inherited, but not shared.
+
+        '''
+        if self.pid == multiprocessing.current_process().pid:
+            # PID has not changed
+            return
+
+        # update saved PID
+        self.pid = multiprocessing.current_process().pid
+        # reinitialize lock
+        self.lock = threading.Lock()
+        # clean parent's connections
+        for name in self.conn_pool:
+            self.conn_pool[name] = []
+
+    @classmethod
+    def get_instance(cls):
+        if not hasattr(cls, '_instance'):
+            cls._instance = cls()
+        return cls._instance
+
+
+def get_instance():
+    return MyManager.get_instance()
+