cursors.py | cursors.py | |||
---|---|---|---|---|
"""MySQLdb Cursors | """MySQLdb Cursors | |||
This module implements Cursors of various types for MySQLdb. By | This module implements Cursors of various types for MySQLdb. By | |||
default, MySQLdb uses the Cursor class. | default, MySQLdb uses the Cursor class. | |||
""" | """ | |||
from __future__ import print_function, absolute_import | ||||
from functools import partial | ||||
import re | import re | |||
import sys | import sys | |||
PY2 = sys.version_info[0] == 2 | ||||
from MySQLdb.compat import unicode | from MySQLdb.compat import unicode | |||
from _mysql_exceptions import ( | ||||
Warning, Error, InterfaceError, DataError, | ||||
DatabaseError, OperationalError, IntegrityError, InternalError, | ||||
NotSupportedError, ProgrammingError) | ||||
restr = r""" | PY2 = sys.version_info[0] == 2 | |||
\s | if PY2: | |||
values | text_type = unicode | |||
\s* | else: | |||
( | text_type = str | |||
\( | ||||
[^()']* | #: Regular expression for :meth:`Cursor.executemany`. | |||
(?: | #: executemany only supports simple bulk insert. | |||
(?: | #: You can use it to load large dataset. | |||
(?:\( | RE_INSERT_VALUES = re.compile( | |||
# ( - editor highlighting helper | r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)" + | |||
.* | r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" + | |||
\)) | r"(\s*(?:ON DUPLICATE.*)?)\Z", | |||
| | re.IGNORECASE | re.DOTALL) | |||
' | ||||
[^\\']* | ||||
(?:\\.[^\\']*)* | ||||
' | ||||
) | ||||
[^()']* | ||||
)* | ||||
\) | ||||
) | ||||
""" | ||||
insert_values = re.compile(restr, re.S | re.I | re.X) | ||||
from _mysql_exceptions import Warning, Error, InterfaceError, DataError, \ | ||||
DatabaseError, OperationalError, IntegrityError, InternalError, \ | ||||
NotSupportedError, ProgrammingError | ||||
class BaseCursor(object): | class BaseCursor(object): | |||
"""A base for Cursor classes. Useful attributes: | """A base for Cursor classes. Useful attributes: | |||
description | description | |||
A tuple of DB API 7-tuples describing the columns in | A tuple of DB API 7-tuples describing the columns in | |||
the last executed query; see PEP-249 for details. | the last executed query; see PEP-249 for details. | |||
description_flags | description_flags | |||
Tuple of column flags for last query, one entry per column | Tuple of column flags for last query, one entry per column | |||
in the result set. Values correspond to those in | in the result set. Values correspond to those in | |||
MySQLdb.constants.FLAG. See MySQL documentation (C API) | MySQLdb.constants.FLAG. See MySQL documentation (C API) | |||
for more information. Non-standard extension. | for more information. Non-standard extension. | |||
arraysize | arraysize | |||
default number of rows fetchmany() will fetch | default number of rows fetchmany() will fetch | |||
""" | """ | |||
#: Max stetement size which :meth:`executemany` generates. | ||||
#: | ||||
#: Max size of allowed statement is max_allowed_packet - packet_header_ | ||||
size. | ||||
#: Default value of max_allowed_packet is 1048576. | ||||
max_stmt_length = 64*1024 | ||||
from _mysql_exceptions import MySQLError, Warning, Error, InterfaceErro r, \ | from _mysql_exceptions import MySQLError, Warning, Error, InterfaceErro r, \ | |||
DatabaseError, DataError, OperationalError, IntegrityError, \ | DatabaseError, DataError, OperationalError, IntegrityError, \ | |||
InternalError, ProgrammingError, NotSupportedError | InternalError, ProgrammingError, NotSupportedError | |||
_defer_warnings = False | _defer_warnings = False | |||
connection = None | ||||
def __init__(self, connection): | def __init__(self, connection): | |||
from weakref import ref | self.connection = connection | |||
self.connection = ref(connection) | ||||
self.description = None | self.description = None | |||
self.description_flags = None | self.description_flags = None | |||
self.rowcount = -1 | self.rowcount = -1 | |||
self.arraysize = 1 | self.arraysize = 1 | |||
self._executed = None | self._executed = None | |||
self.lastrowid = None | self.lastrowid = None | |||
self.messages = [] | self.messages = [] | |||
self.errorhandler = connection.errorhandler | self.errorhandler = connection.errorhandler | |||
self._result = None | self._result = None | |||
self._warnings = 0 | self._warnings = None | |||
self._info = None | ||||
self.rownumber = None | self.rownumber = None | |||
def close(self): | def close(self): | |||
"""Close the cursor. No further queries will be possible.""" | """Close the cursor. No further queries will be possible.""" | |||
try: | try: | |||
if self.connection is None or self.connection() is None: | if self.connection is None: | |||
return | return | |||
while self.nextset(): | while self.nextset(): | |||
pass | pass | |||
finally: | finally: | |||
self.connection = None | self.connection = None | |||
self.errorhandler = None | self.errorhandler = None | |||
self._result = None | self._result = None | |||
def __enter__(self): | def __enter__(self): | |||
return self | return self | |||
def __exit__(self, *exc_info): | def __exit__(self, *exc_info): | |||
del exc_info | del exc_info | |||
self.close() | self.close() | |||
def _ensure_bytes(self, x, encoding=None): | ||||
if isinstance(x, text_type): | ||||
x = x.encode(encoding) | ||||
elif isinstance(x, (tuple, list)): | ||||
x = type(x)(self._ensure_bytes(v, encoding=encoding) for v in x | ||||
) | ||||
return x | ||||
def _escape_args(self, args, conn): | ||||
ensure_bytes = partial(self._ensure_bytes, encoding=conn.encoding) | ||||
if isinstance(args, (tuple, list)): | ||||
if PY2: | ||||
args = tuple(map(ensure_bytes, args)) | ||||
return tuple(conn.literal(arg) for arg in args) | ||||
elif isinstance(args, dict): | ||||
if PY2: | ||||
args = dict((ensure_bytes(key), ensure_bytes(val)) for | ||||
(key, val) in args.items()) | ||||
return dict((key, conn.literal(val)) for (key, val) in args.ite | ||||
ms()) | ||||
else: | ||||
# If it's not a dictionary let's try escaping it anyways. | ||||
# Worst case it will throw a Value error | ||||
if PY2: | ||||
args = ensure_bytes(args) | ||||
return conn.literal(args) | ||||
def _check_executed(self): | def _check_executed(self): | |||
if not self._executed: | if not self._executed: | |||
self.errorhandler(self, ProgrammingError, "execute() first") | self.errorhandler(self, ProgrammingError, "execute() first") | |||
def _warning_check(self): | def _warning_check(self): | |||
from warnings import warn | from warnings import warn | |||
db = self._get_db() | ||||
# None => warnings not interrogated for current query yet | ||||
# 0 => no warnings exists or have been handled already for this que | ||||
ry | ||||
if self._warnings is None: | ||||
self._warnings = db.warning_count() | ||||
if self._warnings: | if self._warnings: | |||
# Only propagate warnings for current query once | ||||
warning_count = self._warnings | ||||
self._warnings = 0 | ||||
# When there is next result, fetching warnings cause "command | # When there is next result, fetching warnings cause "command | |||
# out of sync" error. | # out of sync" error. | |||
if self._result and self._result.has_next: | if self._result and self._result.has_next: | |||
msg = "There are %d MySQL warnings." % (self._warnings,) | msg = "There are %d MySQL warnings." % (warning_count,) | |||
self.messages.append(msg) | self.messages.append(msg) | |||
warn(msg, self.Warning, 3) | warn(self.Warning(0, msg), stacklevel=3) | |||
return | return | |||
warnings = self._get_db().show_warnings() | warnings = db.show_warnings() | |||
if warnings: | if warnings: | |||
# This is done in two loops in case | # This is done in two loops in case | |||
# Warnings are set to raise exceptions. | # Warnings are set to raise exceptions. | |||
for w in warnings: | for w in warnings: | |||
self.messages.append((self.Warning, w)) | self.messages.append((self.Warning, w)) | |||
for w in warnings: | for w in warnings: | |||
warn(w[-1], self.Warning, 3) | warn(self.Warning(*w[1:3]), stacklevel=3) | |||
elif self._info: | else: | |||
self.messages.append((self.Warning, self._info)) | info = db.info() | |||
warn(self._info, self.Warning, 3) | if info: | |||
self.messages.append((self.Warning, info)) | ||||
warn(self.Warning(0, info), stacklevel=3) | ||||
def nextset(self): | def nextset(self): | |||
"""Advance to the next result set. | """Advance to the next result set. | |||
Returns None if there are no more result sets. | Returns None if there are no more result sets. | |||
""" | """ | |||
if self._executed: | if self._executed: | |||
self.fetchall() | self.fetchall() | |||
del self.messages[:] | del self.messages[:] | |||
skipping to change at line 159 | skipping to change at line 187 | |||
def _post_get_result(self): pass | def _post_get_result(self): pass | |||
def _do_get_result(self): | def _do_get_result(self): | |||
db = self._get_db() | db = self._get_db() | |||
self._result = self._get_result() | self._result = self._get_result() | |||
self.rowcount = db.affected_rows() | self.rowcount = db.affected_rows() | |||
self.rownumber = 0 | self.rownumber = 0 | |||
self.description = self._result and self._result.describe() or None | self.description = self._result and self._result.describe() or None | |||
self.description_flags = self._result and self._result.field_flags( ) or None | self.description_flags = self._result and self._result.field_flags( ) or None | |||
self.lastrowid = db.insert_id() | self.lastrowid = db.insert_id() | |||
self._warnings = db.warning_count() | self._warnings = None | |||
self._info = db.info() | ||||
def setinputsizes(self, *args): | def setinputsizes(self, *args): | |||
"""Does nothing, required by DB API.""" | """Does nothing, required by DB API.""" | |||
def setoutputsizes(self, *args): | def setoutputsizes(self, *args): | |||
"""Does nothing, required by DB API.""" | """Does nothing, required by DB API.""" | |||
def _get_db(self): | def _get_db(self): | |||
con = self.connection | con = self.connection | |||
if con is not None: | ||||
con = con() | ||||
if con is None: | if con is None: | |||
raise ProgrammingError("cursor closed") | raise ProgrammingError("cursor closed") | |||
return con | return con | |||
def execute(self, query, args=None): | def execute(self, query, args=None): | |||
"""Execute a query. | """Execute a query. | |||
query -- string, query to execute on server | query -- string, query to execute on server | |||
args -- optional sequence or mapping, parameters to use with query. | args -- optional sequence or mapping, parameters to use with query. | |||
Note: If args is a sequence, then %s must be used as the | Note: If args is a sequence, then %s must be used as the | |||
parameter placeholder in the query. If a mapping is used, | parameter placeholder in the query. If a mapping is used, | |||
%(key)s must be used as the placeholder. | %(key)s must be used as the placeholder. | |||
Returns long integer rows affected, if any | Returns integer represents rows affected, if any | |||
""" | """ | |||
while self.nextset(): | while self.nextset(): | |||
pass | pass | |||
db = self._get_db() | db = self._get_db() | |||
# NOTE: | # NOTE: | |||
# Python 2: query should be bytes when executing %. | # Python 2: query should be bytes when executing %. | |||
# All unicode in args should be encoded to bytes on Python 2. | # All unicode in args should be encoded to bytes on Python 2. | |||
# Python 3: query should be str (unicode) when executing %. | # Python 3: query should be str (unicode) when executing %. | |||
# All bytes in args should be decoded with ascii and surrogateescap e on Python 3. | # All bytes in args should be decoded with ascii and surrogateescap e on Python 3. | |||
# db.literal(obj) always returns str. | # db.literal(obj) always returns str. | |||
if PY2 and isinstance(query, unicode): | if PY2 and isinstance(query, unicode): | |||
query = query.encode(db.unicode_literal.charset) | query = query.encode(db.unicode_literal.charset) | |||
if args is not None: | if args is not None: | |||
if isinstance(args, dict): | if isinstance(args, dict): | |||
args = dict((key, db.literal(item)) for key, item in args.i tems()) | args = dict((key, db.literal(item)) for key, item in args.i tems()) | |||
else: | else: | |||
args = tuple(map(db.literal, args)) | args = tuple(map(db.literal, args)) | |||
if not PY2 and isinstance(query, bytes): | if not PY2 and isinstance(query, (bytes, bytearray)): | |||
query = query.decode(db.unicode_literal.charset) | query = query.decode(db.unicode_literal.charset) | |||
query = query % args | try: | |||
query = query % args | ||||
except TypeError as m: | ||||
self.errorhandler(self, ProgrammingError, str(m)) | ||||
if isinstance(query, unicode): | if isinstance(query, unicode): | |||
query = query.encode(db.unicode_literal.charset, 'surrogateesca pe') | query = query.encode(db.unicode_literal.charset, 'surrogateesca pe') | |||
res = None | res = None | |||
try: | try: | |||
res = self._query(query) | res = self._query(query) | |||
except TypeError as m: | ||||
if m.args[0] in ("not enough arguments for format string", | ||||
"not all arguments converted"): | ||||
self.errorhandler(self, ProgrammingError, m.args[0]) | ||||
else: | ||||
self.errorhandler(self, TypeError, m) | ||||
except Exception: | except Exception: | |||
exc, value = sys.exc_info()[:2] | exc, value = sys.exc_info()[:2] | |||
self.errorhandler(self, exc, value) | self.errorhandler(self, exc, value) | |||
self._executed = query | self._executed = query | |||
if not self._defer_warnings: self._warning_check() | if not self._defer_warnings: | |||
self._warning_check() | ||||
return res | return res | |||
def executemany(self, query, args): | def executemany(self, query, args): | |||
# type: (str, list) -> int | ||||
"""Execute a multi-row query. | """Execute a multi-row query. | |||
query -- string, query to execute on server | :param query: query to execute on server | |||
:param args: Sequence of sequences or mappings. It is used as par | ||||
args | ameter. | |||
:return: Number of rows affected, if any. | ||||
Sequence of sequences or mappings, parameters to use with | ||||
query. | ||||
Returns long integer rows affected, if any. | ||||
This method improves performance on multiple-row INSERT and | This method improves performance on multiple-row INSERT and | |||
REPLACE. Otherwise it is equivalent to looping over args with | REPLACE. Otherwise it is equivalent to looping over args with | |||
execute(). | execute(). | |||
""" | """ | |||
del self.messages[:] | del self.messages[:] | |||
db = self._get_db() | ||||
if not args: return | if not args: | |||
if PY2 and isinstance(query, unicode): | return | |||
query = query.encode(db.unicode_literal.charset) | ||||
elif not PY2 and isinstance(query, bytes): | m = RE_INSERT_VALUES.match(query) | |||
query = query.decode(db.unicode_literal.charset) | if m: | |||
m = insert_values.search(query) | q_prefix = m.group(1) % () | |||
if not m: | q_values = m.group(2).rstrip() | |||
r = 0 | q_postfix = m.group(3) or '' | |||
for a in args: | assert q_values[0] == '(' and q_values[-1] == ')' | |||
r = r + self.execute(query, a) | return self._do_execute_many(q_prefix, q_values, q_postfix, arg | |||
return r | s, | |||
p = m.start(1) | self.max_stmt_length, | |||
e = m.end(1) | self._get_db().encoding) | |||
qv = m.group(1) | ||||
try: | self.rowcount = sum(self.execute(query, arg) for arg in args) | |||
q = [] | return self.rowcount | |||
for a in args: | ||||
if isinstance(a, dict): | def _do_execute_many(self, prefix, values, postfix, args, max_stmt_leng | |||
q.append(qv % dict((key, db.literal(item)) | th, encoding): | |||
for key, item in a.items())) | conn = self._get_db() | |||
escape = self._escape_args | ||||
if isinstance(prefix, text_type): | ||||
prefix = prefix.encode(encoding) | ||||
if PY2 and isinstance(values, text_type): | ||||
values = values.encode(encoding) | ||||
if isinstance(postfix, text_type): | ||||
postfix = postfix.encode(encoding) | ||||
sql = bytearray(prefix) | ||||
args = iter(args) | ||||
v = values % escape(next(args), conn) | ||||
if isinstance(v, text_type): | ||||
if PY2: | ||||
v = v.encode(encoding) | ||||
else: | ||||
v = v.encode(encoding, 'surrogateescape') | ||||
sql += v | ||||
rows = 0 | ||||
for arg in args: | ||||
v = values % escape(arg, conn) | ||||
if isinstance(v, text_type): | ||||
if PY2: | ||||
v = v.encode(encoding) | ||||
else: | else: | |||
q.append(qv % tuple([db.literal(item) for item in a])) | v = v.encode(encoding, 'surrogateescape') | |||
except TypeError as msg: | if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length: | |||
if msg.args[0] in ("not enough arguments for format string", | rows += self.execute(sql + postfix) | |||
"not all arguments converted"): | sql = bytearray(prefix) | |||
self.errorhandler(self, ProgrammingError, msg.args[0]) | ||||
else: | else: | |||
self.errorhandler(self, TypeError, msg) | sql += b',' | |||
except (SystemExit, KeyboardInterrupt): | sql += v | |||
raise | rows += self.execute(sql + postfix) | |||
except: | self.rowcount = rows | |||
exc, value = sys.exc_info()[:2] | return rows | |||
self.errorhandler(self, exc, value) | ||||
qs = '\n'.join([query[:p], ',\n'.join(q), query[e:]]) | ||||
if not PY2: | ||||
qs = qs.encode(db.unicode_literal.charset, 'surrogateescape') | ||||
r = self._query(qs) | ||||
if not self._defer_warnings: self._warning_check() | ||||
return r | ||||
def callproc(self, procname, args=()): | def callproc(self, procname, args=()): | |||
"""Execute stored procedure procname with args | """Execute stored procedure procname with args | |||
procname -- string, name of procedure to execute on server | procname -- string, name of procedure to execute on server | |||
args -- Sequence of parameters to use with procedure | args -- Sequence of parameters to use with procedure | |||
Returns the original args. | Returns the original args. | |||
skipping to change at line 322 | skipping to change at line 353 | |||
behavior with respect to the DB-API. Be sure to use nextset() | behavior with respect to the DB-API. Be sure to use nextset() | |||
to advance through all result sets; otherwise you may get | to advance through all result sets; otherwise you may get | |||
disconnected. | disconnected. | |||
""" | """ | |||
db = self._get_db() | db = self._get_db() | |||
for index, arg in enumerate(args): | for index, arg in enumerate(args): | |||
q = "SET @_%s_%d=%s" % (procname, index, | q = "SET @_%s_%d=%s" % (procname, index, | |||
db.literal(arg)) | db.literal(arg)) | |||
if isinstance(q, unicode): | if isinstance(q, unicode): | |||
q = q.encode(db.unicode_literal.charset) | q = q.encode(db.unicode_literal.charset, 'surrogateescape') | |||
self._query(q) | self._query(q) | |||
self.nextset() | self.nextset() | |||
q = "CALL %s(%s)" % (procname, | q = "CALL %s(%s)" % (procname, | |||
','.join(['@_%s_%d' % (procname, i) | ','.join(['@_%s_%d' % (procname, i) | |||
for i in range(len(args))])) | for i in range(len(args))])) | |||
if isinstance(q, unicode): | if isinstance(q, unicode): | |||
q = q.encode(db.unicode_literal.charset) | q = q.encode(db.unicode_literal.charset, 'surrogateescape') | |||
self._query(q) | self._query(q) | |||
self._executed = q | self._executed = q | |||
if not self._defer_warnings: | if not self._defer_warnings: | |||
self._warning_check() | self._warning_check() | |||
return args | return args | |||
def _do_query(self, q): | def _do_query(self, q): | |||
db = self._get_db() | db = self._get_db() | |||
self._last_executed = q | self._last_executed = q | |||
db.query(q) | db.query(q) | |||
skipping to change at line 367 | skipping to change at line 398 | |||
InterfaceError = InterfaceError | InterfaceError = InterfaceError | |||
DatabaseError = DatabaseError | DatabaseError = DatabaseError | |||
DataError = DataError | DataError = DataError | |||
OperationalError = OperationalError | OperationalError = OperationalError | |||
IntegrityError = IntegrityError | IntegrityError = IntegrityError | |||
InternalError = InternalError | InternalError = InternalError | |||
ProgrammingError = ProgrammingError | ProgrammingError = ProgrammingError | |||
NotSupportedError = NotSupportedError | NotSupportedError = NotSupportedError | |||
class CursorStoreResultMixIn(object): | class CursorStoreResultMixIn(object): | |||
"""This is a MixIn class which causes the entire result set to be | """This is a MixIn class which causes the entire result set to be | |||
stored on the client side, i.e. it uses mysql_store_result(). If the | stored on the client side, i.e. it uses mysql_store_result(). If the | |||
result set can be very large, consider adding a LIMIT clause to your | result set can be very large, consider adding a LIMIT clause to your | |||
query, or using CursorUseResultMixIn instead.""" | query, or using CursorUseResultMixIn instead.""" | |||
def _get_result(self): return self._get_db().store_result() | def _get_result(self): | |||
return self._get_db().store_result() | ||||
def _query(self, q): | def _query(self, q): | |||
rowcount = self._do_query(q) | rowcount = self._do_query(q) | |||
self._post_get_result() | self._post_get_result() | |||
return rowcount | return rowcount | |||
def _post_get_result(self): | def _post_get_result(self): | |||
self._rows = self._fetch_row(0) | self._rows = self._fetch_row(0) | |||
self._result = None | self._result = None | |||
def fetchone(self): | def fetchone(self): | |||
"""Fetches a single row from the cursor. None indicates that | """Fetches a single row from the cursor. None indicates that | |||
no more rows are available.""" | no more rows are available.""" | |||
self._check_executed() | self._check_executed() | |||
if self.rownumber >= len(self._rows): return None | if self.rownumber >= len(self._rows): | |||
return None | ||||
result = self._rows[self.rownumber] | result = self._rows[self.rownumber] | |||
self.rownumber = self.rownumber+1 | self.rownumber = self.rownumber + 1 | |||
return result | return result | |||
def fetchmany(self, size=None): | def fetchmany(self, size=None): | |||
"""Fetch up to size rows from the cursor. Result set may be smaller | """Fetch up to size rows from the cursor. Result set may be smaller | |||
than size. If size is not defined, cursor.arraysize is used.""" | than size. If size is not defined, cursor.arraysize is used.""" | |||
self._check_executed() | self._check_executed() | |||
end = self.rownumber + (size or self.arraysize) | end = self.rownumber + (size or self.arraysize) | |||
result = self._rows[self.rownumber:end] | result = self._rows[self.rownumber:end] | |||
self.rownumber = min(end, len(self._rows)) | self.rownumber = min(end, len(self._rows)) | |||
return result | return result | |||
skipping to change at line 441 | skipping to change at line 473 | |||
def __iter__(self): | def __iter__(self): | |||
self._check_executed() | self._check_executed() | |||
result = self.rownumber and self._rows[self.rownumber:] or self._ro ws | result = self.rownumber and self._rows[self.rownumber:] or self._ro ws | |||
return iter(result) | return iter(result) | |||
class CursorUseResultMixIn(object): | class CursorUseResultMixIn(object): | |||
"""This is a MixIn class which causes the result set to be stored | """This is a MixIn class which causes the result set to be stored | |||
in the server and sent row-by-row to client side, i.e. it uses | in the server and sent row-by-row to client side, i.e. it uses | |||
mysql_use_result(). You MUST retrieve the entire result set and | mysql_use_result(). You MUST retrieve the entire result set and | |||
close() the cursor before additional queries can be peformed on | close() the cursor before additional queries can be performed on | |||
the connection.""" | the connection.""" | |||
_defer_warnings = True | _defer_warnings = True | |||
def _get_result(self): return self._get_db().use_result() | def _get_result(self): return self._get_db().use_result() | |||
def fetchone(self): | def fetchone(self): | |||
"""Fetches a single row from the cursor.""" | """Fetches a single row from the cursor.""" | |||
self._check_executed() | self._check_executed() | |||
r = self._fetch_row(1) | r = self._fetch_row(1) | |||
End of changes. 36 change blocks. | ||||
113 lines changed or deleted | 152 lines changed or added | |||
This html diff was produced by rfcdiff 1.41. The latest version is available from http://tools.ietf.org/tools/rfcdiff/ |