pgtoolkit/pgmanager.py
changeset 104 d8ff52a0390f
parent 103 24e94a3da209
child 105 10551741f61f
equal deleted inserted replaced
103:24e94a3da209 104:d8ff52a0390f
     1 # -*- coding: utf-8 -*-
       
     2 #
       
     3 # PgManager - manage database connections
       
     4 #
       
     5 # Requires: Python 3.2, psycopg2
       
     6 #
       
     7 # Part of pgtoolkit
       
     8 # http://hg.devl.cz/pgtoolkit
       
     9 #
       
    10 # Copyright (c) 2010, 2011, 2012, 2013  Radek Brich <radek.brich@devl.cz>
       
    11 #
       
    12 # Permission is hereby granted, free of charge, to any person obtaining a copy
       
    13 # of this software and associated documentation files (the "Software"), to deal
       
    14 # in the Software without restriction, including without limitation the rights
       
    15 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       
    16 # copies of the Software, and to permit persons to whom the Software is
       
    17 # furnished to do so, subject to the following conditions:
       
    18 #
       
    19 # The above copyright notice and this permission notice shall be included in
       
    20 # all copies or substantial portions of the Software.
       
    21 #
       
    22 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
       
    23 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
       
    24 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
       
    25 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
       
    26 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
       
    27 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
       
    28 # THE SOFTWARE.
       
    29 
       
    30 """Postgres database connection manager
       
    31 
       
    32 PgManager wraps psycopg2, adding following features:
       
    33 
       
    34  * Save and reuse database connection parameters
       
    35 
       
    36  * Connection pooling
       
    37 
       
    38  * Easy query using the with statement
       
    39 
       
    40  * Dictionary rows
       
    41 
       
    42 Example usage:
       
    43 
       
    44     from pgtoolkit import pgmanager
       
    45 
       
    46     pgm = pgmanager.get_instance()
       
    47     pgm.create_conn(hostaddr='127.0.0.1', dbname='postgres')
       
    48 
       
    49     with pgm.cursor() as curs:
       
    50         curs.execute('SELECT now() AS now')
       
    51         row = curs.fetchone_dict()
       
    52         print(row.now)
       
    53 
       
    54 First, we have obtained PgManager instance. This is like calling
       
    55 PgManager(), although in our example the instance is global. That means
       
    56 getting the instance in another module brings us all the defined connections
       
    57 etc.
       
    58 
       
    59 On second line we have created connection named 'default' (this name can be left out).
       
    60 The with statement obtains connection (actually connects to database when needed),
       
    61 then returns cursor for this connection. At the end of with statement,
       
    62 the connection is returned to the pool or closed (depending on number of connections
       
    63 in pool and on setting of pool_size parameter).
       
    64 
       
    65 The row returned by fetchone_dict() is special dict object, which can be accessed
       
    66 using item or attribute access, that is row['now'] or row.now.
       
    67 
       
    68 """
       
    69 
       
    70 from contextlib import contextmanager
       
    71 from collections import OrderedDict
       
    72 import logging
       
    73 import threading
       
    74 import multiprocessing
       
    75 import select
       
    76 import socket
       
    77 
       
    78 import psycopg2
       
    79 import psycopg2.extensions
       
    80 
       
    81 from psycopg2 import DatabaseError, IntegrityError, OperationalError
       
    82 
       
    83 
       
    84 log_sql = logging.getLogger("pgmanager_sql")
       
    85 log_notices = logging.getLogger("pgmanager_notices")
       
    86 log_sql.addHandler(logging.NullHandler())
       
    87 # NullHandler not needed for notices which are INFO level only
       
    88 
       
    89 
       
    90 class PgManagerError(Exception):
       
    91 
       
    92     pass
       
    93 
       
    94 
       
    95 class ConnectionInfo:
       
    96 
       
    97     def __init__(self, name, dsn, isolation_level=None, keep_alive=True,
       
    98                  init_statement=None, pool_size=1):
       
    99         self.name = name  # connection name is logged with SQL queries
       
   100         self.dsn = dsn  # dsn or string with connection parameters
       
   101         self.isolation_level = isolation_level
       
   102         self.keep_alive = keep_alive
       
   103         self.init_statement = init_statement
       
   104         self.pool_size = pool_size
       
   105 
       
   106 
       
   107 class RowDict(OrderedDict):
       
   108     """Special dictionary used for rows returned from queries.
       
   109 
       
   110     Items keep order in which columns where returned from database.
       
   111 
       
   112     It supports three styles of access:
       
   113 
       
   114         Dict style:
       
   115             row['id']
       
   116             for key in row:
       
   117                 ...
       
   118 
       
   119         Object style (only works if column name does not collide with any method name):
       
   120             row.id
       
   121 
       
   122         Tuple style:
       
   123             row[0]
       
   124             id, name = row.values()
       
   125 
       
   126     """
       
   127 
       
   128     def __getitem__(self, key):
       
   129         if isinstance(key, int):
       
   130             return tuple(self.values())[key]
       
   131         else:
       
   132             return OrderedDict.__getitem__(self, key)
       
   133 
       
   134     def __getattr__(self, key):
       
   135         try:
       
   136             return self[key]
       
   137         except KeyError:
       
   138             raise AttributeError(key)
       
   139 
       
   140 
       
   141 class Cursor(psycopg2.extensions.cursor):
       
   142 
       
   143     def execute(self, query, args=None):
       
   144         # log query before executing
       
   145         self._log_query(query, args)
       
   146         try:
       
   147             return super(Cursor, self).execute(query, args)
       
   148         except DatabaseError:
       
   149             self._log_exception()
       
   150             raise
       
   151 
       
   152     def callproc(self, procname, args=None):
       
   153         # log query before executing (not query actually executed but should correspond)
       
   154         self._log_query(self._build_callproc_query(procname, len(args)), args)
       
   155         try:
       
   156             return super(Cursor, self).callproc(procname, args)
       
   157         except DatabaseError:
       
   158             self._log_exception()
       
   159             raise
       
   160 
       
   161     def row_dict(self, row, lstrip=None):
       
   162         adjustname = lambda a: a
       
   163         if lstrip:
       
   164             adjustname = lambda a: a.lstrip(lstrip)
       
   165         return RowDict(zip([adjustname(desc[0]) for desc in self.description], row))
       
   166 
       
   167     def fetchone_dict(self, lstrip=None):
       
   168         '''Return one row as OrderedDict'''
       
   169         row = super(Cursor, self).fetchone()
       
   170         if row is None:
       
   171             return None
       
   172         return self.row_dict(row, lstrip)
       
   173 
       
   174     def fetchall_dict(self, lstrip=None):
       
   175         '''Return all rows as OrderedDict'''
       
   176         rows = super(Cursor, self).fetchall()
       
   177         return [self.row_dict(row, lstrip) for row in rows]
       
   178 
       
   179     def adapt(self, row):
       
   180         if isinstance(row, RowDict):
       
   181             # dict
       
   182             adapted = dict()
       
   183             for key in row.keys():
       
   184                 adapted[key] = self.mogrify('%s', [row[key]]).decode('utf8')
       
   185             return RowDict(adapted)
       
   186         else:
       
   187             # list
       
   188             return [self.mogrify('%s', [x]).decode('utf8') for x in row]
       
   189 
       
   190     def fetchone_adapted(self, lstrip=None):
       
   191         '''Like fetchone_dict() but values are quoted for direct inclusion in SQL query.
       
   192 
       
   193         This is useful when you need to generate SQL script from data returned
       
   194         by the query. Use mogrify() for simple cases.
       
   195 
       
   196         '''
       
   197         row = super(Cursor, self).fetchone()
       
   198         if row is None:
       
   199             return None
       
   200         return self.row_dict([self.mogrify('%s', [x]).decode('utf8') for x in row], lstrip)
       
   201 
       
   202     def fetchall_adapted(self, lstrip=None):
       
   203         '''Like fetchall_dict() but values are quoted for direct inclusion in SQL query.'''
       
   204         rows = super(Cursor, self).fetchall()
       
   205         return [self.row_dict([self.mogrify('%s', [x]).decode('utf8') for x in row], lstrip) for row in rows]
       
   206 
       
   207     def _log_query(self, query='?', args=None):
       
   208         name = self.connection.name if hasattr(self.connection, 'name') else '-'
       
   209         query = self.mogrify(query, args)
       
   210         log_sql.debug('[%s] %s' % (name, query.decode('utf8')))
       
   211 
       
   212     def _log_exception(self):
       
   213         name = self.connection.name if hasattr(self.connection, 'name') else '-'
       
   214         log_sql.exception('[%s] exception:' % (name,))
       
   215 
       
   216     def _build_callproc_query(self, procname, num_args):
       
   217         return 'SELECT * FROM %s(%s)' % (procname, ', '.join(['%s'] * num_args))
       
   218 
       
   219 
       
   220 class Connection(psycopg2.extensions.connection):
       
   221 
       
   222     def cursor(self, name=None):
       
   223         if name is None:
       
   224             return super(Connection, self).cursor(cursor_factory=Cursor)
       
   225         else:
       
   226             return super(Connection, self).cursor(name, cursor_factory=Cursor)
       
   227 
       
   228     def keep_alive(self):
       
   229         '''Set socket to keepalive mode. Must be called before any query.'''
       
   230         sock = socket.fromfd(self.fileno(), socket.AF_INET, socket.SOCK_STREAM)
       
   231         sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
       
   232         try:
       
   233             # Maximum keep-alive probes before asuming the connection is lost
       
   234             sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5)
       
   235             # Interval (in seconds) between keep-alive probes
       
   236             sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 2)
       
   237             # Maximum idle time (in seconds) before start sending keep-alive probes
       
   238             sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 10)
       
   239         except socket.error:
       
   240             pass
       
   241         # close duplicated fd, options set for socket stays
       
   242         sock.close()
       
   243 
       
   244 
       
   245 class PgManager:
       
   246 
       
   247     def __init__(self):
       
   248         self.conn_known = {}  # available connections
       
   249         self.conn_pool = {}  # active connetions
       
   250         self.lock = threading.Lock()  # mutual exclusion for threads
       
   251         self.pid = multiprocessing.current_process().pid  # forking check
       
   252 
       
   253     def __del__(self):
       
   254         for conn in tuple(self.conn_known.keys()):
       
   255             self.destroy_conn(conn)
       
   256 
       
   257     def create_conn(self, name='default', isolation_level=None, keep_alive=True, init_statement=None,
       
   258                     pool_size=1, dsn=None, **kwargs):
       
   259         '''Create named connection.
       
   260 
       
   261         *name* -- name for connection
       
   262 
       
   263         *pool_size* -- how many connections will be kept open in pool.
       
   264         More connections will still be created but they will be closed by put_conn.
       
   265         `None` will disable pool, get_conn() will then always return same connection.
       
   266 
       
   267         *isolation_level* -- `"autocommit"`, `"read_committed"`, `"serializable"` or `None` for driver default
       
   268 
       
   269         *keep_alive* -- set socket to keepalive mode
       
   270 
       
   271         *dsn* -- connection string (parameters or data source name)
       
   272 
       
   273         Other keyword args are used as connection parameters.
       
   274 
       
   275         '''
       
   276         if name in self.conn_known:
       
   277             raise PgManagerError('Connection name "%s" already registered.' % name)
       
   278 
       
   279         if dsn is None:
       
   280             dsn = ' '.join([x[0]+'='+str(x[1]) for x in kwargs.items() if x[1] is not None])
       
   281 
       
   282         isolation_level = self._normalize_isolation_level(isolation_level)
       
   283         ci = ConnectionInfo(name, dsn, isolation_level, keep_alive, init_statement, pool_size)
       
   284 
       
   285         self.conn_known[name] = ci
       
   286         self.conn_pool[name] = []
       
   287 
       
   288     def create_conn_listen(self, name, channel, dsn=None, copy_dsn=None, **kwargs):
       
   289         '''Create connection listening for notifies.
       
   290 
       
   291         Disables pool. If you want to use pool, create other connection for that.
       
   292         This connection can be used as usual: conn.cursor() etc.
       
   293         Don't use PgManager's cursor() and put_conn().
       
   294 
       
   295         *name* -- name for connection
       
   296 
       
   297         *channel* -- listen on this channel
       
   298 
       
   299         *copy_dsn* -- specify name of other connection and its dsn will be used
       
   300 
       
   301         Other parameters forwarded to create_conn().
       
   302 
       
   303         '''
       
   304         if dsn is None and copy_dsn:
       
   305             try:
       
   306                 dsn = self.conn_known[copy_dsn].dsn
       
   307             except KeyError:
       
   308                 raise PgManagerError("Connection name '%s' not registered." % copy_dsn)
       
   309         listen_query = "LISTEN " + channel
       
   310         self.create_conn(name=name, pool_size=None, isolation_level='autocommit', init_statement=listen_query,
       
   311             dsn=dsn, **kwargs)
       
   312 
       
   313     def close_conn(self, name='default'):
       
   314         '''Close all connections of given name.
       
   315 
       
   316         Connection credentials are still saved.
       
   317 
       
   318         '''
       
   319         while len(self.conn_pool[name]):
       
   320             conn = self.conn_pool[name].pop()
       
   321             conn.close()
       
   322 
       
   323     def destroy_conn(self, name='default'):
       
   324         '''Destroy connection.
       
   325 
       
   326         Counterpart of create_conn.
       
   327 
       
   328         '''
       
   329         if not name in self.conn_known:
       
   330             raise PgManagerError('Connection name "%s" not registered.' % name)
       
   331 
       
   332         self.close_conn(name)
       
   333 
       
   334         del self.conn_known[name]
       
   335         del self.conn_pool[name]
       
   336 
       
   337     def knows_conn(self, name='default'):
       
   338         return name in self.conn_known
       
   339 
       
   340     def get_conn(self, name='default'):
       
   341         '''Get connection of name 'name' from pool.'''
       
   342         self._check_fork()
       
   343         self.lock.acquire()
       
   344         try:
       
   345             try:
       
   346                 ci = self.conn_known[name]
       
   347             except KeyError:
       
   348                 raise PgManagerError("Connection name '%s' not registered." % name)
       
   349 
       
   350             # no pool, just one static connection
       
   351             if ci.pool_size is None:
       
   352                 # check for existing connection
       
   353                 try:
       
   354                     conn = self.conn_pool[name][0]
       
   355                     if conn.closed:
       
   356                         conn = None
       
   357                 except IndexError:
       
   358                     conn = None
       
   359                     self.conn_pool[name].append(conn)
       
   360                 # if no existing connection is valid, connect new one and save it
       
   361                 if conn is None:
       
   362                     conn = self._connect(ci)
       
   363                     self.conn_pool[name][0] = conn
       
   364 
       
   365             # connection from pool
       
   366             else:
       
   367                 conn = None
       
   368                 while len(self.conn_pool[name]) and conn is None:
       
   369                     conn = self.conn_pool[name].pop()
       
   370                     if conn.closed:
       
   371                         conn = None
       
   372 
       
   373                 if conn is None:
       
   374                     conn = self._connect(ci)
       
   375         finally:
       
   376             self.lock.release()
       
   377         return conn
       
   378 
       
   379     def put_conn(self, conn, name='default'):
       
   380         '''Put connection back to pool.
       
   381 
       
   382         *name* must be same as used for get_conn, otherwise things become broken.
       
   383 
       
   384         '''
       
   385         self.lock.acquire()
       
   386         try:
       
   387             if not name in self.conn_known:
       
   388                 raise PgManagerError("Connection name '%s' not registered." % name)
       
   389 
       
   390             if len(self.conn_pool[name]) >= self.conn_known[name].pool_size:
       
   391                 conn.close()
       
   392                 return
       
   393 
       
   394             if conn.get_transaction_status() == psycopg2.extensions.TRANSACTION_STATUS_UNKNOWN:
       
   395                 conn.close()
       
   396                 return
       
   397 
       
   398             # connection returned to the pool must not be in transaction
       
   399             if conn.get_transaction_status() != psycopg2.extensions.TRANSACTION_STATUS_IDLE:
       
   400                 try:
       
   401                     conn.rollback()
       
   402                 except OperationalError:
       
   403                     if not conn.closed:
       
   404                         conn.close()
       
   405                     return
       
   406 
       
   407             self.conn_pool[name].append(conn)
       
   408         finally:
       
   409             self.lock.release()
       
   410 
       
   411     @contextmanager
       
   412     def cursor(self, name='default'):
       
   413         '''Cursor context.
       
   414 
       
   415         Uses any connection info with *name* from pool
       
   416         and returns cursor for that connection.
       
   417 
       
   418         '''
       
   419         conn = self.get_conn(name)
       
   420 
       
   421         try:
       
   422             curs = conn.cursor()
       
   423             yield curs
       
   424         finally:
       
   425             curs.close()
       
   426             self.log_notices(conn)
       
   427             self.put_conn(conn, name)
       
   428 
       
   429     def log_notices(self, conn):
       
   430         for notice in conn.notices:
       
   431             log_notices.info(notice.rstrip())
       
   432         conn.notices[:] = []
       
   433 
       
   434     def wait_for_notify(self, name='default', timeout=None):
       
   435         '''Wait for asynchronous notifies, return the last one.
       
   436 
       
   437         *name* -- name of connection, must be created using `create_conn_listen()`
       
   438 
       
   439         *timeout* -- in seconds, floating point (`None` means wait forever)
       
   440 
       
   441         Returns `None` on timeout.
       
   442 
       
   443         '''
       
   444         conn = self.get_conn(name)
       
   445 
       
   446         # return any notifies on stack
       
   447         if conn.notifies:
       
   448             return conn.notifies.pop()
       
   449 
       
   450         if select.select([conn], [], [], timeout) == ([], [], []):
       
   451             # timeout
       
   452             return None
       
   453         else:
       
   454             conn.poll()
       
   455 
       
   456             # return just the last notify (we do not care for older ones)
       
   457             if conn.notifies:
       
   458                 return conn.notifies.pop()
       
   459             return None
       
   460 
       
   461     def _connect(self, ci):
       
   462         conn = psycopg2.connect(ci.dsn, connection_factory=Connection)
       
   463         conn.name = ci.name
       
   464         if ci.keep_alive:
       
   465             conn.keep_alive()
       
   466         if not ci.isolation_level is None:
       
   467             conn.set_isolation_level(ci.isolation_level)
       
   468         if ci.init_statement:
       
   469             curs = conn.cursor()
       
   470             curs.execute(ci.init_statement)
       
   471             curs.connection.commit()
       
   472             curs.close()
       
   473         return conn
       
   474 
       
   475     def _normalize_isolation_level(self, level):
       
   476         if type(level) == str:
       
   477             if level.lower() == 'autocommit':
       
   478                 return psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT
       
   479             if level.lower() == 'read_committed':
       
   480                 return psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED
       
   481             if level.lower() == 'serializable':
       
   482                 return psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE
       
   483             raise PgManagerError('Unknown isolation level name: "%s"' % level)
       
   484         return level
       
   485 
       
   486     def _check_fork(self):
       
   487         '''Check if process was forked (PID has changed).
       
   488 
       
   489         If it was, clean parent's connections.
       
   490         New connections are created for children.
       
   491         Known connection credentials are inherited, but not shared.
       
   492 
       
   493         '''
       
   494         if self.pid == multiprocessing.current_process().pid:
       
   495             # PID has not changed
       
   496             return
       
   497 
       
   498         # update saved PID
       
   499         self.pid = multiprocessing.current_process().pid
       
   500         # reinitialize lock
       
   501         self.lock = threading.Lock()
       
   502         # clean parent's connections
       
   503         for name in self.conn_pool:
       
   504             self.conn_pool[name] = []
       
   505 
       
   506     @classmethod
       
   507     def get_instance(cls):
       
   508         if not hasattr(cls, '_instance'):
       
   509             cls._instance = cls()
       
   510         return cls._instance
       
   511 
       
   512 
       
   513 def get_instance():
       
   514     return PgManager.get_instance()
       
   515