0001"""
0002Database API
0003(part of web.py)
0004"""
0005
0006
0007
0008
0009
0010__all__ = [
0011 "UnknownParamstyle", "UnknownDB",
0012 "sqllist", "sqlors", "aparam", "reparam",
0013 "SQLQuery", "sqlquote",
0014 "connect",
0015 "transact", "commit", "rollback",
0016 "query",
0017 "select", "insert", "update", "delete"
0018]
0019
0020import time
0021try: import datetime
0022except ImportError: datetime = None
0023
0024from utils import storage, iters, iterbetter
0025import webapi as web
0026
0027try:
0028 from DBUtils.PooledDB import PooledDB
0029 web.config._hasPooling = True
0030except ImportError:
0031 web.config._hasPooling = False
0032
0033class _ItplError(ValueError):
0034 def __init__(self, text, pos):
0035 ValueError.__init__(self)
0036 self.text = text
0037 self.pos = pos
0038 def __str__(self):
0039 return "unfinished expression in %s at char %d" % (
0040 repr(self.text), self.pos)
0041
0042def _interpolate(format):
0043 """
0044 Takes a format string and returns a list of 2-tuples of the form
0045 (boolean, string) where boolean says whether string should be evaled
0046 or not.
0047
0048 from <http://lfw.org/python/Itpl.py> (public domain, Ka-Ping Yee)
0049 """
0050 from tokenize import tokenprog
0051
0052 def matchorfail(text, pos):
0053 match = tokenprog.match(text, pos)
0054 if match is None:
0055 raise _ItplError(text, pos)
0056 return match, match.end()
0057
0058 namechars = "abcdefghijklmnopqrstuvwxyz" "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_";
0060 chunks = []
0061 pos = 0
0062
0063 while 1:
0064 dollar = format.find("$", pos)
0065 if dollar < 0:
0066 break
0067 nextchar = format[dollar + 1]
0068
0069 if nextchar == "{":
0070 chunks.append((0, format[pos:dollar]))
0071 pos, level = dollar + 2, 1
0072 while level:
0073 match, pos = matchorfail(format, pos)
0074 tstart, tend = match.regs[3]
0075 token = format[tstart:tend]
0076 if token == "{":
0077 level = level + 1
0078 elif token == "}":
0079 level = level - 1
0080 chunks.append((1, format[dollar + 2:pos - 1]))
0081
0082 elif nextchar in namechars:
0083 chunks.append((0, format[pos:dollar]))
0084 match, pos = matchorfail(format, dollar + 1)
0085 while pos < len(format):
0086 if format[pos] == "." and pos + 1 < len(format) and format[pos + 1] in namechars:
0088 match, pos = matchorfail(format, pos + 1)
0089 elif format[pos] in "([":
0090 pos, level = pos + 1, 1
0091 while level:
0092 match, pos = matchorfail(format, pos)
0093 tstart, tend = match.regs[3]
0094 token = format[tstart:tend]
0095 if token[0] in "([":
0096 level = level + 1
0097 elif token[0] in ")]":
0098 level = level - 1
0099 else:
0100 break
0101 chunks.append((1, format[dollar + 1:pos]))
0102
0103 else:
0104 chunks.append((0, format[pos:dollar + 1]))
0105 pos = dollar + 1 + (nextchar == "$")
0106
0107 if pos < len(format):
0108 chunks.append((0, format[pos:]))
0109 return chunks
0110
0111class UnknownParamstyle(Exception):
0112 """
0113 raised for unsupported db paramstyles
0114
0115 (currently supported: qmark, numeric, format, pyformat)
0116 """
0117 pass
0118
0119def aparam():
0120 """
0121 Returns the appropriate string to be used to interpolate
0122 a value with the current `web.ctx.db_module` or simply %s
0123 if there isn't one.
0124
0125 >>> aparam()
0126 '%s'
0127 """
0128 if hasattr(web.ctx, 'db_module'):
0129 style = web.ctx.db_module.paramstyle
0130 else:
0131 style = 'pyformat'
0132
0133 if style == 'qmark':
0134 return '?'
0135 elif style == 'numeric':
0136 return ':1'
0137 elif style in ['format', 'pyformat']:
0138 return '%s'
0139 raise UnknownParamstyle, style
0140
0141def reparam(string_, dictionary):
0142 """
0143 Takes a string and a dictionary and interpolates the string
0144 using values from the dictionary. Returns an `SQLQuery` for the result.
0145
0146 >>> reparam("s = $s", dict(s=True))
0147 <sql: "s = 't'">
0148 """
0149 vals = []
0150 result = []
0151 for live, chunk in _interpolate(string_):
0152 if live:
0153 result.append(aparam())
0154 vals.append(eval(chunk, dictionary))
0155 else: result.append(chunk)
0156 return SQLQuery(''.join(result), vals)
0157
0158def sqlify(obj):
0159 """
0160 converts `obj` to its proper SQL version
0161
0162 >>> sqlify(None)
0163 'NULL'
0164 >>> sqlify(True)
0165 "'t'"
0166 >>> sqlify(3)
0167 '3'
0168 """
0169
0170
0171
0172
0173 if obj is None:
0174 return 'NULL'
0175 elif obj is True:
0176 return "'t'"
0177 elif obj is False:
0178 return "'f'"
0179 elif datetime and isinstance(obj, datetime.datetime):
0180 return repr(obj.isoformat())
0181 else:
0182 return repr(obj)
0183
0184class SQLQuery:
0185 """
0186 You can pass this sort of thing as a clause in any db function.
0187 Otherwise, you can pass a dictionary to the keyword argument `vars`
0188 and the function will call reparam for you.
0189 """
0190
0191 def __init__(self, s='', v=()):
0192 self.s, self.v = str(s), tuple(v)
0193
0194 def __getitem__(self, key):
0195 return [self.s, self.v][key]
0196
0197 def __add__(self, other):
0198 if isinstance(other, str):
0199 self.s += other
0200 elif isinstance(other, SQLQuery):
0201 self.s += other.s
0202 self.v += other.v
0203 return self
0204
0205 def __radd__(self, other):
0206 if isinstance(other, str):
0207 self.s = other + self.s
0208 return self
0209 else:
0210 return NotImplemented
0211
0212 def __str__(self):
0213 try:
0214 return self.s % tuple([sqlify(x) for x in self.v])
0215 except (ValueError, TypeError):
0216 return self.s
0217
0218 def __repr__(self):
0219 return '<sql: %s>' % repr(str(self))
0220
0221def sqlquote(a):
0222 """
0223 Ensures `a` is quoted properly for use in a SQL query.
0224
0225 >>> 'WHERE x = ' + sqlquote(True) + ' AND y = ' + sqlquote(3)
0226 <sql: "WHERE x = 't' AND y = 3">
0227 """
0228 return SQLQuery(aparam(), (a,))
0229
0230class UnknownDB(Exception):
0231 """raised for unsupported dbms"""
0232 pass
0233
0234def connect(dbn, **keywords):
0235 """
0236 Connects to the specified database.
0237
0238 `dbn` currently must be "postgres", "mysql", or "sqllite".
0239
0240 If DBUtils is installed, connection pooling will be used.
0241 """
0242 if dbn == "postgres":
0243 try:
0244 import psycopg2 as db
0245 except ImportError:
0246 try:
0247 import psycopg as db
0248 except ImportError:
0249 import pgdb as db
0250 if 'pw' in keywords:
0251 keywords['password'] = keywords['pw']
0252 del keywords['pw']
0253 keywords['database'] = keywords['db']
0254 del keywords['db']
0255
0256 elif dbn == "mysql":
0257 import MySQLdb as db
0258 if 'pw' in keywords:
0259 keywords['passwd'] = keywords['pw']
0260 del keywords['pw']
0261 db.paramstyle = 'pyformat'
0262
0263 elif dbn == "sqlite":
0264 try:
0265 from pysqlite2 import dbapi2 as db
0266 db.paramstyle = 'qmark'
0267 except ImportError:
0268 import sqlite as db
0269 keywords['database'] = keywords['db']
0270 del keywords['db']
0271
0272 elif dbn == "firebird":
0273 import kinterbasdb as db
0274 if 'pw' in keywords:
0275 keywords['passwd'] = keywords['pw']
0276 del keywords['pw']
0277 keywords['database'] = keywords['db']
0278 del keywords['db']
0279
0280 else:
0281 raise UnknownDB, dbn
0282
0283 web.ctx.db_name = dbn
0284 web.ctx.db_module = db
0285 web.ctx.db_transaction = False
0286 web.ctx.db = keywords
0287
0288 def db_cursor():
0289 if isinstance(web.ctx.db, dict):
0290 keywords = web.ctx.db
0291 if web.config._hasPooling:
0292 if 'db' not in globals():
0293 globals()['db'] = PooledDB(dbapi=db, **keywords)
0294 web.ctx.db = globals()['db'].connection()
0295 else:
0296 web.ctx.db = db.connect(**keywords)
0297 return web.ctx.db.cursor()
0298 web.ctx.db_cursor = db_cursor
0299
0300 web.ctx.dbq_count = 0
0301
0302 def db_execute(cur, sql_query):
0303 """executes an sql query"""
0304
0305 web.ctx.dbq_count += 1
0306
0307 try:
0308 a = time.time()
0309 out = cur.execute(sql_query.s, sql_query.v)
0310 b = time.time()
0311 except:
0312 if web.config.get('db_printing'):
0313 print >> web.debug, 'ERR:', str(sql_query)
0314 rollback()
0315 raise
0316
0317 if web.config.get('db_printing'):
0318 print >> web.debug, '%s (%s): %s' % (round(b-a, 2), web.ctx.dbq_count, str(sql_query))
0319
0320 return out
0321 web.ctx.db_execute = db_execute
0322 return web.ctx.db
0323
0324def transact():
0325 """Start a transaction."""
0326
0327 if hasattr(web.ctx.db, 'commit'): web.ctx.db.commit()
0328 web.ctx.db_transaction = True
0329
0330def commit():
0331 """Commits a transaction."""
0332 if hasattr(web.ctx.db, 'commit'): web.ctx.db.commit()
0333 web.ctx.db_transaction = False
0334
0335def rollback():
0336 """Rolls back a transaction."""
0337 if hasattr(web.ctx.db, 'rollback'): web.ctx.db.rollback()
0338 web.ctx.db_transaction = False
0339
0340def query(sql_query, vars=None, processed=False, _test=False):
0341 """
0342 Execute SQL query `sql_query` using dictionary `vars` to interpolate it.
0343 If `processed=True`, `vars` is a `reparam`-style list to use
0344 instead of interpolating.
0345
0346 >>> query("SELECT * FROM foo", _test=True)
0347 <sql: 'SELECT * FROM foo'>
0348 >>> query("SELECT * FROM foo WHERE x = $x", vars=dict(x='f'), _test=True)
0349 <sql: "SELECT * FROM foo WHERE x = 'f'">
0350 >>> query("SELECT * FROM foo WHERE x = " + sqlquote('f'), _test=True)
0351 <sql: "SELECT * FROM foo WHERE x = 'f'">
0352 """
0353 if vars is None: vars = {}
0354
0355 if not processed and not isinstance(sql_query, SQLQuery):
0356 sql_query = reparam(sql_query, vars)
0357
0358 if _test: return sql_query
0359
0360 db_cursor = web.ctx.db_cursor()
0361 web.ctx.db_execute(db_cursor, sql_query)
0362
0363 if db_cursor.description:
0364 names = [x[0] for x in db_cursor.description]
0365 def iterwrapper():
0366 row = db_cursor.fetchone()
0367 while row:
0368 yield storage(dict(zip(names, row)))
0369 row = db_cursor.fetchone()
0370 out = iterbetter(iterwrapper())
0371 out.__len__ = lambda: int(db_cursor.rowcount)
0372 out.list = lambda: [storage(dict(zip(names, x))) for x in db_cursor.fetchall()]
0374 else:
0375 out = db_cursor.rowcount
0376
0377 if not web.ctx.db_transaction: web.ctx.db.commit()
0378 return out
0379
0380def sqllist(lst):
0381 """
0382 Converts the arguments for use in something like a WHERE clause.
0383
0384 >>> sqllist(['a', 'b'])
0385 'a, b'
0386 >>> sqllist('a')
0387 'a'
0388
0389 """
0390 if isinstance(lst, str):
0391 return lst
0392 else:
0393 return ', '.join(lst)
0394
0395def sqlors(left, lst):
0396 """
0397 `left is a SQL clause like `tablename.arg = `
0398 and `lst` is a list of values. Returns a reparam-style
0399 pair featuring the SQL that ORs together the clause
0400 for each item in the lst.
0401
0402 >>> sqlors('foo = ', [])
0403 <sql: '2+2=5'>
0404 >>> sqlors('foo = ', [1])
0405 <sql: 'foo = 1'>
0406 >>> sqlors('foo = ', 1)
0407 <sql: 'foo = 1'>
0408 >>> sqlors('foo = ', [1,2,3])
0409 <sql: '(foo = 1 OR foo = 2 OR foo = 3)'>
0410 """
0411 if isinstance(lst, iters):
0412 lst = list(lst)
0413 ln = len(lst)
0414 if ln == 0:
0415 return SQLQuery("2+2=5", [])
0416 if ln == 1:
0417 lst = lst[0]
0418
0419 if isinstance(lst, iters):
0420 return SQLQuery('(' + left +
0421 (' OR ' + left).join([aparam() for param in lst]) + ")", lst)
0422 else:
0423 return SQLQuery(left + aparam(), [lst])
0424
0425def sqlwhere(dictionary, grouping=' AND '):
0426 """
0427 Converts a `dictionary` to an SQL WHERE clause `SQLQuery`.
0428
0429 >>> sqlwhere({'cust_id': 2, 'order_id':3})
0430 <sql: 'order_id = 3 AND cust_id = 2'>
0431 >>> sqlwhere({'cust_id': 2, 'order_id':3}, grouping=', ')
0432 <sql: 'order_id = 3, cust_id = 2'>
0433 """
0434
0435 return SQLQuery(grouping.join([
0436 '%s = %s' % (k, aparam()) for k in dictionary.keys()
0437 ]), dictionary.values())
0438
0439def select(tables, vars=None, what='*', where=None, order=None, group=None,
0440 limit=None, offset=None, _test=False):
0441 """
0442 Selects `what` from `tables` with clauses `where`, `order`,
0443 `group`, `limit`, and `offset`. Uses vars to interpolate.
0444 Otherwise, each clause can be a SQLQuery.
0445
0446 >>> select('foo', _test=True)
0447 <sql: 'SELECT * FROM foo'>
0448 >>> select(['foo', 'bar'], where="foo.bar_id = bar.id", limit=5, _test=True)
0449 <sql: 'SELECT * FROM foo, bar WHERE foo.bar_id = bar.id LIMIT 5'>
0450 """
0451 if vars is None: vars = {}
0452 qout = ""
0453
0454 def gen_clause(sql, value):
0455 if isinstance(val, (int, long)):
0456 if sql == 'WHERE':
0457 nout = 'id = ' + sqlquote(val)
0458 else:
0459 nout = SQLQuery(val)
0460 elif isinstance(val, (list, tuple)) and len(val) == 2:
0461 nout = SQLQuery(val[0], val[1])
0462 elif isinstance(val, SQLQuery):
0463 nout = val
0464 elif val:
0465 nout = reparam(val, vars)
0466 else:
0467 return ""
0468
0469 out = ""
0470 if qout: out += " "
0471 out += sql + " " + nout
0472 return out
0473
0474 if web.ctx.get('db_name') == "firebird":
0475 for (sql, val) in (
0476 ('FIRST', limit),
0477 ('SKIP', offset)
0478 ):
0479 qout += gen_clause(sql, val)
0480 if qout:
0481 SELECT = 'SELECT ' + qout
0482 else:
0483 SELECT = 'SELECT'
0484 qout = ""
0485 sql_clauses = (
0486 (SELECT, what),
0487 ('FROM', sqllist(tables)),
0488 ('WHERE', where),
0489 ('GROUP BY', group),
0490 ('ORDER BY', order)
0491 )
0492 else:
0493 sql_clauses = (
0494 ('SELECT', what),
0495 ('FROM', sqllist(tables)),
0496 ('WHERE', where),
0497 ('GROUP BY', group),
0498 ('ORDER BY', order),
0499 ('LIMIT', limit),
0500 ('OFFSET', offset)
0501 )
0502
0503 for (sql, val) in sql_clauses:
0504 qout += gen_clause(sql, val)
0505
0506 if _test: return qout
0507 return query(qout, processed=True)
0508
0509def insert(tablename, seqname=None, _test=False, **values):
0510 """
0511 Inserts `values` into `tablename`. Returns current sequence ID.
0512 Set `seqname` to the ID if it's not the default, or to `False`
0513 if there isn't one.
0514
0515 >>> insert('foo', joe='bob', a=2, _test=True)
0516 <sql: "INSERT INTO foo (a, joe) VALUES (2, 'bob')">
0517 """
0518
0519 if values:
0520 sql_query = SQLQuery("INSERT INTO %s (%s) VALUES (%s)" % (
0521 tablename,
0522 ", ".join(values.keys()),
0523 ', '.join([aparam() for x in values])
0524 ), values.values())
0525 else:
0526 sql_query = SQLQuery("INSERT INTO %s DEFAULT VALUES" % tablename)
0527
0528 if _test: return sql_query
0529
0530 db_cursor = web.ctx.db_cursor()
0531 if seqname is False:
0532 pass
0533 elif web.ctx.db_name == "postgres":
0534 if seqname is None:
0535 seqname = tablename + "_id_seq"
0536 sql_query += "; SELECT currval('%s')" % seqname
0537 elif web.ctx.db_name == "mysql":
0538 web.ctx.db_execute(db_cursor, sql_query)
0539 sql_query = SQLQuery("SELECT last_insert_id()")
0540 elif web.ctx.db_name == "sqlite":
0541 web.ctx.db_execute(db_cursor, sql_query)
0542
0543 sql_query = SQLQuery("SELECT last_insert_rowid()")
0544
0545 web.ctx.db_execute(db_cursor, sql_query)
0546 try:
0547 out = db_cursor.fetchone()[0]
0548 except Exception:
0549 out = None
0550
0551 if not web.ctx.db_transaction: web.ctx.db.commit()
0552
0553 return out
0554
0555def update(tables, where, vars=None, _test=False, **values):
0556 """
0557 Update `tables` with clause `where` (interpolated using `vars`)
0558 and setting `values`.
0559
0560 >>> joe = 'Joseph'
0561 >>> update('foo', where='name = $joe', name='bob', age=5,
0562 ... vars=locals(), _test=True)
0563 <sql: "UPDATE foo SET age = 5, name = 'bob' WHERE name = 'Joseph'">
0564 """
0565 if vars is None: vars = {}
0566
0567 if isinstance(where, (int, long)):
0568 where = "id = " + sqlquote(where)
0569 elif isinstance(where, (list, tuple)) and len(where) == 2:
0570 where = SQLQuery(where[0], where[1])
0571 elif isinstance(where, SQLQuery):
0572 pass
0573 else:
0574 where = reparam(where, vars)
0575
0576 query = (
0577 "UPDATE " + sqllist(tables) +
0578 " SET " + sqlwhere(values, ', ') +
0579 " WHERE " + where)
0580
0581 if _test: return query
0582
0583 db_cursor = web.ctx.db_cursor()
0584 web.ctx.db_execute(db_cursor, query)
0585
0586 if not web.ctx.db_transaction: web.ctx.db.commit()
0587 return db_cursor.rowcount
0588
0589def delete(table, where, using=None, vars=None, _test=False):
0590 """
0591 Deletes from `table` with clauses `where` and `using`.
0592
0593 >>> name = 'Joe'
0594 >>> delete('foo', where='name = $name', vars=locals(), _test=True)
0595 <sql: "DELETE FROM foo WHERE name = 'Joe'">
0596 """
0597 if vars is None: vars = {}
0598
0599 if isinstance(where, (int, long)):
0600 where = "id = " + sqlquote(where)
0601 elif isinstance(where, (list, tuple)) and len(where) == 2:
0602 where = SQLQuery(where[0], where[1])
0603 elif isinstance(where, SQLQuery):
0604 pass
0605 else:
0606 where = reparam(where, vars)
0607
0608 q = 'DELETE FROM ' + table + ' WHERE ' + where
0609 if using and web.ctx.get('db_name') != "firebird":
0610 q += ' USING ' + sqllist(using)
0611
0612 if _test: return q
0613
0614 db_cursor = web.ctx.db_cursor()
0615 web.ctx.db_execute(db_cursor, q)
0616
0617 if not web.ctx.db_transaction: web.ctx.db.commit()
0618 return db_cursor.rowcount
0619
0620if __name__ == "__main__":
0621 import doctest
0622 doctest.testmod()