0001"""
0002Database API
0003(part of web.py)
0004"""
0005
0006# todo:
0007#  - test with sqllite
0008#  - a store function?
0009
0010__all__ = [
0011  "UnknownParamstyle", "UnknownDB",
0012  "sqllist", "sqlors", "aparam", "reparam",
0013  "SQLQuery", "sqlquote",
0014  "connect",
0015  "transact", "commit", "rollback",
0016  "query",
0017  "select", "insert", "update", "delete"
0018]
0019
0020import time
0021try: import datetime
0022except ImportError: datetime = None
0023
0024from utils import storage, iters, iterbetter
0025import webapi as web
0026
0027try:
0028    from DBUtils.PooledDB import PooledDB
0029    web.config._hasPooling = True
0030except ImportError:
0031    web.config._hasPooling = False
0032
0033class _ItplError(ValueError):
0034    def __init__(self, text, pos):
0035        ValueError.__init__(self)
0036        self.text = text
0037        self.pos = pos
0038    def __str__(self):
0039        return "unfinished expression in %s at char %d" % (
0040            repr(self.text), self.pos)
0041
0042def _interpolate(format):
0043    """
0044    Takes a format string and returns a list of 2-tuples of the form
0045    (boolean, string) where boolean says whether string should be evaled
0046    or not.
0047    
0048    from <http://lfw.org/python/Itpl.py> (public domain, Ka-Ping Yee)
0049    """
0050    from tokenize import tokenprog
0051
0052    def matchorfail(text, pos):
0053        match = tokenprog.match(text, pos)
0054        if match is None:
0055            raise _ItplError(text, pos)
0056        return match, match.end()
0057
0058    namechars = "abcdefghijklmnopqrstuvwxyz"           "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_";
0060    chunks = []
0061    pos = 0
0062
0063    while 1:
0064        dollar = format.find("$", pos)
0065        if dollar < 0:
0066            break
0067        nextchar = format[dollar + 1]
0068
0069        if nextchar == "{":
0070            chunks.append((0, format[pos:dollar]))
0071            pos, level = dollar + 2, 1
0072            while level:
0073                match, pos = matchorfail(format, pos)
0074                tstart, tend = match.regs[3]
0075                token = format[tstart:tend]
0076                if token == "{":
0077                    level = level + 1
0078                elif token == "}":
0079                    level = level - 1
0080            chunks.append((1, format[dollar + 2:pos - 1]))
0081
0082        elif nextchar in namechars:
0083            chunks.append((0, format[pos:dollar]))
0084            match, pos = matchorfail(format, dollar + 1)
0085            while pos < len(format):
0086                if format[pos] == "." and                       pos + 1 < len(format) and format[pos + 1] in namechars:
0088                    match, pos = matchorfail(format, pos + 1)
0089                elif format[pos] in "([":
0090                    pos, level = pos + 1, 1
0091                    while level:
0092                        match, pos = matchorfail(format, pos)
0093                        tstart, tend = match.regs[3]
0094                        token = format[tstart:tend]
0095                        if token[0] in "([":
0096                            level = level + 1
0097                        elif token[0] in ")]":
0098                            level = level - 1
0099                else:
0100                    break
0101            chunks.append((1, format[dollar + 1:pos]))
0102
0103        else:
0104            chunks.append((0, format[pos:dollar + 1]))
0105            pos = dollar + 1 + (nextchar == "$")
0106
0107    if pos < len(format):
0108        chunks.append((0, format[pos:]))
0109    return chunks
0110
0111class UnknownParamstyle(Exception):
0112    """
0113    raised for unsupported db paramstyles
0114    
0115    (currently supported: qmark, numeric, format, pyformat)
0116    """
0117    pass
0118
0119def aparam():
0120    """
0121    Returns the appropriate string to be used to interpolate
0122    a value with the current `web.ctx.db_module` or simply %s
0123    if there isn't one.
0124    
0125        >>> aparam()
0126        '%s'
0127    """
0128    if hasattr(web.ctx, 'db_module'):
0129        style = web.ctx.db_module.paramstyle
0130    else:
0131        style = 'pyformat'
0132
0133    if style == 'qmark':
0134        return '?'
0135    elif style == 'numeric':
0136        return ':1'
0137    elif style in ['format', 'pyformat']:
0138        return '%s'
0139    raise UnknownParamstyle, style
0140
0141def reparam(string_, dictionary):
0142    """
0143    Takes a string and a dictionary and interpolates the string
0144    using values from the dictionary. Returns an `SQLQuery` for the result.
0145    
0146        >>> reparam("s = $s", dict(s=True))
0147        <sql: "s = 't'">
0148    """
0149    vals = []
0150    result = []
0151    for live, chunk in _interpolate(string_):
0152        if live:
0153            result.append(aparam())
0154            vals.append(eval(chunk, dictionary))
0155        else: result.append(chunk)
0156    return SQLQuery(''.join(result), vals)
0157
0158def sqlify(obj):
0159    """
0160    converts `obj` to its proper SQL version
0161    
0162        >>> sqlify(None)
0163        'NULL'
0164        >>> sqlify(True)
0165        "'t'"
0166        >>> sqlify(3)
0167        '3'
0168    """
0169
0170    # because `1 == True and hash(1) == hash(True)`
0171    # we have to do this the hard way...
0172
0173    if obj is None:
0174        return 'NULL'
0175    elif obj is True:
0176        return "'t'"
0177    elif obj is False:
0178        return "'f'"
0179    elif datetime and isinstance(obj, datetime.datetime):
0180        return repr(obj.isoformat())
0181    else:
0182        return repr(obj)
0183
0184class SQLQuery:
0185    """
0186    You can pass this sort of thing as a clause in any db function.
0187    Otherwise, you can pass a dictionary to the keyword argument `vars`
0188    and the function will call reparam for you.
0189    """
0190    # tested in sqlquote's docstring
0191    def __init__(self, s='', v=()):
0192        self.s, self.v = str(s), tuple(v)
0193
0194    def __getitem__(self, key): # for backwards-compatibility
0195        return [self.s, self.v][key]
0196
0197    def __add__(self, other):
0198        if isinstance(other, str):
0199            self.s += other
0200        elif isinstance(other, SQLQuery):
0201            self.s += other.s
0202            self.v += other.v
0203        return self
0204
0205    def __radd__(self, other):
0206        if isinstance(other, str):
0207            self.s = other + self.s
0208            return self
0209        else:
0210            return NotImplemented
0211
0212    def __str__(self):
0213        try:
0214            return self.s % tuple([sqlify(x) for x in self.v])
0215        except (ValueError, TypeError):
0216            return self.s
0217
0218    def __repr__(self):
0219        return '<sql: %s>' % repr(str(self))
0220
0221def sqlquote(a):
0222    """
0223    Ensures `a` is quoted properly for use in a SQL query.
0224    
0225        >>> 'WHERE x = ' + sqlquote(True) + ' AND y = ' + sqlquote(3)
0226        <sql: "WHERE x = 't' AND y = 3">
0227    """
0228    return SQLQuery(aparam(), (a,))
0229
0230class UnknownDB(Exception):
0231    """raised for unsupported dbms"""
0232    pass
0233
0234def connect(dbn, **keywords):
0235    """
0236    Connects to the specified database. 
0237    
0238    `dbn` currently must be "postgres", "mysql", or "sqllite". 
0239    
0240    If DBUtils is installed, connection pooling will be used.
0241    """
0242    if dbn == "postgres":
0243        try:
0244            import psycopg2 as db
0245        except ImportError:
0246            try:
0247                import psycopg as db
0248            except ImportError:
0249                import pgdb as db
0250        if 'pw' in keywords:
0251            keywords['password'] = keywords['pw']
0252            del keywords['pw']
0253        keywords['database'] = keywords['db']
0254        del keywords['db']
0255
0256    elif dbn == "mysql":
0257        import MySQLdb as db
0258        if 'pw' in keywords:
0259            keywords['passwd'] = keywords['pw']
0260            del keywords['pw']
0261        db.paramstyle = 'pyformat' # it's both, like psycopg
0262
0263    elif dbn == "sqlite":
0264        try: ## try first sqlite3 version
0265            from pysqlite2 import dbapi2 as db
0266            db.paramstyle = 'qmark'
0267        except ImportError: ## else try sqlite2
0268            import sqlite as db
0269        keywords['database'] = keywords['db']
0270        del keywords['db']
0271
0272    elif dbn == "firebird":
0273        import kinterbasdb as db
0274        if 'pw' in keywords:
0275            keywords['passwd'] = keywords['pw']
0276            del keywords['pw']
0277        keywords['database'] = keywords['db']
0278        del keywords['db']
0279
0280    else:
0281        raise UnknownDB, dbn
0282
0283    web.ctx.db_name = dbn
0284    web.ctx.db_module = db
0285    web.ctx.db_transaction = False
0286    web.ctx.db = keywords
0287
0288    def db_cursor():
0289        if isinstance(web.ctx.db, dict):
0290            keywords = web.ctx.db
0291            if web.config._hasPooling:
0292                if 'db' not in globals():
0293                    globals()['db'] = PooledDB(dbapi=db, **keywords)
0294                web.ctx.db = globals()['db'].connection()
0295            else:
0296                web.ctx.db = db.connect(**keywords)
0297        return web.ctx.db.cursor()
0298    web.ctx.db_cursor = db_cursor
0299
0300    web.ctx.dbq_count = 0
0301
0302    def db_execute(cur, sql_query):
0303        """executes an sql query"""
0304
0305        web.ctx.dbq_count += 1
0306
0307        try:
0308            a = time.time()
0309            out = cur.execute(sql_query.s, sql_query.v)
0310            b = time.time()
0311        except:
0312            if web.config.get('db_printing'):
0313                print >> web.debug, 'ERR:', str(sql_query)
0314            rollback()
0315            raise
0316
0317        if web.config.get('db_printing'):
0318            print >> web.debug, '%s (%s): %s' % (round(b-a, 2), web.ctx.dbq_count, str(sql_query))
0319
0320        return out
0321    web.ctx.db_execute = db_execute
0322    return web.ctx.db
0323
0324def transact():
0325    """Start a transaction."""
0326    # commit everything up to now, so we don't rollback it later
0327    if hasattr(web.ctx.db, 'commit'): web.ctx.db.commit()
0328    web.ctx.db_transaction = True
0329
0330def commit():
0331    """Commits a transaction."""
0332    if hasattr(web.ctx.db, 'commit'): web.ctx.db.commit()
0333    web.ctx.db_transaction = False
0334
0335def rollback():
0336    """Rolls back a transaction."""
0337    if hasattr(web.ctx.db, 'rollback'): web.ctx.db.rollback()
0338    web.ctx.db_transaction = False
0339
0340def query(sql_query, vars=None, processed=False, _test=False):
0341    """
0342    Execute SQL query `sql_query` using dictionary `vars` to interpolate it.
0343    If `processed=True`, `vars` is a `reparam`-style list to use 
0344    instead of interpolating.
0345    
0346        >>> query("SELECT * FROM foo", _test=True)
0347        <sql: 'SELECT * FROM foo'>
0348        >>> query("SELECT * FROM foo WHERE x = $x", vars=dict(x='f'), _test=True)
0349        <sql: "SELECT * FROM foo WHERE x = 'f'">
0350        >>> query("SELECT * FROM foo WHERE x = " + sqlquote('f'), _test=True)
0351        <sql: "SELECT * FROM foo WHERE x = 'f'">
0352    """
0353    if vars is None: vars = {}
0354
0355    if not processed and not isinstance(sql_query, SQLQuery):
0356        sql_query = reparam(sql_query, vars)
0357
0358    if _test: return sql_query
0359
0360    db_cursor = web.ctx.db_cursor()
0361    web.ctx.db_execute(db_cursor, sql_query)
0362
0363    if db_cursor.description:
0364        names = [x[0] for x in db_cursor.description]
0365        def iterwrapper():
0366            row = db_cursor.fetchone()
0367            while row:
0368                yield storage(dict(zip(names, row)))
0369                row = db_cursor.fetchone()
0370        out = iterbetter(iterwrapper())
0371        out.__len__ = lambda: int(db_cursor.rowcount)
0372        out.list = lambda: [storage(dict(zip(names, x)))                              for x in db_cursor.fetchall()]
0374    else:
0375        out = db_cursor.rowcount
0376
0377    if not web.ctx.db_transaction: web.ctx.db.commit()
0378    return out
0379
0380def sqllist(lst):
0381    """
0382    Converts the arguments for use in something like a WHERE clause.
0383    
0384        >>> sqllist(['a', 'b'])
0385        'a, b'
0386        >>> sqllist('a')
0387        'a'
0388        
0389    """
0390    if isinstance(lst, str):
0391        return lst
0392    else:
0393        return ', '.join(lst)
0394
0395def sqlors(left, lst):
0396    """
0397    `left is a SQL clause like `tablename.arg = ` 
0398    and `lst` is a list of values. Returns a reparam-style
0399    pair featuring the SQL that ORs together the clause
0400    for each item in the lst.
0401
0402        >>> sqlors('foo = ', [])
0403        <sql: '2+2=5'>
0404        >>> sqlors('foo = ', [1])
0405        <sql: 'foo = 1'>
0406        >>> sqlors('foo = ', 1)
0407        <sql: 'foo = 1'>
0408        >>> sqlors('foo = ', [1,2,3])
0409        <sql: '(foo = 1 OR foo = 2 OR foo = 3)'>
0410    """
0411    if isinstance(lst, iters):
0412        lst = list(lst)
0413        ln = len(lst)
0414        if ln == 0:
0415            return SQLQuery("2+2=5", [])
0416        if ln == 1:
0417            lst = lst[0]
0418
0419    if isinstance(lst, iters):
0420        return SQLQuery('(' + left +
0421               (' OR ' + left).join([aparam() for param in lst]) + ")", lst)
0422    else:
0423        return SQLQuery(left + aparam(), [lst])
0424
0425def sqlwhere(dictionary, grouping=' AND '):
0426    """
0427    Converts a `dictionary` to an SQL WHERE clause `SQLQuery`.
0428    
0429        >>> sqlwhere({'cust_id': 2, 'order_id':3})
0430        <sql: 'order_id = 3 AND cust_id = 2'>
0431        >>> sqlwhere({'cust_id': 2, 'order_id':3}, grouping=', ')
0432        <sql: 'order_id = 3, cust_id = 2'>
0433    """
0434
0435    return SQLQuery(grouping.join([
0436      '%s = %s' % (k, aparam()) for k in dictionary.keys()
0437    ]), dictionary.values())
0438
0439def select(tables, vars=None, what='*', where=None, order=None, group=None,
0440           limit=None, offset=None, _test=False):
0441    """
0442    Selects `what` from `tables` with clauses `where`, `order`, 
0443    `group`, `limit`, and `offset`. Uses vars to interpolate. 
0444    Otherwise, each clause can be a SQLQuery.
0445    
0446        >>> select('foo', _test=True)
0447        <sql: 'SELECT * FROM foo'>
0448        >>> select(['foo', 'bar'], where="foo.bar_id = bar.id", limit=5, _test=True)
0449        <sql: 'SELECT * FROM foo, bar WHERE foo.bar_id = bar.id LIMIT 5'>
0450    """
0451    if vars is None: vars = {}
0452    qout = ""
0453
0454    def gen_clause(sql, value):
0455        if isinstance(val, (int, long)):
0456            if sql == 'WHERE':
0457                nout = 'id = ' + sqlquote(val)
0458            else:
0459                nout = SQLQuery(val)
0460        elif isinstance(val, (list, tuple)) and len(val) == 2:
0461            nout = SQLQuery(val[0], val[1]) # backwards-compatibility
0462        elif isinstance(val, SQLQuery):
0463            nout = val
0464        elif val:
0465            nout = reparam(val, vars)
0466        else:
0467            return ""
0468
0469        out = ""
0470        if qout: out += " "
0471        out += sql + " " + nout
0472        return out
0473
0474    if web.ctx.get('db_name') == "firebird":
0475        for (sql, val) in (
0476           ('FIRST', limit),
0477           ('SKIP', offset)
0478        ):
0479            qout += gen_clause(sql, val)
0480        if qout:
0481            SELECT = 'SELECT ' + qout
0482        else:
0483            SELECT = 'SELECT'
0484        qout = ""
0485        sql_clauses = (
0486          (SELECT, what),
0487          ('FROM', sqllist(tables)),
0488          ('WHERE', where),
0489          ('GROUP BY', group),
0490          ('ORDER BY', order)
0491        )
0492    else:
0493        sql_clauses = (
0494          ('SELECT', what),
0495          ('FROM', sqllist(tables)),
0496          ('WHERE', where),
0497          ('GROUP BY', group),
0498          ('ORDER BY', order),
0499          ('LIMIT', limit),
0500          ('OFFSET', offset)
0501        )
0502
0503    for (sql, val) in sql_clauses:
0504        qout += gen_clause(sql, val)
0505
0506    if _test: return qout
0507    return query(qout, processed=True)
0508
0509def insert(tablename, seqname=None, _test=False, **values):
0510    """
0511    Inserts `values` into `tablename`. Returns current sequence ID.
0512    Set `seqname` to the ID if it's not the default, or to `False`
0513    if there isn't one.
0514    
0515        >>> insert('foo', joe='bob', a=2, _test=True)
0516        <sql: "INSERT INTO foo (a, joe) VALUES (2, 'bob')">
0517    """
0518
0519    if values:
0520        sql_query = SQLQuery("INSERT INTO %s (%s) VALUES (%s)" % (
0521            tablename,
0522            ", ".join(values.keys()),
0523            ', '.join([aparam() for x in values])
0524        ), values.values())
0525    else:
0526        sql_query = SQLQuery("INSERT INTO %s DEFAULT VALUES" % tablename)
0527
0528    if _test: return sql_query
0529
0530    db_cursor = web.ctx.db_cursor()
0531    if seqname is False:
0532        pass
0533    elif web.ctx.db_name == "postgres":
0534        if seqname is None:
0535            seqname = tablename + "_id_seq"
0536        sql_query += "; SELECT currval('%s')" % seqname
0537    elif web.ctx.db_name == "mysql":
0538        web.ctx.db_execute(db_cursor, sql_query)
0539        sql_query = SQLQuery("SELECT last_insert_id()")
0540    elif web.ctx.db_name == "sqlite":
0541        web.ctx.db_execute(db_cursor, sql_query)
0542        # not really the same...
0543        sql_query = SQLQuery("SELECT last_insert_rowid()")
0544
0545    web.ctx.db_execute(db_cursor, sql_query)
0546    try:
0547        out = db_cursor.fetchone()[0]
0548    except Exception:
0549        out = None
0550
0551    if not web.ctx.db_transaction: web.ctx.db.commit()
0552
0553    return out
0554
0555def update(tables, where, vars=None, _test=False, **values):
0556    """
0557    Update `tables` with clause `where` (interpolated using `vars`)
0558    and setting `values`.
0559    
0560        >>> joe = 'Joseph'
0561        >>> update('foo', where='name = $joe', name='bob', age=5,
0562        ...   vars=locals(), _test=True)
0563        <sql: "UPDATE foo SET age = 5, name = 'bob' WHERE name = 'Joseph'">
0564    """
0565    if vars is None: vars = {}
0566
0567    if isinstance(where, (int, long)):
0568        where = "id = " + sqlquote(where)
0569    elif isinstance(where, (list, tuple)) and len(where) == 2:
0570        where = SQLQuery(where[0], where[1])
0571    elif isinstance(where, SQLQuery):
0572        pass
0573    else:
0574        where = reparam(where, vars)
0575
0576    query = (
0577      "UPDATE " + sqllist(tables) +
0578      " SET " + sqlwhere(values, ', ') +
0579      " WHERE " + where)
0580
0581    if _test: return query
0582
0583    db_cursor = web.ctx.db_cursor()
0584    web.ctx.db_execute(db_cursor, query)
0585
0586    if not web.ctx.db_transaction: web.ctx.db.commit()
0587    return db_cursor.rowcount
0588
0589def delete(table, where, using=None, vars=None, _test=False):
0590    """
0591    Deletes from `table` with clauses `where` and `using`.
0592    
0593        >>> name = 'Joe'
0594        >>> delete('foo', where='name = $name', vars=locals(), _test=True)
0595        <sql: "DELETE FROM foo WHERE name = 'Joe'">
0596    """
0597    if vars is None: vars = {}
0598
0599    if isinstance(where, (int, long)):
0600        where = "id = " + sqlquote(where)
0601    elif isinstance(where, (list, tuple)) and len(where) == 2:
0602        where = SQLQuery(where[0], where[1])
0603    elif isinstance(where, SQLQuery):
0604        pass
0605    else:
0606        where = reparam(where, vars)
0607
0608    q = 'DELETE FROM ' + table + ' WHERE ' + where
0609    if using and web.ctx.get('db_name') != "firebird":
0610        q += ' USING ' + sqllist(using)
0611
0612    if _test: return q
0613
0614    db_cursor = web.ctx.db_cursor()
0615    web.ctx.db_execute(db_cursor, q)
0616
0617    if not web.ctx.db_transaction: web.ctx.db.commit()
0618    return db_cursor.rowcount
0619
0620if __name__ == "__main__":
0621    import doctest
0622    doctest.testmod()