openmedialibrary_platform_l.../lib/python3.7/site-packages/sqlalchemy/testing/assertsql.py

373 lines
12 KiB
Python
Raw Normal View History

2016-02-06 09:37:19 +00:00
# testing/assertsql.py
2016-02-22 11:13:36 +00:00
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
2016-02-06 09:37:19 +00:00
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from ..engine.default import DefaultDialect
from .. import util
import re
2016-02-22 11:13:36 +00:00
import collections
import contextlib
from .. import event
from sqlalchemy.schema import _DDLCompiles
from sqlalchemy.engine.util import _distill_params
from sqlalchemy.engine import url
2016-02-06 09:37:19 +00:00
class AssertRule(object):
2016-02-22 11:13:36 +00:00
is_consumed = False
errormessage = None
consume_statement = True
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
def process_statement(self, execute_observed):
2016-02-06 09:37:19 +00:00
pass
2016-02-22 11:13:36 +00:00
def no_more_statements(self):
assert False, 'All statements are complete, but pending '\
'assertion rules remain'
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
class SQLMatchRule(AssertRule):
pass
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
class CursorSQL(SQLMatchRule):
consume_statement = False
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
def __init__(self, statement, params=None):
self.statement = statement
self.params = params
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
def process_statement(self, execute_observed):
stmt = execute_observed.statements[0]
if self.statement != stmt.statement or (
self.params is not None and self.params != stmt.parameters):
self.errormessage = \
"Testing for exact SQL %s parameters %s received %s %s" % (
self.statement, self.params,
stmt.statement, stmt.parameters
)
else:
execute_observed.statements.pop(0)
self.is_consumed = True
if not execute_observed.statements:
self.consume_statement = True
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
class CompiledSQL(SQLMatchRule):
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
def __init__(self, statement, params=None, dialect='default'):
self.statement = statement
self.params = params
self.dialect = dialect
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
def _compare_sql(self, execute_observed, received_statement):
stmt = re.sub(r'[\n\t]', '', self.statement)
return received_statement == stmt
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
def _compile_dialect(self, execute_observed):
if self.dialect == 'default':
return DefaultDialect()
else:
# ugh
if self.dialect == 'postgresql':
params = {'implicit_returning': True}
else:
params = {}
return url.URL(self.dialect).get_dialect()(**params)
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
def _received_statement(self, execute_observed):
"""reconstruct the statement and params in terms
of a target dialect, which for CompiledSQL is just DefaultDialect."""
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
context = execute_observed.context
compare_dialect = self._compile_dialect(execute_observed)
if isinstance(context.compiled.statement, _DDLCompiles):
compiled = \
context.compiled.statement.compile(dialect=compare_dialect)
else:
compiled = (
context.compiled.statement.compile(
dialect=compare_dialect,
column_keys=context.compiled.column_keys,
inline=context.compiled.inline)
)
_received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled))
parameters = execute_observed.parameters
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
if not parameters:
_received_parameters = [compiled.construct_params()]
else:
_received_parameters = [
compiled.construct_params(m) for m in parameters]
return _received_statement, _received_parameters
def process_statement(self, execute_observed):
context = execute_observed.context
_received_statement, _received_parameters = \
self._received_statement(execute_observed)
params = self._all_params(context)
equivalent = self._compare_sql(execute_observed, _received_statement)
if equivalent:
if params is not None:
all_params = list(params)
all_received = list(_received_parameters)
while all_params and all_received:
param = dict(all_params.pop(0))
for idx, received in enumerate(list(all_received)):
# do a positive compare only
for param_key in param:
# a key in param did not match current
# 'received'
if param_key not in received or \
received[param_key] != param[param_key]:
break
else:
# all keys in param matched 'received';
# onto next param
del all_received[idx]
break
else:
# param did not match any entry
# in all_received
equivalent = False
break
if all_params or all_received:
equivalent = False
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
if equivalent:
self.is_consumed = True
self.errormessage = None
else:
self.errormessage = self._failure_message(params) % {
'received_statement': _received_statement,
'received_parameters': _received_parameters
}
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
def _all_params(self, context):
2016-02-06 09:37:19 +00:00
if self.params:
if util.callable(self.params):
params = self.params(context)
else:
params = self.params
if not isinstance(params, list):
params = [params]
2016-02-22 11:13:36 +00:00
return params
2016-02-06 09:37:19 +00:00
else:
2016-02-22 11:13:36 +00:00
return None
def _failure_message(self, expected_params):
return (
'Testing for compiled statement %r partial params %r, '
'received %%(received_statement)r with params '
'%%(received_parameters)r' % (
self.statement.replace('%', '%%'), expected_params
)
)
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
class RegexSQL(CompiledSQL):
2016-02-06 09:37:19 +00:00
def __init__(self, regex, params=None):
SQLMatchRule.__init__(self)
self.regex = re.compile(regex)
self.orig_regex = regex
self.params = params
2016-02-22 11:13:36 +00:00
self.dialect = 'default'
def _failure_message(self, expected_params):
return (
'Testing for compiled statement ~%r partial params %r, '
'received %%(received_statement)r with params '
'%%(received_parameters)r' % (
self.orig_regex, expected_params
)
)
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
def _compare_sql(self, execute_observed, received_statement):
return bool(self.regex.match(received_statement))
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
class DialectSQL(CompiledSQL):
def _compile_dialect(self, execute_observed):
return execute_observed.context.dialect
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
def _compare_no_space(self, real_stmt, received_stmt):
stmt = re.sub(r'[\n\t]', '', real_stmt)
return received_stmt == stmt
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
def _received_statement(self, execute_observed):
received_stmt, received_params = super(DialectSQL, self).\
_received_statement(execute_observed)
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
# TODO: why do we need this part?
for real_stmt in execute_observed.statements:
if self._compare_no_space(real_stmt.statement, received_stmt):
break
2016-02-06 09:37:19 +00:00
else:
2016-02-22 11:13:36 +00:00
raise AssertionError(
"Can't locate compiled statement %r in list of "
"statements actually invoked" % received_stmt)
return received_stmt, execute_observed.context.compiled_parameters
def _compare_sql(self, execute_observed, received_statement):
stmt = re.sub(r'[\n\t]', '', self.statement)
# convert our comparison statement to have the
# paramstyle of the received
paramstyle = execute_observed.context.dialect.paramstyle
if paramstyle == 'pyformat':
stmt = re.sub(
r':([\w_]+)', r"%(\1)s", stmt)
2016-02-06 09:37:19 +00:00
else:
2016-02-22 11:13:36 +00:00
# positional params
repl = None
if paramstyle == 'qmark':
repl = "?"
elif paramstyle == 'format':
repl = r"%s"
elif paramstyle == 'numeric':
repl = None
stmt = re.sub(r':([\w_]+)', repl, stmt)
return received_statement == stmt
2016-02-06 09:37:19 +00:00
class CountStatements(AssertRule):
def __init__(self, count):
self.count = count
self._statement_count = 0
2016-02-22 11:13:36 +00:00
def process_statement(self, execute_observed):
2016-02-06 09:37:19 +00:00
self._statement_count += 1
2016-02-22 11:13:36 +00:00
def no_more_statements(self):
if self.count != self._statement_count:
assert False, 'desired statement count %d does not match %d' \
% (self.count, self._statement_count)
2016-02-06 09:37:19 +00:00
class AllOf(AssertRule):
def __init__(self, *rules):
self.rules = set(rules)
2016-02-22 11:13:36 +00:00
def process_statement(self, execute_observed):
2016-02-06 09:37:19 +00:00
for rule in list(self.rules):
2016-02-22 11:13:36 +00:00
rule.errormessage = None
rule.process_statement(execute_observed)
if rule.is_consumed:
self.rules.discard(rule)
if not self.rules:
self.is_consumed = True
break
elif not rule.errormessage:
# rule is not done yet
self.errormessage = None
break
else:
self.errormessage = list(self.rules)[0].errormessage
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
class Or(AllOf):
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
def process_statement(self, execute_observed):
for rule in self.rules:
rule.process_statement(execute_observed)
if rule.is_consumed:
self.is_consumed = True
break
else:
self.errormessage = list(self.rules)[0].errormessage
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
class SQLExecuteObserved(object):
def __init__(self, context, clauseelement, multiparams, params):
self.context = context
self.clauseelement = clauseelement
self.parameters = _distill_params(multiparams, params)
self.statements = []
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
class SQLCursorExecuteObserved(
collections.namedtuple(
"SQLCursorExecuteObserved",
["statement", "parameters", "context", "executemany"])
):
pass
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
class SQLAsserter(object):
def __init__(self):
self.accumulated = []
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
def _close(self):
self._final = self.accumulated
del self.accumulated
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
def assert_(self, *rules):
rules = list(rules)
observed = list(self._final)
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
while observed and rules:
rule = rules[0]
rule.process_statement(observed[0])
if rule.is_consumed:
rules.pop(0)
elif rule.errormessage:
assert False, rule.errormessage
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
if rule.consume_statement:
observed.pop(0)
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
if not observed and rules:
rules[0].no_more_statements()
elif not rules and observed:
assert False, "Additional SQL statements remain"
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
@contextlib.contextmanager
def assert_engine(engine):
asserter = SQLAsserter()
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
orig = []
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
@event.listens_for(engine, "before_execute")
def connection_execute(conn, clauseelement, multiparams, params):
# grab the original statement + params before any cursor
# execution
orig[:] = clauseelement, multiparams, params
2016-02-06 09:37:19 +00:00
2016-02-22 11:13:36 +00:00
@event.listens_for(engine, "after_cursor_execute")
def cursor_execute(conn, cursor, statement, parameters,
context, executemany):
if not context:
return
# then grab real cursor statements and associate them all
# around a single context
if asserter.accumulated and \
asserter.accumulated[-1].context is context:
obs = asserter.accumulated[-1]
else:
obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
asserter.accumulated.append(obs)
obs.statements.append(
SQLCursorExecuteObserved(
statement, parameters, context, executemany)
)
try:
yield asserter
finally:
event.remove(engine, "after_cursor_execute", cursor_execute)
event.remove(engine, "before_execute", connection_execute)
asserter._close()