mytoolkit/mymanager.py
changeset 71 4251068a251a
parent 49 08e4dfe1b0cb
child 74 d4306261ddfb
--- a/mytoolkit/mymanager.py	Mon Mar 04 15:39:34 2013 +0100
+++ b/mytoolkit/mymanager.py	Tue Mar 05 11:24:47 2013 +0100
@@ -2,12 +2,12 @@
 #
 # MyManager - manage database connections (MySQL version)
 #
-# Requires: Python 2.6 / 2.7, MySQLdb
+# Requires: Python 2.6 / 2.7 / 3.2, MySQLdb
 #
 # Part of pgtoolkit
 # http://hg.devl.cz/pgtoolkit
 #
-# Copyright (c) 2011  Radek Brich <radek.brich@devl.cz>
+# Copyright (c) 2011, 2013  Radek Brich <radek.brich@devl.cz>
 #
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to deal
@@ -29,48 +29,37 @@
 
 """MySQL database connection manager
 
-MyManager wraps MySQLdb connect function, adding following features:
+MyManager wraps MySQLdb in same manner as PgManager wraps psycopg2.
+It's fully compatible so it should work as drop-in replacement for PgManager.
 
- * Manage database connection parameters - link connection parameters
-   to an unique identifier, retrieve connection object by this identifier
+It adds following features over MySQLdb:
 
- * Connection pooling - connections with same identifier are pooled and reused
+ * Save and reuse database connection parameters
+
+ * Connection pooling
 
- * Easy query using the with statement - retrieve cursor directly by connection
-   identifier, don't worry about connections
+ * Easy query using the with statement
 
- * Dict rows - cursor has additional methods like fetchall_dict(), which
-   returns dict row instead of ordinary list-like row
+ * Dictionary rows
 
 Example:
 
-import mymanager
-
-db = mymanager.get_instance()
-db.create_conn(host='127.0.0.1', db='default')
+    from mytoolkit import mymanager
 
-with db.cursor() as curs:
-    curs.execute('SELECT now() AS now')
-    row = curs.fetchone_dict()
-    print row.now
+    dbm = mymanager.get_instance()
+    dbm.create_conn(host='127.0.0.1', dbname='default')
 
-First, we have obtained MyManager instance. This is like calling
-MyManager(), although in our example the instance is global. That means
-getting the instance in another module brings us all the defined connections
-etc.
+    with dbm.cursor() as curs:
+        curs.execute('SELECT now() AS now')
+        row = curs.fetchone_dict()
+        print(row.now)
 
-On next line we have created connection named 'default' (this name can be left out).
-The with statement obtains connection (actually connects to database when needed),
-then returns cursor for this connection. At the end of with statement,
-the connection is returned to the pool or closed (depending on number of connections
-in pool and on setting of keep_open parameter).
-
-The row returned by fetchone_dict() is special dict object, which can be accessed
-using item or attribute access, that is row['now'] or row.now.
+See PgManager docs for more information.
 
 """
 
 from contextlib import contextmanager
+from collections import OrderedDict
 import logging
 import threading
 
@@ -79,6 +68,11 @@
 
 from MySQLdb import DatabaseError, IntegrityError, OperationalError
 
+from pgtoolkit.pgmanager import RowDict
+
+
+log_sql = logging.getLogger("mymanager_sql")
+
 
 class MyManagerError(Exception):
 
@@ -87,10 +81,12 @@
 
 class ConnectionInfo:
 
-    def __init__(self, isolation_level=None, init_statement=None, keep_open=1, **kw):
+    def __init__(self, name, isolation_level=None,
+                 init_statement=None, pool_size=1, **kw):
+        self.name = name  # connection name is logged with SQL queries
         self.isolation_level = isolation_level
         self.init_statement = init_statement
-        self.keep_open = keep_open
+        self.pool_size = pool_size
         self.parameters = kw
         self.adjust_parameters()
 
@@ -105,25 +101,19 @@
         self.parameters = res
 
 
-class RowDict(dict):
-
-    def __getattr__(self, key):
-        return self[key]
-
-
 class Cursor(MySQLdb.cursors.Cursor):
 
     def execute(self, query, args=None):
         try:
             return super(Cursor, self).execute(query, args)
         finally:
-            log.debug(self._executed.decode('utf8'))
+            log_sql.debug(self._executed.decode('utf8'))
 
     def callproc(self, procname, args=None):
         try:
             return super(Cursor, self).callproc(procname, args)
         finally:
-            log.debug(self._executed.decode('utf8'))
+            log_sql.debug(self._executed.decode('utf8'))
 
     def row_dict(self, row, lstrip=None):
         adjustname = lambda a: a
@@ -148,6 +138,7 @@
         self.conn_known = {}  # available connections
         self.conn_pool = {}
         self.lock = threading.Lock()
+        self.pid = multiprocessing.current_process().pid  # forking check
 
     def __del__(self):
         for conn in tuple(self.conn_known.keys()):
@@ -159,7 +150,7 @@
             raise MyManagerError('Connection name "%s" already registered.' % name)
 
         isolation_level = self._normalize_isolation_level(isolation_level)
-        ci = ConnectionInfo(isolation_level, **kw)
+        ci = ConnectionInfo(name, isolation_level, **kw)
 
         self.conn_known[name] = ci
         self.conn_pool[name] = []
@@ -190,11 +181,13 @@
 
     def get_conn(self, name='default'):
         '''Get connection of name 'name' from pool.'''
+        self._check_fork()
         self.lock.acquire()
         try:
             if not name in self.conn_known:
                 raise MyManagerError("Connection name '%s' not registered." % name)
 
+            # connection from pool
             conn = None
             while len(self.conn_pool[name]) and conn is None:
                 conn = self.conn_pool[name].pop()
@@ -223,7 +216,7 @@
             if not name in self.conn_known:
                 raise MyManagerError("Connection name '%s' not registered." % name)
 
-            if len(self.conn_pool[name]) >= self.conn_known[name].keep_open:
+            if len(self.conn_pool[name]) >= self.conn_known[name].pool_size:
                 conn.close()
                 return
 
@@ -267,6 +260,7 @@
         if ci.init_statement:
             curs = conn.cursor()
             curs.execute(ci.init_statement)
+            curs.connection.commit()
             curs.close()
         return conn
 
@@ -284,26 +278,33 @@
                 return level
         raise MyManagerError('Unknown isolation level name: "%s"', level)
 
+    def _check_fork(self):
+        '''Check if process was forked (PID has changed).
 
-try:
-    NullHandler = logging.NullHandler
-except AttributeError:
-    class NullHandler(logging.Handler):
-        def emit(self, record):
-            pass
+        If it was, clean parent's connections.
+        New connections are created for children.
+        Known connection credentials are inherited, but not shared.
+
+        '''
+        if self.pid == multiprocessing.current_process().pid:
+            # PID has not changed
+            return
 
+        # update saved PID
+        self.pid = multiprocessing.current_process().pid
+        # reinitialize lock
+        self.lock = threading.Lock()
+        # clean parent's connections
+        for name in self.conn_pool:
+            self.conn_pool[name] = []
 
-log = logging.getLogger("mymanager")
-log.addHandler(NullHandler())
-
-
-instance = None
+    @classmethod
+    def get_instance(cls):
+        if not hasattr(cls, '_instance'):
+            cls._instance = cls()
+        return cls._instance
 
 
 def get_instance():
-    global instance
-    if instance is None:
-        instance = MyManager()
-    return instance
+    return MyManager.get_instance()
 
-