cursors.py   cursors.py 
"""MySQLdb Cursors """MySQLdb Cursors
This module implements Cursors of various types for MySQLdb. By This module implements Cursors of various types for MySQLdb. By
default, MySQLdb uses the Cursor class. default, MySQLdb uses the Cursor class.
""" """
from __future__ import print_function, absolute_import
from functools import partial
import re import re
import sys import sys
PY2 = sys.version_info[0] == 2
from MySQLdb.compat import unicode from MySQLdb.compat import unicode
from _mysql_exceptions import (
Warning, Error, InterfaceError, DataError,
DatabaseError, OperationalError, IntegrityError, InternalError,
NotSupportedError, ProgrammingError)
restr = r""" PY2 = sys.version_info[0] == 2
\s if PY2:
values text_type = unicode
\s* else:
( text_type = str
\(
[^()']* #: Regular expression for :meth:`Cursor.executemany`.
(?: #: executemany only supports simple bulk insert.
(?: #: You can use it to load large dataset.
(?:\( RE_INSERT_VALUES = re.compile(
# ( - editor highlighting helper r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)" +
.* r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" +
\)) r"(\s*(?:ON DUPLICATE.*)?)\Z",
| re.IGNORECASE | re.DOTALL)
'
[^\\']*
(?:\\.[^\\']*)*
'
)
[^()']*
)*
\)
)
"""
insert_values = re.compile(restr, re.S | re.I | re.X)
from _mysql_exceptions import Warning, Error, InterfaceError, DataError, \
DatabaseError, OperationalError, IntegrityError, InternalError, \
NotSupportedError, ProgrammingError
class BaseCursor(object): class BaseCursor(object):
"""A base for Cursor classes. Useful attributes: """A base for Cursor classes. Useful attributes:
description description
A tuple of DB API 7-tuples describing the columns in A tuple of DB API 7-tuples describing the columns in
the last executed query; see PEP-249 for details. the last executed query; see PEP-249 for details.
description_flags description_flags
Tuple of column flags for last query, one entry per column Tuple of column flags for last query, one entry per column
in the result set. Values correspond to those in in the result set. Values correspond to those in
MySQLdb.constants.FLAG. See MySQL documentation (C API) MySQLdb.constants.FLAG. See MySQL documentation (C API)
for more information. Non-standard extension. for more information. Non-standard extension.
arraysize arraysize
default number of rows fetchmany() will fetch default number of rows fetchmany() will fetch
""" """
#: Max stetement size which :meth:`executemany` generates.
#:
#: Max size of allowed statement is max_allowed_packet - packet_header_
size.
#: Default value of max_allowed_packet is 1048576.
max_stmt_length = 64*1024
from _mysql_exceptions import MySQLError, Warning, Error, InterfaceErro r, \ from _mysql_exceptions import MySQLError, Warning, Error, InterfaceErro r, \
DatabaseError, DataError, OperationalError, IntegrityError, \ DatabaseError, DataError, OperationalError, IntegrityError, \
InternalError, ProgrammingError, NotSupportedError InternalError, ProgrammingError, NotSupportedError
_defer_warnings = False _defer_warnings = False
connection = None
def __init__(self, connection): def __init__(self, connection):
from weakref import ref self.connection = connection
self.connection = ref(connection)
self.description = None self.description = None
self.description_flags = None self.description_flags = None
self.rowcount = -1 self.rowcount = -1
self.arraysize = 1 self.arraysize = 1
self._executed = None self._executed = None
self.lastrowid = None self.lastrowid = None
self.messages = [] self.messages = []
self.errorhandler = connection.errorhandler self.errorhandler = connection.errorhandler
self._result = None self._result = None
self._warnings = 0 self._warnings = None
self._info = None
self.rownumber = None self.rownumber = None
def close(self): def close(self):
"""Close the cursor. No further queries will be possible.""" """Close the cursor. No further queries will be possible."""
try: try:
if self.connection is None or self.connection() is None: if self.connection is None:
return return
while self.nextset(): while self.nextset():
pass pass
finally: finally:
self.connection = None self.connection = None
self.errorhandler = None self.errorhandler = None
self._result = None self._result = None
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, *exc_info): def __exit__(self, *exc_info):
del exc_info del exc_info
self.close() self.close()
def _ensure_bytes(self, x, encoding=None):
if isinstance(x, text_type):
x = x.encode(encoding)
elif isinstance(x, (tuple, list)):
x = type(x)(self._ensure_bytes(v, encoding=encoding) for v in x
)
return x
def _escape_args(self, args, conn):
ensure_bytes = partial(self._ensure_bytes, encoding=conn.encoding)
if isinstance(args, (tuple, list)):
if PY2:
args = tuple(map(ensure_bytes, args))
return tuple(conn.literal(arg) for arg in args)
elif isinstance(args, dict):
if PY2:
args = dict((ensure_bytes(key), ensure_bytes(val)) for
(key, val) in args.items())
return dict((key, conn.literal(val)) for (key, val) in args.ite
ms())
else:
# If it's not a dictionary let's try escaping it anyways.
# Worst case it will throw a Value error
if PY2:
args = ensure_bytes(args)
return conn.literal(args)
def _check_executed(self): def _check_executed(self):
if not self._executed: if not self._executed:
self.errorhandler(self, ProgrammingError, "execute() first") self.errorhandler(self, ProgrammingError, "execute() first")
def _warning_check(self): def _warning_check(self):
from warnings import warn from warnings import warn
db = self._get_db()
# None => warnings not interrogated for current query yet
# 0 => no warnings exists or have been handled already for this que
ry
if self._warnings is None:
self._warnings = db.warning_count()
if self._warnings: if self._warnings:
# Only propagate warnings for current query once
warning_count = self._warnings
self._warnings = 0
# When there is next result, fetching warnings cause "command # When there is next result, fetching warnings cause "command
# out of sync" error. # out of sync" error.
if self._result and self._result.has_next: if self._result and self._result.has_next:
msg = "There are %d MySQL warnings." % (self._warnings,) msg = "There are %d MySQL warnings." % (warning_count,)
self.messages.append(msg) self.messages.append(msg)
warn(msg, self.Warning, 3) warn(self.Warning(0, msg), stacklevel=3)
return return
warnings = self._get_db().show_warnings() warnings = db.show_warnings()
if warnings: if warnings:
# This is done in two loops in case # This is done in two loops in case
# Warnings are set to raise exceptions. # Warnings are set to raise exceptions.
for w in warnings: for w in warnings:
self.messages.append((self.Warning, w)) self.messages.append((self.Warning, w))
for w in warnings: for w in warnings:
warn(w[-1], self.Warning, 3) warn(self.Warning(*w[1:3]), stacklevel=3)
elif self._info: else:
self.messages.append((self.Warning, self._info)) info = db.info()
warn(self._info, self.Warning, 3) if info:
self.messages.append((self.Warning, info))
warn(self.Warning(0, info), stacklevel=3)
def nextset(self): def nextset(self):
"""Advance to the next result set. """Advance to the next result set.
Returns None if there are no more result sets. Returns None if there are no more result sets.
""" """
if self._executed: if self._executed:
self.fetchall() self.fetchall()
del self.messages[:] del self.messages[:]
skipping to change at line 159 skipping to change at line 187
def _post_get_result(self): pass def _post_get_result(self): pass
def _do_get_result(self): def _do_get_result(self):
db = self._get_db() db = self._get_db()
self._result = self._get_result() self._result = self._get_result()
self.rowcount = db.affected_rows() self.rowcount = db.affected_rows()
self.rownumber = 0 self.rownumber = 0
self.description = self._result and self._result.describe() or None self.description = self._result and self._result.describe() or None
self.description_flags = self._result and self._result.field_flags( ) or None self.description_flags = self._result and self._result.field_flags( ) or None
self.lastrowid = db.insert_id() self.lastrowid = db.insert_id()
self._warnings = db.warning_count() self._warnings = None
self._info = db.info()
def setinputsizes(self, *args): def setinputsizes(self, *args):
"""Does nothing, required by DB API.""" """Does nothing, required by DB API."""
def setoutputsizes(self, *args): def setoutputsizes(self, *args):
"""Does nothing, required by DB API.""" """Does nothing, required by DB API."""
def _get_db(self): def _get_db(self):
con = self.connection con = self.connection
if con is not None:
con = con()
if con is None: if con is None:
raise ProgrammingError("cursor closed") raise ProgrammingError("cursor closed")
return con return con
def execute(self, query, args=None): def execute(self, query, args=None):
"""Execute a query. """Execute a query.
query -- string, query to execute on server query -- string, query to execute on server
args -- optional sequence or mapping, parameters to use with query. args -- optional sequence or mapping, parameters to use with query.
Note: If args is a sequence, then %s must be used as the Note: If args is a sequence, then %s must be used as the
parameter placeholder in the query. If a mapping is used, parameter placeholder in the query. If a mapping is used,
%(key)s must be used as the placeholder. %(key)s must be used as the placeholder.
Returns long integer rows affected, if any Returns integer represents rows affected, if any
""" """
while self.nextset(): while self.nextset():
pass pass
db = self._get_db() db = self._get_db()
# NOTE: # NOTE:
# Python 2: query should be bytes when executing %. # Python 2: query should be bytes when executing %.
# All unicode in args should be encoded to bytes on Python 2. # All unicode in args should be encoded to bytes on Python 2.
# Python 3: query should be str (unicode) when executing %. # Python 3: query should be str (unicode) when executing %.
# All bytes in args should be decoded with ascii and surrogateescap e on Python 3. # All bytes in args should be decoded with ascii and surrogateescap e on Python 3.
# db.literal(obj) always returns str. # db.literal(obj) always returns str.
if PY2 and isinstance(query, unicode): if PY2 and isinstance(query, unicode):
query = query.encode(db.unicode_literal.charset) query = query.encode(db.unicode_literal.charset)
if args is not None: if args is not None:
if isinstance(args, dict): if isinstance(args, dict):
args = dict((key, db.literal(item)) for key, item in args.i tems()) args = dict((key, db.literal(item)) for key, item in args.i tems())
else: else:
args = tuple(map(db.literal, args)) args = tuple(map(db.literal, args))
if not PY2 and isinstance(query, bytes): if not PY2 and isinstance(query, (bytes, bytearray)):
query = query.decode(db.unicode_literal.charset) query = query.decode(db.unicode_literal.charset)
query = query % args try:
query = query % args
except TypeError as m:
self.errorhandler(self, ProgrammingError, str(m))
if isinstance(query, unicode): if isinstance(query, unicode):
query = query.encode(db.unicode_literal.charset, 'surrogateesca pe') query = query.encode(db.unicode_literal.charset, 'surrogateesca pe')
res = None res = None
try: try:
res = self._query(query) res = self._query(query)
except TypeError as m:
if m.args[0] in ("not enough arguments for format string",
"not all arguments converted"):
self.errorhandler(self, ProgrammingError, m.args[0])
else:
self.errorhandler(self, TypeError, m)
except Exception: except Exception:
exc, value = sys.exc_info()[:2] exc, value = sys.exc_info()[:2]
self.errorhandler(self, exc, value) self.errorhandler(self, exc, value)
self._executed = query self._executed = query
if not self._defer_warnings: self._warning_check() if not self._defer_warnings:
self._warning_check()
return res return res
def executemany(self, query, args): def executemany(self, query, args):
# type: (str, list) -> int
"""Execute a multi-row query. """Execute a multi-row query.
query -- string, query to execute on server :param query: query to execute on server
:param args: Sequence of sequences or mappings. It is used as par
args ameter.
:return: Number of rows affected, if any.
Sequence of sequences or mappings, parameters to use with
query.
Returns long integer rows affected, if any.
This method improves performance on multiple-row INSERT and This method improves performance on multiple-row INSERT and
REPLACE. Otherwise it is equivalent to looping over args with REPLACE. Otherwise it is equivalent to looping over args with
execute(). execute().
""" """
del self.messages[:] del self.messages[:]
db = self._get_db()
if not args: return if not args:
if PY2 and isinstance(query, unicode): return
query = query.encode(db.unicode_literal.charset)
elif not PY2 and isinstance(query, bytes): m = RE_INSERT_VALUES.match(query)
query = query.decode(db.unicode_literal.charset) if m:
m = insert_values.search(query) q_prefix = m.group(1) % ()
if not m: q_values = m.group(2).rstrip()
r = 0 q_postfix = m.group(3) or ''
for a in args: assert q_values[0] == '(' and q_values[-1] == ')'
r = r + self.execute(query, a) return self._do_execute_many(q_prefix, q_values, q_postfix, arg
return r s,
p = m.start(1) self.max_stmt_length,
e = m.end(1) self._get_db().encoding)
qv = m.group(1)
try: self.rowcount = sum(self.execute(query, arg) for arg in args)
q = [] return self.rowcount
for a in args:
if isinstance(a, dict): def _do_execute_many(self, prefix, values, postfix, args, max_stmt_leng
q.append(qv % dict((key, db.literal(item)) th, encoding):
for key, item in a.items())) conn = self._get_db()
escape = self._escape_args
if isinstance(prefix, text_type):
prefix = prefix.encode(encoding)
if PY2 and isinstance(values, text_type):
values = values.encode(encoding)
if isinstance(postfix, text_type):
postfix = postfix.encode(encoding)
sql = bytearray(prefix)
args = iter(args)
v = values % escape(next(args), conn)
if isinstance(v, text_type):
if PY2:
v = v.encode(encoding)
else:
v = v.encode(encoding, 'surrogateescape')
sql += v
rows = 0
for arg in args:
v = values % escape(arg, conn)
if isinstance(v, text_type):
if PY2:
v = v.encode(encoding)
else: else:
q.append(qv % tuple([db.literal(item) for item in a])) v = v.encode(encoding, 'surrogateescape')
except TypeError as msg: if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
if msg.args[0] in ("not enough arguments for format string", rows += self.execute(sql + postfix)
"not all arguments converted"): sql = bytearray(prefix)
self.errorhandler(self, ProgrammingError, msg.args[0])
else: else:
self.errorhandler(self, TypeError, msg) sql += b','
except (SystemExit, KeyboardInterrupt): sql += v
raise rows += self.execute(sql + postfix)
except: self.rowcount = rows
exc, value = sys.exc_info()[:2] return rows
self.errorhandler(self, exc, value)
qs = '\n'.join([query[:p], ',\n'.join(q), query[e:]])
if not PY2:
qs = qs.encode(db.unicode_literal.charset, 'surrogateescape')
r = self._query(qs)
if not self._defer_warnings: self._warning_check()
return r
def callproc(self, procname, args=()): def callproc(self, procname, args=()):
"""Execute stored procedure procname with args """Execute stored procedure procname with args
procname -- string, name of procedure to execute on server procname -- string, name of procedure to execute on server
args -- Sequence of parameters to use with procedure args -- Sequence of parameters to use with procedure
Returns the original args. Returns the original args.
skipping to change at line 322 skipping to change at line 353
behavior with respect to the DB-API. Be sure to use nextset() behavior with respect to the DB-API. Be sure to use nextset()
to advance through all result sets; otherwise you may get to advance through all result sets; otherwise you may get
disconnected. disconnected.
""" """
db = self._get_db() db = self._get_db()
for index, arg in enumerate(args): for index, arg in enumerate(args):
q = "SET @_%s_%d=%s" % (procname, index, q = "SET @_%s_%d=%s" % (procname, index,
db.literal(arg)) db.literal(arg))
if isinstance(q, unicode): if isinstance(q, unicode):
q = q.encode(db.unicode_literal.charset) q = q.encode(db.unicode_literal.charset, 'surrogateescape')
self._query(q) self._query(q)
self.nextset() self.nextset()
q = "CALL %s(%s)" % (procname, q = "CALL %s(%s)" % (procname,
','.join(['@_%s_%d' % (procname, i) ','.join(['@_%s_%d' % (procname, i)
for i in range(len(args))])) for i in range(len(args))]))
if isinstance(q, unicode): if isinstance(q, unicode):
q = q.encode(db.unicode_literal.charset) q = q.encode(db.unicode_literal.charset, 'surrogateescape')
self._query(q) self._query(q)
self._executed = q self._executed = q
if not self._defer_warnings: if not self._defer_warnings:
self._warning_check() self._warning_check()
return args return args
def _do_query(self, q): def _do_query(self, q):
db = self._get_db() db = self._get_db()
self._last_executed = q self._last_executed = q
db.query(q) db.query(q)
skipping to change at line 367 skipping to change at line 398
InterfaceError = InterfaceError InterfaceError = InterfaceError
DatabaseError = DatabaseError DatabaseError = DatabaseError
DataError = DataError DataError = DataError
OperationalError = OperationalError OperationalError = OperationalError
IntegrityError = IntegrityError IntegrityError = IntegrityError
InternalError = InternalError InternalError = InternalError
ProgrammingError = ProgrammingError ProgrammingError = ProgrammingError
NotSupportedError = NotSupportedError NotSupportedError = NotSupportedError
class CursorStoreResultMixIn(object): class CursorStoreResultMixIn(object):
"""This is a MixIn class which causes the entire result set to be """This is a MixIn class which causes the entire result set to be
stored on the client side, i.e. it uses mysql_store_result(). If the stored on the client side, i.e. it uses mysql_store_result(). If the
result set can be very large, consider adding a LIMIT clause to your result set can be very large, consider adding a LIMIT clause to your
query, or using CursorUseResultMixIn instead.""" query, or using CursorUseResultMixIn instead."""
def _get_result(self): return self._get_db().store_result() def _get_result(self):
return self._get_db().store_result()
def _query(self, q): def _query(self, q):
rowcount = self._do_query(q) rowcount = self._do_query(q)
self._post_get_result() self._post_get_result()
return rowcount return rowcount
def _post_get_result(self): def _post_get_result(self):
self._rows = self._fetch_row(0) self._rows = self._fetch_row(0)
self._result = None self._result = None
def fetchone(self): def fetchone(self):
"""Fetches a single row from the cursor. None indicates that """Fetches a single row from the cursor. None indicates that
no more rows are available.""" no more rows are available."""
self._check_executed() self._check_executed()
if self.rownumber >= len(self._rows): return None if self.rownumber >= len(self._rows):
return None
result = self._rows[self.rownumber] result = self._rows[self.rownumber]
self.rownumber = self.rownumber+1 self.rownumber = self.rownumber + 1
return result return result
def fetchmany(self, size=None): def fetchmany(self, size=None):
"""Fetch up to size rows from the cursor. Result set may be smaller """Fetch up to size rows from the cursor. Result set may be smaller
than size. If size is not defined, cursor.arraysize is used.""" than size. If size is not defined, cursor.arraysize is used."""
self._check_executed() self._check_executed()
end = self.rownumber + (size or self.arraysize) end = self.rownumber + (size or self.arraysize)
result = self._rows[self.rownumber:end] result = self._rows[self.rownumber:end]
self.rownumber = min(end, len(self._rows)) self.rownumber = min(end, len(self._rows))
return result return result
skipping to change at line 441 skipping to change at line 473
def __iter__(self): def __iter__(self):
self._check_executed() self._check_executed()
result = self.rownumber and self._rows[self.rownumber:] or self._ro ws result = self.rownumber and self._rows[self.rownumber:] or self._ro ws
return iter(result) return iter(result)
class CursorUseResultMixIn(object): class CursorUseResultMixIn(object):
"""This is a MixIn class which causes the result set to be stored """This is a MixIn class which causes the result set to be stored
in the server and sent row-by-row to client side, i.e. it uses in the server and sent row-by-row to client side, i.e. it uses
mysql_use_result(). You MUST retrieve the entire result set and mysql_use_result(). You MUST retrieve the entire result set and
close() the cursor before additional queries can be peformed on close() the cursor before additional queries can be performed on
the connection.""" the connection."""
_defer_warnings = True _defer_warnings = True
def _get_result(self): return self._get_db().use_result() def _get_result(self): return self._get_db().use_result()
def fetchone(self): def fetchone(self):
"""Fetches a single row from the cursor.""" """Fetches a single row from the cursor."""
self._check_executed() self._check_executed()
r = self._fetch_row(1) r = self._fetch_row(1)
 End of changes. 36 change blocks. 
113 lines changed or deleted 152 lines changed or added

This html diff was produced by rfcdiff 1.41. The latest version is available from http://tools.ietf.org/tools/rfcdiff/