update sqlalchemy
This commit is contained in:
parent
7365367c61
commit
3b436646a2
362 changed files with 37720 additions and 11021 deletions
|
|
@ -1,12 +1,12 @@
|
|||
# testing/__init__.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
|
||||
from .warnings import testing_warn, assert_warnings, resetwarnings
|
||||
from .warnings import assert_warnings
|
||||
|
||||
from . import config
|
||||
|
||||
|
|
@ -19,11 +19,14 @@ def against(*queries):
|
|||
return _against(config._current, *queries)
|
||||
|
||||
from .assertions import emits_warning, emits_warning_on, uses_deprecated, \
|
||||
eq_, ne_, is_, is_not_, startswith_, assert_raises, \
|
||||
eq_, ne_, le_, is_, is_not_, startswith_, assert_raises, \
|
||||
assert_raises_message, AssertsCompiledSQL, ComparesTables, \
|
||||
AssertsExecutionResults, expect_deprecated
|
||||
AssertsExecutionResults, expect_deprecated, expect_warnings, \
|
||||
in_, not_in_
|
||||
|
||||
from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict
|
||||
from .util import run_as_contextmanager, rowset, fail, \
|
||||
provide_metadata, adict, force_drop_names, \
|
||||
teardown_events
|
||||
|
||||
crashes = skip
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# testing/assertions.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
|
@ -9,79 +9,88 @@ from __future__ import absolute_import
|
|||
|
||||
from . import util as testutil
|
||||
from sqlalchemy import pool, orm, util
|
||||
from sqlalchemy.engine import default, create_engine, url
|
||||
from sqlalchemy import exc as sa_exc
|
||||
from sqlalchemy.engine import default, url
|
||||
from sqlalchemy.util import decorator
|
||||
from sqlalchemy import types as sqltypes, schema
|
||||
from sqlalchemy import types as sqltypes, schema, exc as sa_exc
|
||||
import warnings
|
||||
import re
|
||||
from .warnings import resetwarnings
|
||||
from .exclusions import db_spec, _is_excluded
|
||||
from . import assertsql
|
||||
from . import config
|
||||
import itertools
|
||||
from .util import fail
|
||||
import contextlib
|
||||
from . import mock
|
||||
|
||||
|
||||
def expect_warnings(*messages, **kw):
|
||||
"""Context manager which expects one or more warnings.
|
||||
|
||||
With no arguments, squelches all SAWarnings emitted via
|
||||
sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise
|
||||
pass string expressions that will match selected warnings via regex;
|
||||
all non-matching warnings are sent through.
|
||||
|
||||
The expect version **asserts** that the warnings were in fact seen.
|
||||
|
||||
Note that the test suite sets SAWarning warnings to raise exceptions.
|
||||
|
||||
"""
|
||||
return _expect_warnings(sa_exc.SAWarning, messages, **kw)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def expect_warnings_on(db, *messages, **kw):
|
||||
"""Context manager which expects one or more warnings on specific
|
||||
dialects.
|
||||
|
||||
The expect version **asserts** that the warnings were in fact seen.
|
||||
|
||||
"""
|
||||
spec = db_spec(db)
|
||||
|
||||
if isinstance(db, util.string_types) and not spec(config._current):
|
||||
yield
|
||||
else:
|
||||
with expect_warnings(*messages, **kw):
|
||||
yield
|
||||
|
||||
|
||||
def emits_warning(*messages):
|
||||
"""Mark a test as emitting a warning.
|
||||
"""Decorator form of expect_warnings().
|
||||
|
||||
Note that emits_warning does **not** assert that the warnings
|
||||
were in fact seen.
|
||||
|
||||
With no arguments, squelches all SAWarning failures. Or pass one or more
|
||||
strings; these will be matched to the root of the warning description by
|
||||
warnings.filterwarnings().
|
||||
"""
|
||||
# TODO: it would be nice to assert that a named warning was
|
||||
# emitted. should work with some monkeypatching of warnings,
|
||||
# and may work on non-CPython if they keep to the spirit of
|
||||
# warnings.showwarning's docstring.
|
||||
# - update: jython looks ok, it uses cpython's module
|
||||
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
# todo: should probably be strict about this, too
|
||||
filters = [dict(action='ignore',
|
||||
category=sa_exc.SAPendingDeprecationWarning)]
|
||||
if not messages:
|
||||
filters.append(dict(action='ignore',
|
||||
category=sa_exc.SAWarning))
|
||||
else:
|
||||
filters.extend(dict(action='ignore',
|
||||
message=message,
|
||||
category=sa_exc.SAWarning)
|
||||
for message in messages)
|
||||
for f in filters:
|
||||
warnings.filterwarnings(**f)
|
||||
try:
|
||||
with expect_warnings(assert_=False, *messages):
|
||||
return fn(*args, **kw)
|
||||
finally:
|
||||
resetwarnings()
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def emits_warning_on(db, *warnings):
|
||||
def expect_deprecated(*messages, **kw):
|
||||
return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw)
|
||||
|
||||
|
||||
def emits_warning_on(db, *messages):
|
||||
"""Mark a test as emitting a warning on a specific dialect.
|
||||
|
||||
With no arguments, squelches all SAWarning failures. Or pass one or more
|
||||
strings; these will be matched to the root of the warning description by
|
||||
warnings.filterwarnings().
|
||||
"""
|
||||
spec = db_spec(db)
|
||||
|
||||
Note that emits_warning_on does **not** assert that the warnings
|
||||
were in fact seen.
|
||||
|
||||
"""
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
if isinstance(db, util.string_types):
|
||||
if not spec(config._current):
|
||||
return fn(*args, **kw)
|
||||
else:
|
||||
wrapped = emits_warning(*warnings)(fn)
|
||||
return wrapped(*args, **kw)
|
||||
else:
|
||||
if not _is_excluded(*db):
|
||||
return fn(*args, **kw)
|
||||
else:
|
||||
wrapped = emits_warning(*warnings)(fn)
|
||||
return wrapped(*args, **kw)
|
||||
with expect_warnings_on(db, assert_=False, *messages):
|
||||
return fn(*args, **kw)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
|
|
@ -95,39 +104,52 @@ def uses_deprecated(*messages):
|
|||
As a special case, you may pass a function name prefixed with //
|
||||
and it will be re-written as needed to match the standard warning
|
||||
verbiage emitted by the sqlalchemy.util.deprecated decorator.
|
||||
|
||||
Note that uses_deprecated does **not** assert that the warnings
|
||||
were in fact seen.
|
||||
|
||||
"""
|
||||
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
with expect_deprecated(*messages):
|
||||
with expect_deprecated(*messages, assert_=False):
|
||||
return fn(*args, **kw)
|
||||
return decorate
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def expect_deprecated(*messages):
|
||||
# todo: should probably be strict about this, too
|
||||
filters = [dict(action='ignore',
|
||||
category=sa_exc.SAPendingDeprecationWarning)]
|
||||
if not messages:
|
||||
filters.append(dict(action='ignore',
|
||||
category=sa_exc.SADeprecationWarning))
|
||||
else:
|
||||
filters.extend(
|
||||
[dict(action='ignore',
|
||||
message=message,
|
||||
category=sa_exc.SADeprecationWarning)
|
||||
for message in
|
||||
[(m.startswith('//') and
|
||||
('Call to deprecated function ' + m[2:]) or m)
|
||||
for m in messages]])
|
||||
def _expect_warnings(exc_cls, messages, regex=True, assert_=True):
|
||||
|
||||
for f in filters:
|
||||
warnings.filterwarnings(**f)
|
||||
try:
|
||||
if regex:
|
||||
filters = [re.compile(msg, re.I | re.S) for msg in messages]
|
||||
else:
|
||||
filters = messages
|
||||
|
||||
seen = set(filters)
|
||||
|
||||
real_warn = warnings.warn
|
||||
|
||||
def our_warn(msg, exception, *arg, **kw):
|
||||
if not issubclass(exception, exc_cls):
|
||||
return real_warn(msg, exception, *arg, **kw)
|
||||
|
||||
if not filters:
|
||||
return
|
||||
|
||||
for filter_ in filters:
|
||||
if (regex and filter_.match(msg)) or \
|
||||
(not regex and filter_ == msg):
|
||||
seen.discard(filter_)
|
||||
break
|
||||
else:
|
||||
real_warn(msg, exception, *arg, **kw)
|
||||
|
||||
with mock.patch("warnings.warn", our_warn):
|
||||
yield
|
||||
finally:
|
||||
resetwarnings()
|
||||
|
||||
if assert_:
|
||||
assert not seen, "Warnings were not seen: %s" % \
|
||||
", ".join("%r" % (s.pattern if regex else s) for s in seen)
|
||||
|
||||
|
||||
def global_cleanup_assertions():
|
||||
|
|
@ -192,6 +214,11 @@ def ne_(a, b, msg=None):
|
|||
assert a != b, msg or "%r == %r" % (a, b)
|
||||
|
||||
|
||||
def le_(a, b, msg=None):
|
||||
"""Assert a <= b, with repr messaging on failure."""
|
||||
assert a <= b, msg or "%r != %r" % (a, b)
|
||||
|
||||
|
||||
def is_(a, b, msg=None):
|
||||
"""Assert a is b, with repr messaging on failure."""
|
||||
assert a is b, msg or "%r is not %r" % (a, b)
|
||||
|
|
@ -202,6 +229,16 @@ def is_not_(a, b, msg=None):
|
|||
assert a is not b, msg or "%r is %r" % (a, b)
|
||||
|
||||
|
||||
def in_(a, b, msg=None):
|
||||
"""Assert a in b, with repr messaging on failure."""
|
||||
assert a in b, msg or "%r not in %r" % (a, b)
|
||||
|
||||
|
||||
def not_in_(a, b, msg=None):
|
||||
"""Assert a in not b, with repr messaging on failure."""
|
||||
assert a not in b, msg or "%r is in %r" % (a, b)
|
||||
|
||||
|
||||
def startswith_(a, fragment, msg=None):
|
||||
"""Assert a.startswith(fragment), with repr messaging on failure."""
|
||||
assert a.startswith(fragment), msg or "%r does not start with %r" % (
|
||||
|
|
@ -233,6 +270,7 @@ class AssertsCompiledSQL(object):
|
|||
def assert_compile(self, clause, result, params=None,
|
||||
checkparams=None, dialect=None,
|
||||
checkpositional=None,
|
||||
check_prefetch=None,
|
||||
use_default_dialect=False,
|
||||
allow_dialect_select=False,
|
||||
literal_binds=False):
|
||||
|
|
@ -293,6 +331,8 @@ class AssertsCompiledSQL(object):
|
|||
if checkpositional is not None:
|
||||
p = c.construct_params(params)
|
||||
eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
|
||||
if check_prefetch is not None:
|
||||
eq_(c.prefetch, check_prefetch)
|
||||
|
||||
|
||||
class ComparesTables(object):
|
||||
|
|
@ -409,29 +449,27 @@ class AssertsExecutionResults(object):
|
|||
cls.__name__, repr(expected_item)))
|
||||
return True
|
||||
|
||||
def assert_sql_execution(self, db, callable_, *rules):
|
||||
assertsql.asserter.add_rules(rules)
|
||||
try:
|
||||
callable_()
|
||||
assertsql.asserter.statement_complete()
|
||||
finally:
|
||||
assertsql.asserter.clear_rules()
|
||||
def sql_execution_asserter(self, db=None):
|
||||
if db is None:
|
||||
from . import db as db
|
||||
|
||||
def assert_sql(self, db, callable_, list_, with_sequences=None):
|
||||
if (with_sequences is not None and
|
||||
config.db.dialect.supports_sequences):
|
||||
rules = with_sequences
|
||||
else:
|
||||
rules = list_
|
||||
return assertsql.assert_engine(db)
|
||||
|
||||
def assert_sql_execution(self, db, callable_, *rules):
|
||||
with self.sql_execution_asserter(db) as asserter:
|
||||
callable_()
|
||||
asserter.assert_(*rules)
|
||||
|
||||
def assert_sql(self, db, callable_, rules):
|
||||
|
||||
newrules = []
|
||||
for rule in rules:
|
||||
if isinstance(rule, dict):
|
||||
newrule = assertsql.AllOf(*[
|
||||
assertsql.ExactSQL(k, v) for k, v in rule.items()
|
||||
assertsql.CompiledSQL(k, v) for k, v in rule.items()
|
||||
])
|
||||
else:
|
||||
newrule = assertsql.ExactSQL(*rule)
|
||||
newrule = assertsql.CompiledSQL(*rule)
|
||||
newrules.append(newrule)
|
||||
|
||||
self.assert_sql_execution(db, callable_, *newrules)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# testing/assertsql.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
|
@ -8,84 +8,151 @@
|
|||
from ..engine.default import DefaultDialect
|
||||
from .. import util
|
||||
import re
|
||||
import collections
|
||||
import contextlib
|
||||
from .. import event
|
||||
from sqlalchemy.schema import _DDLCompiles
|
||||
from sqlalchemy.engine.util import _distill_params
|
||||
from sqlalchemy.engine import url
|
||||
|
||||
|
||||
class AssertRule(object):
|
||||
|
||||
def process_execute(self, clauseelement, *multiparams, **params):
|
||||
is_consumed = False
|
||||
errormessage = None
|
||||
consume_statement = True
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
pass
|
||||
|
||||
def process_cursor_execute(self, statement, parameters, context,
|
||||
executemany):
|
||||
pass
|
||||
|
||||
def is_consumed(self):
|
||||
"""Return True if this rule has been consumed, False if not.
|
||||
|
||||
Should raise an AssertionError if this rule's condition has
|
||||
definitely failed.
|
||||
|
||||
"""
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
def rule_passed(self):
|
||||
"""Return True if the last test of this rule passed, False if
|
||||
failed, None if no test was applied."""
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
def consume_final(self):
|
||||
"""Return True if this rule has been consumed.
|
||||
|
||||
Should raise an AssertionError if this rule's condition has not
|
||||
been consumed or has failed.
|
||||
|
||||
"""
|
||||
|
||||
if self._result is None:
|
||||
assert False, 'Rule has not been consumed'
|
||||
return self.is_consumed()
|
||||
def no_more_statements(self):
|
||||
assert False, 'All statements are complete, but pending '\
|
||||
'assertion rules remain'
|
||||
|
||||
|
||||
class SQLMatchRule(AssertRule):
|
||||
def __init__(self):
|
||||
self._result = None
|
||||
self._errmsg = ""
|
||||
|
||||
def rule_passed(self):
|
||||
return self._result
|
||||
|
||||
def is_consumed(self):
|
||||
if self._result is None:
|
||||
return False
|
||||
|
||||
assert self._result, self._errmsg
|
||||
|
||||
return True
|
||||
pass
|
||||
|
||||
|
||||
class ExactSQL(SQLMatchRule):
|
||||
class CursorSQL(SQLMatchRule):
|
||||
consume_statement = False
|
||||
|
||||
def __init__(self, sql, params=None):
|
||||
SQLMatchRule.__init__(self)
|
||||
self.sql = sql
|
||||
def __init__(self, statement, params=None):
|
||||
self.statement = statement
|
||||
self.params = params
|
||||
|
||||
def process_cursor_execute(self, statement, parameters, context,
|
||||
executemany):
|
||||
if not context:
|
||||
return
|
||||
_received_statement = \
|
||||
_process_engine_statement(context.unicode_statement,
|
||||
context)
|
||||
_received_parameters = context.compiled_parameters
|
||||
def process_statement(self, execute_observed):
|
||||
stmt = execute_observed.statements[0]
|
||||
if self.statement != stmt.statement or (
|
||||
self.params is not None and self.params != stmt.parameters):
|
||||
self.errormessage = \
|
||||
"Testing for exact SQL %s parameters %s received %s %s" % (
|
||||
self.statement, self.params,
|
||||
stmt.statement, stmt.parameters
|
||||
)
|
||||
else:
|
||||
execute_observed.statements.pop(0)
|
||||
self.is_consumed = True
|
||||
if not execute_observed.statements:
|
||||
self.consume_statement = True
|
||||
|
||||
# TODO: remove this step once all unit tests are migrated, as
|
||||
# ExactSQL should really be *exact* SQL
|
||||
|
||||
sql = _process_assertion_statement(self.sql, context)
|
||||
equivalent = _received_statement == sql
|
||||
class CompiledSQL(SQLMatchRule):
|
||||
|
||||
def __init__(self, statement, params=None, dialect='default'):
|
||||
self.statement = statement
|
||||
self.params = params
|
||||
self.dialect = dialect
|
||||
|
||||
def _compare_sql(self, execute_observed, received_statement):
|
||||
stmt = re.sub(r'[\n\t]', '', self.statement)
|
||||
return received_statement == stmt
|
||||
|
||||
def _compile_dialect(self, execute_observed):
|
||||
if self.dialect == 'default':
|
||||
return DefaultDialect()
|
||||
else:
|
||||
# ugh
|
||||
if self.dialect == 'postgresql':
|
||||
params = {'implicit_returning': True}
|
||||
else:
|
||||
params = {}
|
||||
return url.URL(self.dialect).get_dialect()(**params)
|
||||
|
||||
def _received_statement(self, execute_observed):
|
||||
"""reconstruct the statement and params in terms
|
||||
of a target dialect, which for CompiledSQL is just DefaultDialect."""
|
||||
|
||||
context = execute_observed.context
|
||||
compare_dialect = self._compile_dialect(execute_observed)
|
||||
if isinstance(context.compiled.statement, _DDLCompiles):
|
||||
compiled = \
|
||||
context.compiled.statement.compile(dialect=compare_dialect)
|
||||
else:
|
||||
compiled = (
|
||||
context.compiled.statement.compile(
|
||||
dialect=compare_dialect,
|
||||
column_keys=context.compiled.column_keys,
|
||||
inline=context.compiled.inline)
|
||||
)
|
||||
_received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled))
|
||||
parameters = execute_observed.parameters
|
||||
|
||||
if not parameters:
|
||||
_received_parameters = [compiled.construct_params()]
|
||||
else:
|
||||
_received_parameters = [
|
||||
compiled.construct_params(m) for m in parameters]
|
||||
|
||||
return _received_statement, _received_parameters
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
context = execute_observed.context
|
||||
|
||||
_received_statement, _received_parameters = \
|
||||
self._received_statement(execute_observed)
|
||||
params = self._all_params(context)
|
||||
|
||||
equivalent = self._compare_sql(execute_observed, _received_statement)
|
||||
|
||||
if equivalent:
|
||||
if params is not None:
|
||||
all_params = list(params)
|
||||
all_received = list(_received_parameters)
|
||||
while all_params and all_received:
|
||||
param = dict(all_params.pop(0))
|
||||
|
||||
for idx, received in enumerate(list(all_received)):
|
||||
# do a positive compare only
|
||||
for param_key in param:
|
||||
# a key in param did not match current
|
||||
# 'received'
|
||||
if param_key not in received or \
|
||||
received[param_key] != param[param_key]:
|
||||
break
|
||||
else:
|
||||
# all keys in param matched 'received';
|
||||
# onto next param
|
||||
del all_received[idx]
|
||||
break
|
||||
else:
|
||||
# param did not match any entry
|
||||
# in all_received
|
||||
equivalent = False
|
||||
break
|
||||
if all_params or all_received:
|
||||
equivalent = False
|
||||
|
||||
if equivalent:
|
||||
self.is_consumed = True
|
||||
self.errormessage = None
|
||||
else:
|
||||
self.errormessage = self._failure_message(params) % {
|
||||
'received_statement': _received_statement,
|
||||
'received_parameters': _received_parameters
|
||||
}
|
||||
|
||||
def _all_params(self, context):
|
||||
if self.params:
|
||||
if util.callable(self.params):
|
||||
params = self.params(context)
|
||||
|
|
@ -93,127 +160,84 @@ class ExactSQL(SQLMatchRule):
|
|||
params = self.params
|
||||
if not isinstance(params, list):
|
||||
params = [params]
|
||||
equivalent = equivalent and params \
|
||||
== context.compiled_parameters
|
||||
return params
|
||||
else:
|
||||
params = {}
|
||||
self._result = equivalent
|
||||
if not self._result:
|
||||
self._errmsg = (
|
||||
'Testing for exact statement %r exact params %r, '
|
||||
'received %r with params %r' %
|
||||
(sql, params, _received_statement, _received_parameters))
|
||||
return None
|
||||
|
||||
def _failure_message(self, expected_params):
|
||||
return (
|
||||
'Testing for compiled statement %r partial params %r, '
|
||||
'received %%(received_statement)r with params '
|
||||
'%%(received_parameters)r' % (
|
||||
self.statement.replace('%', '%%'), expected_params
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RegexSQL(SQLMatchRule):
|
||||
|
||||
class RegexSQL(CompiledSQL):
|
||||
def __init__(self, regex, params=None):
|
||||
SQLMatchRule.__init__(self)
|
||||
self.regex = re.compile(regex)
|
||||
self.orig_regex = regex
|
||||
self.params = params
|
||||
self.dialect = 'default'
|
||||
|
||||
def process_cursor_execute(self, statement, parameters, context,
|
||||
executemany):
|
||||
if not context:
|
||||
return
|
||||
_received_statement = \
|
||||
_process_engine_statement(context.unicode_statement,
|
||||
context)
|
||||
_received_parameters = context.compiled_parameters
|
||||
equivalent = bool(self.regex.match(_received_statement))
|
||||
if self.params:
|
||||
if util.callable(self.params):
|
||||
params = self.params(context)
|
||||
else:
|
||||
params = self.params
|
||||
if not isinstance(params, list):
|
||||
params = [params]
|
||||
|
||||
# do a positive compare only
|
||||
|
||||
for param, received in zip(params, _received_parameters):
|
||||
for k, v in param.items():
|
||||
if k not in received or received[k] != v:
|
||||
equivalent = False
|
||||
break
|
||||
else:
|
||||
params = {}
|
||||
self._result = equivalent
|
||||
if not self._result:
|
||||
self._errmsg = \
|
||||
'Testing for regex %r partial params %r, received %r '\
|
||||
'with params %r' % (self.orig_regex, params,
|
||||
_received_statement,
|
||||
_received_parameters)
|
||||
|
||||
|
||||
class CompiledSQL(SQLMatchRule):
|
||||
|
||||
def __init__(self, statement, params=None):
|
||||
SQLMatchRule.__init__(self)
|
||||
self.statement = statement
|
||||
self.params = params
|
||||
|
||||
def process_cursor_execute(self, statement, parameters, context,
|
||||
executemany):
|
||||
if not context:
|
||||
return
|
||||
from sqlalchemy.schema import _DDLCompiles
|
||||
_received_parameters = list(context.compiled_parameters)
|
||||
|
||||
# recompile from the context, using the default dialect
|
||||
|
||||
if isinstance(context.compiled.statement, _DDLCompiles):
|
||||
compiled = \
|
||||
context.compiled.statement.compile(dialect=DefaultDialect())
|
||||
else:
|
||||
compiled = (
|
||||
context.compiled.statement.compile(
|
||||
dialect=DefaultDialect(),
|
||||
column_keys=context.compiled.column_keys)
|
||||
def _failure_message(self, expected_params):
|
||||
return (
|
||||
'Testing for compiled statement ~%r partial params %r, '
|
||||
'received %%(received_statement)r with params '
|
||||
'%%(received_parameters)r' % (
|
||||
self.orig_regex, expected_params
|
||||
)
|
||||
_received_statement = re.sub(r'[\n\t]', '', str(compiled))
|
||||
equivalent = self.statement == _received_statement
|
||||
if self.params:
|
||||
if util.callable(self.params):
|
||||
params = self.params(context)
|
||||
else:
|
||||
params = self.params
|
||||
if not isinstance(params, list):
|
||||
params = [params]
|
||||
else:
|
||||
params = list(params)
|
||||
all_params = list(params)
|
||||
all_received = list(_received_parameters)
|
||||
while params:
|
||||
param = dict(params.pop(0))
|
||||
for k, v in context.compiled.params.items():
|
||||
param.setdefault(k, v)
|
||||
if param not in _received_parameters:
|
||||
equivalent = False
|
||||
break
|
||||
else:
|
||||
_received_parameters.remove(param)
|
||||
if _received_parameters:
|
||||
equivalent = False
|
||||
else:
|
||||
params = {}
|
||||
all_params = {}
|
||||
all_received = []
|
||||
self._result = equivalent
|
||||
if not self._result:
|
||||
print('Testing for compiled statement %r partial params '
|
||||
'%r, received %r with params %r' %
|
||||
(self.statement, all_params,
|
||||
_received_statement, all_received))
|
||||
self._errmsg = (
|
||||
'Testing for compiled statement %r partial params %r, '
|
||||
'received %r with params %r' %
|
||||
(self.statement, all_params,
|
||||
_received_statement, all_received))
|
||||
)
|
||||
|
||||
# print self._errmsg
|
||||
def _compare_sql(self, execute_observed, received_statement):
|
||||
return bool(self.regex.match(received_statement))
|
||||
|
||||
|
||||
class DialectSQL(CompiledSQL):
|
||||
def _compile_dialect(self, execute_observed):
|
||||
return execute_observed.context.dialect
|
||||
|
||||
def _compare_no_space(self, real_stmt, received_stmt):
|
||||
stmt = re.sub(r'[\n\t]', '', real_stmt)
|
||||
return received_stmt == stmt
|
||||
|
||||
def _received_statement(self, execute_observed):
|
||||
received_stmt, received_params = super(DialectSQL, self).\
|
||||
_received_statement(execute_observed)
|
||||
|
||||
# TODO: why do we need this part?
|
||||
for real_stmt in execute_observed.statements:
|
||||
if self._compare_no_space(real_stmt.statement, received_stmt):
|
||||
break
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Can't locate compiled statement %r in list of "
|
||||
"statements actually invoked" % received_stmt)
|
||||
|
||||
return received_stmt, execute_observed.context.compiled_parameters
|
||||
|
||||
def _compare_sql(self, execute_observed, received_statement):
|
||||
stmt = re.sub(r'[\n\t]', '', self.statement)
|
||||
# convert our comparison statement to have the
|
||||
# paramstyle of the received
|
||||
paramstyle = execute_observed.context.dialect.paramstyle
|
||||
if paramstyle == 'pyformat':
|
||||
stmt = re.sub(
|
||||
r':([\w_]+)', r"%(\1)s", stmt)
|
||||
else:
|
||||
# positional params
|
||||
repl = None
|
||||
if paramstyle == 'qmark':
|
||||
repl = "?"
|
||||
elif paramstyle == 'format':
|
||||
repl = r"%s"
|
||||
elif paramstyle == 'numeric':
|
||||
repl = None
|
||||
stmt = re.sub(r':([\w_]+)', repl, stmt)
|
||||
|
||||
return received_statement == stmt
|
||||
|
||||
|
||||
class CountStatements(AssertRule):
|
||||
|
|
@ -222,21 +246,13 @@ class CountStatements(AssertRule):
|
|||
self.count = count
|
||||
self._statement_count = 0
|
||||
|
||||
def process_execute(self, clauseelement, *multiparams, **params):
|
||||
def process_statement(self, execute_observed):
|
||||
self._statement_count += 1
|
||||
|
||||
def process_cursor_execute(self, statement, parameters, context,
|
||||
executemany):
|
||||
pass
|
||||
|
||||
def is_consumed(self):
|
||||
return False
|
||||
|
||||
def consume_final(self):
|
||||
assert self.count == self._statement_count, \
|
||||
'desired statement count %d does not match %d' \
|
||||
% (self.count, self._statement_count)
|
||||
return True
|
||||
def no_more_statements(self):
|
||||
if self.count != self._statement_count:
|
||||
assert False, 'desired statement count %d does not match %d' \
|
||||
% (self.count, self._statement_count)
|
||||
|
||||
|
||||
class AllOf(AssertRule):
|
||||
|
|
@ -244,116 +260,113 @@ class AllOf(AssertRule):
|
|||
def __init__(self, *rules):
|
||||
self.rules = set(rules)
|
||||
|
||||
def process_execute(self, clauseelement, *multiparams, **params):
|
||||
for rule in self.rules:
|
||||
rule.process_execute(clauseelement, *multiparams, **params)
|
||||
|
||||
def process_cursor_execute(self, statement, parameters, context,
|
||||
executemany):
|
||||
for rule in self.rules:
|
||||
rule.process_cursor_execute(statement, parameters, context,
|
||||
executemany)
|
||||
|
||||
def is_consumed(self):
|
||||
if not self.rules:
|
||||
return True
|
||||
def process_statement(self, execute_observed):
|
||||
for rule in list(self.rules):
|
||||
if rule.rule_passed(): # a rule passed, move on
|
||||
self.rules.remove(rule)
|
||||
return len(self.rules) == 0
|
||||
return False
|
||||
|
||||
def rule_passed(self):
|
||||
return self.is_consumed()
|
||||
|
||||
def consume_final(self):
|
||||
return len(self.rules) == 0
|
||||
rule.errormessage = None
|
||||
rule.process_statement(execute_observed)
|
||||
if rule.is_consumed:
|
||||
self.rules.discard(rule)
|
||||
if not self.rules:
|
||||
self.is_consumed = True
|
||||
break
|
||||
elif not rule.errormessage:
|
||||
# rule is not done yet
|
||||
self.errormessage = None
|
||||
break
|
||||
else:
|
||||
self.errormessage = list(self.rules)[0].errormessage
|
||||
|
||||
|
||||
class Or(AllOf):
|
||||
def __init__(self, *rules):
|
||||
self.rules = set(rules)
|
||||
self._consume_final = False
|
||||
|
||||
def is_consumed(self):
|
||||
if not self.rules:
|
||||
return True
|
||||
for rule in list(self.rules):
|
||||
if rule.rule_passed(): # a rule passed
|
||||
self._consume_final = True
|
||||
return True
|
||||
return False
|
||||
|
||||
def consume_final(self):
|
||||
assert self._consume_final, "Unsatisified rules remain"
|
||||
|
||||
|
||||
def _process_engine_statement(query, context):
|
||||
if util.jython:
|
||||
|
||||
# oracle+zxjdbc passes a PyStatement when returning into
|
||||
|
||||
query = str(query)
|
||||
if context.engine.name == 'mssql' \
|
||||
and query.endswith('; select scope_identity()'):
|
||||
query = query[:-25]
|
||||
query = re.sub(r'\n', '', query)
|
||||
return query
|
||||
|
||||
|
||||
def _process_assertion_statement(query, context):
|
||||
paramstyle = context.dialect.paramstyle
|
||||
if paramstyle == 'named':
|
||||
pass
|
||||
elif paramstyle == 'pyformat':
|
||||
query = re.sub(r':([\w_]+)', r"%(\1)s", query)
|
||||
else:
|
||||
# positional params
|
||||
repl = None
|
||||
if paramstyle == 'qmark':
|
||||
repl = "?"
|
||||
elif paramstyle == 'format':
|
||||
repl = r"%s"
|
||||
elif paramstyle == 'numeric':
|
||||
repl = None
|
||||
query = re.sub(r':([\w_]+)', repl, query)
|
||||
|
||||
return query
|
||||
|
||||
|
||||
class SQLAssert(object):
|
||||
|
||||
rules = None
|
||||
|
||||
def add_rules(self, rules):
|
||||
self.rules = list(rules)
|
||||
|
||||
def statement_complete(self):
|
||||
def process_statement(self, execute_observed):
|
||||
for rule in self.rules:
|
||||
if not rule.consume_final():
|
||||
assert False, \
|
||||
'All statements are complete, but pending '\
|
||||
'assertion rules remain'
|
||||
rule.process_statement(execute_observed)
|
||||
if rule.is_consumed:
|
||||
self.is_consumed = True
|
||||
break
|
||||
else:
|
||||
self.errormessage = list(self.rules)[0].errormessage
|
||||
|
||||
def clear_rules(self):
|
||||
del self.rules
|
||||
|
||||
def execute(self, conn, clauseelement, multiparams, params, result):
|
||||
if self.rules is not None:
|
||||
if not self.rules:
|
||||
assert False, \
|
||||
'All rules have been exhausted, but further '\
|
||||
'statements remain'
|
||||
rule = self.rules[0]
|
||||
rule.process_execute(clauseelement, *multiparams, **params)
|
||||
if rule.is_consumed():
|
||||
self.rules.pop(0)
|
||||
class SQLExecuteObserved(object):
|
||||
def __init__(self, context, clauseelement, multiparams, params):
|
||||
self.context = context
|
||||
self.clauseelement = clauseelement
|
||||
self.parameters = _distill_params(multiparams, params)
|
||||
self.statements = []
|
||||
|
||||
def cursor_execute(self, conn, cursor, statement, parameters,
|
||||
|
||||
class SQLCursorExecuteObserved(
|
||||
collections.namedtuple(
|
||||
"SQLCursorExecuteObserved",
|
||||
["statement", "parameters", "context", "executemany"])
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class SQLAsserter(object):
|
||||
def __init__(self):
|
||||
self.accumulated = []
|
||||
|
||||
def _close(self):
|
||||
self._final = self.accumulated
|
||||
del self.accumulated
|
||||
|
||||
def assert_(self, *rules):
|
||||
rules = list(rules)
|
||||
observed = list(self._final)
|
||||
|
||||
while observed and rules:
|
||||
rule = rules[0]
|
||||
rule.process_statement(observed[0])
|
||||
if rule.is_consumed:
|
||||
rules.pop(0)
|
||||
elif rule.errormessage:
|
||||
assert False, rule.errormessage
|
||||
|
||||
if rule.consume_statement:
|
||||
observed.pop(0)
|
||||
|
||||
if not observed and rules:
|
||||
rules[0].no_more_statements()
|
||||
elif not rules and observed:
|
||||
assert False, "Additional SQL statements remain"
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assert_engine(engine):
|
||||
asserter = SQLAsserter()
|
||||
|
||||
orig = []
|
||||
|
||||
@event.listens_for(engine, "before_execute")
|
||||
def connection_execute(conn, clauseelement, multiparams, params):
|
||||
# grab the original statement + params before any cursor
|
||||
# execution
|
||||
orig[:] = clauseelement, multiparams, params
|
||||
|
||||
@event.listens_for(engine, "after_cursor_execute")
|
||||
def cursor_execute(conn, cursor, statement, parameters,
|
||||
context, executemany):
|
||||
if self.rules:
|
||||
rule = self.rules[0]
|
||||
rule.process_cursor_execute(statement, parameters, context,
|
||||
executemany)
|
||||
if not context:
|
||||
return
|
||||
# then grab real cursor statements and associate them all
|
||||
# around a single context
|
||||
if asserter.accumulated and \
|
||||
asserter.accumulated[-1].context is context:
|
||||
obs = asserter.accumulated[-1]
|
||||
else:
|
||||
obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
|
||||
asserter.accumulated.append(obs)
|
||||
obs.statements.append(
|
||||
SQLCursorExecuteObserved(
|
||||
statement, parameters, context, executemany)
|
||||
)
|
||||
|
||||
asserter = SQLAssert()
|
||||
try:
|
||||
yield asserter
|
||||
finally:
|
||||
event.remove(engine, "after_cursor_execute", cursor_execute)
|
||||
event.remove(engine, "before_execute", connection_execute)
|
||||
asserter._close()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# testing/config.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
|
@ -12,8 +12,10 @@ db = None
|
|||
db_url = None
|
||||
db_opts = None
|
||||
file_config = None
|
||||
|
||||
test_schema = None
|
||||
test_schema_2 = None
|
||||
_current = None
|
||||
_skip_test_exception = None
|
||||
|
||||
|
||||
class Config(object):
|
||||
|
|
@ -22,12 +24,14 @@ class Config(object):
|
|||
self.db_opts = db_opts
|
||||
self.options = options
|
||||
self.file_config = file_config
|
||||
self.test_schema = "test_schema"
|
||||
self.test_schema_2 = "test_schema_2"
|
||||
|
||||
_stack = collections.deque()
|
||||
_configs = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, db, db_opts, options, file_config, namespace):
|
||||
def register(cls, db, db_opts, options, file_config):
|
||||
"""add a config as one of the global configs.
|
||||
|
||||
If there are no configs set up yet, this config also
|
||||
|
|
@ -35,18 +39,19 @@ class Config(object):
|
|||
"""
|
||||
cfg = Config(db, db_opts, options, file_config)
|
||||
|
||||
global _current
|
||||
if not _current:
|
||||
cls.set_as_current(cfg, namespace)
|
||||
cls._configs[cfg.db.name] = cfg
|
||||
cls._configs[(cfg.db.name, cfg.db.dialect)] = cfg
|
||||
cls._configs[cfg.db] = cfg
|
||||
return cfg
|
||||
|
||||
@classmethod
|
||||
def set_as_current(cls, config, namespace):
|
||||
global db, _current, db_url
|
||||
global db, _current, db_url, test_schema, test_schema_2, db_opts
|
||||
_current = config
|
||||
db_url = config.db.url
|
||||
db_opts = config.db_opts
|
||||
test_schema = config.test_schema
|
||||
test_schema_2 = config.test_schema_2
|
||||
namespace.db = db = config.db
|
||||
|
||||
@classmethod
|
||||
|
|
@ -78,3 +83,10 @@ class Config(object):
|
|||
def all_dbs(cls):
|
||||
for cfg in cls.all_configs():
|
||||
yield cfg.db
|
||||
|
||||
def skip_test(self, msg):
|
||||
skip_test(msg)
|
||||
|
||||
|
||||
def skip_test(msg):
|
||||
raise _skip_test_exception(msg)
|
||||
|
|
|
|||
|
|
@ -8,4 +8,4 @@ import pytest
|
|||
|
||||
class TestSuite(unittest.TestCase):
|
||||
def test_sqlalchemy(self):
|
||||
pytest.main()
|
||||
pytest.main(["-n", "4", "-q"])
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# testing/engines.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
|
@ -7,15 +7,12 @@
|
|||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import types
|
||||
import weakref
|
||||
from collections import deque
|
||||
from . import config
|
||||
from .util import decorator
|
||||
from .. import event, pool
|
||||
import re
|
||||
import warnings
|
||||
from .. import util
|
||||
|
||||
|
||||
class ConnectionKiller(object):
|
||||
|
|
@ -40,8 +37,6 @@ class ConnectionKiller(object):
|
|||
def _safe(self, fn):
|
||||
try:
|
||||
fn()
|
||||
except (SystemExit, KeyboardInterrupt):
|
||||
raise
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
"testing_reaper couldn't "
|
||||
|
|
@ -103,7 +98,14 @@ def drop_all_tables(metadata, bind):
|
|||
testing_reaper.close_all()
|
||||
if hasattr(bind, 'close'):
|
||||
bind.close()
|
||||
metadata.drop_all(bind)
|
||||
|
||||
if not config.db.dialect.supports_alter:
|
||||
from . import assertions
|
||||
with assertions.expect_warnings(
|
||||
"Can't sort tables", assert_=False):
|
||||
metadata.drop_all(bind)
|
||||
else:
|
||||
metadata.drop_all(bind)
|
||||
|
||||
|
||||
@decorator
|
||||
|
|
@ -171,8 +173,6 @@ class ReconnectFixture(object):
|
|||
def _safe(self, fn):
|
||||
try:
|
||||
fn()
|
||||
except (SystemExit, KeyboardInterrupt):
|
||||
raise
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
"ReconnectFixture couldn't "
|
||||
|
|
@ -211,7 +211,7 @@ def testing_engine(url=None, options=None):
|
|||
"""Produce an engine configured by --options with optional overrides."""
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from .assertsql import asserter
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
if not options:
|
||||
use_reaper = True
|
||||
|
|
@ -219,15 +219,20 @@ def testing_engine(url=None, options=None):
|
|||
use_reaper = options.pop('use_reaper', True)
|
||||
|
||||
url = url or config.db.url
|
||||
|
||||
url = make_url(url)
|
||||
if options is None:
|
||||
options = config.db_opts
|
||||
if config.db is None or url.drivername == config.db.url.drivername:
|
||||
options = config.db_opts
|
||||
else:
|
||||
options = {}
|
||||
|
||||
engine = create_engine(url, **options)
|
||||
engine._has_events = True # enable event blocks, helps with profiling
|
||||
|
||||
if isinstance(engine.pool, pool.QueuePool):
|
||||
engine.pool._timeout = 0
|
||||
engine.pool._max_overflow = 0
|
||||
event.listen(engine, 'after_execute', asserter.execute)
|
||||
event.listen(engine, 'after_cursor_execute', asserter.cursor_execute)
|
||||
if use_reaper:
|
||||
event.listen(engine.pool, 'connect', testing_reaper.connect)
|
||||
event.listen(engine.pool, 'checkout', testing_reaper.checkout)
|
||||
|
|
@ -287,10 +292,10 @@ class DBAPIProxyCursor(object):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, engine, conn):
|
||||
def __init__(self, engine, conn, *args, **kwargs):
|
||||
self.engine = engine
|
||||
self.connection = conn
|
||||
self.cursor = conn.cursor()
|
||||
self.cursor = conn.cursor(*args, **kwargs)
|
||||
|
||||
def execute(self, stmt, parameters=None, **kw):
|
||||
if parameters:
|
||||
|
|
@ -318,8 +323,8 @@ class DBAPIProxyConnection(object):
|
|||
self.engine = engine
|
||||
self.cursor_cls = cursor_cls
|
||||
|
||||
def cursor(self):
|
||||
return self.cursor_cls(self.engine, self.conn)
|
||||
def cursor(self, *args, **kwargs):
|
||||
return self.cursor_cls(self.engine, self.conn, *args, **kwargs)
|
||||
|
||||
def close(self):
|
||||
self.conn.close()
|
||||
|
|
@ -339,112 +344,3 @@ def proxying_engine(conn_cls=DBAPIProxyConnection,
|
|||
return testing_engine(options={'creator': mock_conn})
|
||||
|
||||
|
||||
class ReplayableSession(object):
|
||||
"""A simple record/playback tool.
|
||||
|
||||
This is *not* a mock testing class. It only records a session for later
|
||||
playback and makes no assertions on call consistency whatsoever. It's
|
||||
unlikely to be suitable for anything other than DB-API recording.
|
||||
|
||||
"""
|
||||
|
||||
Callable = object()
|
||||
NoAttribute = object()
|
||||
|
||||
if util.py2k:
|
||||
Natives = set([getattr(types, t)
|
||||
for t in dir(types) if not t.startswith('_')]).\
|
||||
difference([getattr(types, t)
|
||||
for t in ('FunctionType', 'BuiltinFunctionType',
|
||||
'MethodType', 'BuiltinMethodType',
|
||||
'LambdaType', 'UnboundMethodType',)])
|
||||
else:
|
||||
Natives = set([getattr(types, t)
|
||||
for t in dir(types) if not t.startswith('_')]).\
|
||||
union([type(t) if not isinstance(t, type)
|
||||
else t for t in __builtins__.values()]).\
|
||||
difference([getattr(types, t)
|
||||
for t in ('FunctionType', 'BuiltinFunctionType',
|
||||
'MethodType', 'BuiltinMethodType',
|
||||
'LambdaType', )])
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = deque()
|
||||
|
||||
def recorder(self, base):
|
||||
return self.Recorder(self.buffer, base)
|
||||
|
||||
def player(self):
|
||||
return self.Player(self.buffer)
|
||||
|
||||
class Recorder(object):
|
||||
def __init__(self, buffer, subject):
|
||||
self._buffer = buffer
|
||||
self._subject = subject
|
||||
|
||||
def __call__(self, *args, **kw):
|
||||
subject, buffer = [object.__getattribute__(self, x)
|
||||
for x in ('_subject', '_buffer')]
|
||||
|
||||
result = subject(*args, **kw)
|
||||
if type(result) not in ReplayableSession.Natives:
|
||||
buffer.append(ReplayableSession.Callable)
|
||||
return type(self)(buffer, result)
|
||||
else:
|
||||
buffer.append(result)
|
||||
return result
|
||||
|
||||
@property
|
||||
def _sqla_unwrap(self):
|
||||
return self._subject
|
||||
|
||||
def __getattribute__(self, key):
|
||||
try:
|
||||
return object.__getattribute__(self, key)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
subject, buffer = [object.__getattribute__(self, x)
|
||||
for x in ('_subject', '_buffer')]
|
||||
try:
|
||||
result = type(subject).__getattribute__(subject, key)
|
||||
except AttributeError:
|
||||
buffer.append(ReplayableSession.NoAttribute)
|
||||
raise
|
||||
else:
|
||||
if type(result) not in ReplayableSession.Natives:
|
||||
buffer.append(ReplayableSession.Callable)
|
||||
return type(self)(buffer, result)
|
||||
else:
|
||||
buffer.append(result)
|
||||
return result
|
||||
|
||||
class Player(object):
|
||||
def __init__(self, buffer):
|
||||
self._buffer = buffer
|
||||
|
||||
def __call__(self, *args, **kw):
|
||||
buffer = object.__getattribute__(self, '_buffer')
|
||||
result = buffer.popleft()
|
||||
if result is ReplayableSession.Callable:
|
||||
return self
|
||||
else:
|
||||
return result
|
||||
|
||||
@property
|
||||
def _sqla_unwrap(self):
|
||||
return None
|
||||
|
||||
def __getattribute__(self, key):
|
||||
try:
|
||||
return object.__getattribute__(self, key)
|
||||
except AttributeError:
|
||||
pass
|
||||
buffer = object.__getattribute__(self, '_buffer')
|
||||
result = buffer.popleft()
|
||||
if result is ReplayableSession.Callable:
|
||||
return self
|
||||
elif result is ReplayableSession.NoAttribute:
|
||||
raise AttributeError(key)
|
||||
else:
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# testing/entities.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# testing/exclusions.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
|
@ -7,87 +7,161 @@
|
|||
|
||||
|
||||
import operator
|
||||
from .plugin.plugin_base import SkipTest
|
||||
from ..util import decorator
|
||||
from . import config
|
||||
from .. import util
|
||||
import contextlib
|
||||
import inspect
|
||||
import contextlib
|
||||
from sqlalchemy.util.compat import inspect_getargspec
|
||||
|
||||
|
||||
class skip_if(object):
|
||||
def __init__(self, predicate, reason=None):
|
||||
self.predicate = _as_predicate(predicate)
|
||||
self.reason = reason
|
||||
def skip_if(predicate, reason=None):
|
||||
rule = compound()
|
||||
pred = _as_predicate(predicate, reason)
|
||||
rule.skips.add(pred)
|
||||
return rule
|
||||
|
||||
_fails_on = None
|
||||
|
||||
def fails_if(predicate, reason=None):
|
||||
rule = compound()
|
||||
pred = _as_predicate(predicate, reason)
|
||||
rule.fails.add(pred)
|
||||
return rule
|
||||
|
||||
|
||||
class compound(object):
|
||||
def __init__(self):
|
||||
self.fails = set()
|
||||
self.skips = set()
|
||||
self.tags = set()
|
||||
|
||||
def __add__(self, other):
|
||||
def decorate(fn):
|
||||
return other(self(fn))
|
||||
return decorate
|
||||
return self.add(other)
|
||||
|
||||
def add(self, *others):
|
||||
copy = compound()
|
||||
copy.fails.update(self.fails)
|
||||
copy.skips.update(self.skips)
|
||||
copy.tags.update(self.tags)
|
||||
for other in others:
|
||||
copy.fails.update(other.fails)
|
||||
copy.skips.update(other.skips)
|
||||
copy.tags.update(other.tags)
|
||||
return copy
|
||||
|
||||
def not_(self):
|
||||
copy = compound()
|
||||
copy.fails.update(NotPredicate(fail) for fail in self.fails)
|
||||
copy.skips.update(NotPredicate(skip) for skip in self.skips)
|
||||
copy.tags.update(self.tags)
|
||||
return copy
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.enabled_for_config(config._current)
|
||||
|
||||
def enabled_for_config(self, config):
|
||||
return not self.predicate(config)
|
||||
for predicate in self.skips.union(self.fails):
|
||||
if predicate(config):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def matching_config_reasons(self, config):
|
||||
return [
|
||||
predicate._as_string(config) for predicate
|
||||
in self.skips.union(self.fails)
|
||||
if predicate(config)
|
||||
]
|
||||
|
||||
def include_test(self, include_tags, exclude_tags):
|
||||
return bool(
|
||||
not self.tags.intersection(exclude_tags) and
|
||||
(not include_tags or self.tags.intersection(include_tags))
|
||||
)
|
||||
|
||||
def _extend(self, other):
|
||||
self.skips.update(other.skips)
|
||||
self.fails.update(other.fails)
|
||||
self.tags.update(other.tags)
|
||||
|
||||
def __call__(self, fn):
|
||||
if hasattr(fn, '_sa_exclusion_extend'):
|
||||
fn._sa_exclusion_extend._extend(self)
|
||||
return fn
|
||||
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
return self._do(config._current, fn, *args, **kw)
|
||||
decorated = decorate(fn)
|
||||
decorated._sa_exclusion_extend = self
|
||||
return decorated
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fail_if(self, name='block'):
|
||||
def fail_if(self):
|
||||
all_fails = compound()
|
||||
all_fails.fails.update(self.skips.union(self.fails))
|
||||
|
||||
try:
|
||||
yield
|
||||
except Exception as ex:
|
||||
if self.predicate(config._current):
|
||||
print(("%s failed as expected (%s): %s " % (
|
||||
name, self.predicate, str(ex))))
|
||||
else:
|
||||
raise
|
||||
all_fails._expect_failure(config._current, ex)
|
||||
else:
|
||||
if self.predicate(config._current):
|
||||
raise AssertionError(
|
||||
"Unexpected success for '%s' (%s)" %
|
||||
(name, self.predicate))
|
||||
all_fails._expect_success(config._current)
|
||||
|
||||
def __call__(self, fn):
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
if self.predicate(config._current):
|
||||
if self.reason:
|
||||
msg = "'%s' : %s" % (
|
||||
fn.__name__,
|
||||
self.reason
|
||||
def _do(self, config, fn, *args, **kw):
|
||||
for skip in self.skips:
|
||||
if skip(config):
|
||||
msg = "'%s' : %s" % (
|
||||
fn.__name__,
|
||||
skip._as_string(config)
|
||||
)
|
||||
config.skip_test(msg)
|
||||
|
||||
try:
|
||||
return_value = fn(*args, **kw)
|
||||
except Exception as ex:
|
||||
self._expect_failure(config, ex, name=fn.__name__)
|
||||
else:
|
||||
self._expect_success(config, name=fn.__name__)
|
||||
return return_value
|
||||
|
||||
def _expect_failure(self, config, ex, name='block'):
|
||||
for fail in self.fails:
|
||||
if fail(config):
|
||||
print(("%s failed as expected (%s): %s " % (
|
||||
name, fail._as_string(config), str(ex))))
|
||||
break
|
||||
else:
|
||||
util.raise_from_cause(ex)
|
||||
|
||||
def _expect_success(self, config, name='block'):
|
||||
if not self.fails:
|
||||
return
|
||||
for fail in self.fails:
|
||||
if not fail(config):
|
||||
break
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Unexpected success for '%s' (%s)" %
|
||||
(
|
||||
name,
|
||||
" and ".join(
|
||||
fail._as_string(config)
|
||||
for fail in self.fails
|
||||
)
|
||||
else:
|
||||
msg = "'%s': %s" % (
|
||||
fn.__name__, self.predicate
|
||||
)
|
||||
raise SkipTest(msg)
|
||||
else:
|
||||
if self._fails_on:
|
||||
with self._fails_on.fail_if(name=fn.__name__):
|
||||
return fn(*args, **kw)
|
||||
else:
|
||||
return fn(*args, **kw)
|
||||
return decorate(fn)
|
||||
|
||||
def fails_on(self, other, reason=None):
|
||||
self._fails_on = skip_if(other, reason)
|
||||
return self
|
||||
|
||||
def fails_on_everything_except(self, *dbs):
|
||||
self._fails_on = skip_if(fails_on_everything_except(*dbs))
|
||||
return self
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class fails_if(skip_if):
|
||||
def __call__(self, fn):
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
with self.fail_if(name=fn.__name__):
|
||||
return fn(*args, **kw)
|
||||
return decorate(fn)
|
||||
def requires_tag(tagname):
|
||||
return tags([tagname])
|
||||
|
||||
|
||||
def tags(tagnames):
|
||||
comp = compound()
|
||||
comp.tags.update(tagnames)
|
||||
return comp
|
||||
|
||||
|
||||
def only_if(predicate, reason=None):
|
||||
|
|
@ -102,13 +176,17 @@ def succeeds_if(predicate, reason=None):
|
|||
|
||||
class Predicate(object):
|
||||
@classmethod
|
||||
def as_predicate(cls, predicate):
|
||||
if isinstance(predicate, skip_if):
|
||||
return NotPredicate(predicate.predicate)
|
||||
def as_predicate(cls, predicate, description=None):
|
||||
if isinstance(predicate, compound):
|
||||
return cls.as_predicate(predicate.enabled_for_config, description)
|
||||
elif isinstance(predicate, Predicate):
|
||||
if description and predicate.description is None:
|
||||
predicate.description = description
|
||||
return predicate
|
||||
elif isinstance(predicate, list):
|
||||
return OrPredicate([cls.as_predicate(pred) for pred in predicate])
|
||||
elif isinstance(predicate, (list, set)):
|
||||
return OrPredicate(
|
||||
[cls.as_predicate(pred) for pred in predicate],
|
||||
description)
|
||||
elif isinstance(predicate, tuple):
|
||||
return SpecPredicate(*predicate)
|
||||
elif isinstance(predicate, util.string_types):
|
||||
|
|
@ -119,12 +197,26 @@ class Predicate(object):
|
|||
op = tokens.pop(0)
|
||||
if tokens:
|
||||
spec = tuple(int(d) for d in tokens.pop(0).split("."))
|
||||
return SpecPredicate(db, op, spec)
|
||||
return SpecPredicate(db, op, spec, description=description)
|
||||
elif util.callable(predicate):
|
||||
return LambdaPredicate(predicate)
|
||||
return LambdaPredicate(predicate, description)
|
||||
else:
|
||||
assert False, "unknown predicate type: %s" % predicate
|
||||
|
||||
def _format_description(self, config, negate=False):
|
||||
bool_ = self(config)
|
||||
if negate:
|
||||
bool_ = not negate
|
||||
return self.description % {
|
||||
"driver": config.db.url.get_driver_name(),
|
||||
"database": config.db.url.get_backend_name(),
|
||||
"doesnt_support": "doesn't support" if bool_ else "does support",
|
||||
"does_support": "does support" if bool_ else "doesn't support"
|
||||
}
|
||||
|
||||
def _as_string(self, config=None, negate=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class BooleanPredicate(Predicate):
|
||||
def __init__(self, value, description=None):
|
||||
|
|
@ -134,14 +226,8 @@ class BooleanPredicate(Predicate):
|
|||
def __call__(self, config):
|
||||
return self.value
|
||||
|
||||
def _as_string(self, negate=False):
|
||||
if negate:
|
||||
return "not " + self.description
|
||||
else:
|
||||
return self.description
|
||||
|
||||
def __str__(self):
|
||||
return self._as_string()
|
||||
def _as_string(self, config, negate=False):
|
||||
return self._format_description(config, negate=negate)
|
||||
|
||||
|
||||
class SpecPredicate(Predicate):
|
||||
|
|
@ -185,9 +271,9 @@ class SpecPredicate(Predicate):
|
|||
else:
|
||||
return True
|
||||
|
||||
def _as_string(self, negate=False):
|
||||
def _as_string(self, config, negate=False):
|
||||
if self.description is not None:
|
||||
return self.description
|
||||
return self._format_description(config)
|
||||
elif self.op is None:
|
||||
if negate:
|
||||
return "not %s" % self.db
|
||||
|
|
@ -207,13 +293,10 @@ class SpecPredicate(Predicate):
|
|||
self.spec
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return self._as_string()
|
||||
|
||||
|
||||
class LambdaPredicate(Predicate):
|
||||
def __init__(self, lambda_, description=None, args=None, kw=None):
|
||||
spec = inspect.getargspec(lambda_)
|
||||
spec = inspect_getargspec(lambda_)
|
||||
if not spec[0]:
|
||||
self.lambda_ = lambda db: lambda_()
|
||||
else:
|
||||
|
|
@ -230,25 +313,23 @@ class LambdaPredicate(Predicate):
|
|||
def __call__(self, config):
|
||||
return self.lambda_(config)
|
||||
|
||||
def _as_string(self, negate=False):
|
||||
if negate:
|
||||
return "not " + self.description
|
||||
else:
|
||||
return self.description
|
||||
|
||||
def __str__(self):
|
||||
return self._as_string()
|
||||
def _as_string(self, config, negate=False):
|
||||
return self._format_description(config)
|
||||
|
||||
|
||||
class NotPredicate(Predicate):
|
||||
def __init__(self, predicate):
|
||||
def __init__(self, predicate, description=None):
|
||||
self.predicate = predicate
|
||||
self.description = description
|
||||
|
||||
def __call__(self, config):
|
||||
return not self.predicate(config)
|
||||
|
||||
def __str__(self):
|
||||
return self.predicate._as_string(True)
|
||||
def _as_string(self, config, negate=False):
|
||||
if self.description:
|
||||
return self._format_description(config, not negate)
|
||||
else:
|
||||
return self.predicate._as_string(config, not negate)
|
||||
|
||||
|
||||
class OrPredicate(Predicate):
|
||||
|
|
@ -259,40 +340,32 @@ class OrPredicate(Predicate):
|
|||
def __call__(self, config):
|
||||
for pred in self.predicates:
|
||||
if pred(config):
|
||||
self._str = pred
|
||||
return True
|
||||
return False
|
||||
|
||||
_str = None
|
||||
|
||||
def _eval_str(self, negate=False):
|
||||
if self._str is None:
|
||||
if negate:
|
||||
conjunction = " and "
|
||||
else:
|
||||
conjunction = " or "
|
||||
return conjunction.join(p._as_string(negate=negate)
|
||||
for p in self.predicates)
|
||||
else:
|
||||
return self._str._as_string(negate=negate)
|
||||
|
||||
def _negation_str(self):
|
||||
if self.description is not None:
|
||||
return "Not " + (self.description % {"spec": self._str})
|
||||
else:
|
||||
return self._eval_str(negate=True)
|
||||
|
||||
def _as_string(self, negate=False):
|
||||
def _eval_str(self, config, negate=False):
|
||||
if negate:
|
||||
return self._negation_str()
|
||||
conjunction = " and "
|
||||
else:
|
||||
conjunction = " or "
|
||||
return conjunction.join(p._as_string(config, negate=negate)
|
||||
for p in self.predicates)
|
||||
|
||||
def _negation_str(self, config):
|
||||
if self.description is not None:
|
||||
return "Not " + self._format_description(config)
|
||||
else:
|
||||
return self._eval_str(config, negate=True)
|
||||
|
||||
def _as_string(self, config, negate=False):
|
||||
if negate:
|
||||
return self._negation_str(config)
|
||||
else:
|
||||
if self.description is not None:
|
||||
return self.description % {"spec": self._str}
|
||||
return self._format_description(config)
|
||||
else:
|
||||
return self._eval_str()
|
||||
return self._eval_str(config)
|
||||
|
||||
def __str__(self):
|
||||
return self._as_string()
|
||||
|
||||
_as_predicate = Predicate.as_predicate
|
||||
|
||||
|
|
@ -325,8 +398,8 @@ def closed():
|
|||
return skip_if(BooleanPredicate(True, "marked as skip"))
|
||||
|
||||
|
||||
def fails():
|
||||
return fails_if(BooleanPredicate(True, "expected to fail"))
|
||||
def fails(reason=None):
|
||||
return fails_if(BooleanPredicate(True, reason or "expected to fail"))
|
||||
|
||||
|
||||
@decorator
|
||||
|
|
@ -341,8 +414,8 @@ def fails_on(db, reason=None):
|
|||
def fails_on_everything_except(*dbs):
|
||||
return succeeds_if(
|
||||
OrPredicate([
|
||||
SpecPredicate(db) for db in dbs
|
||||
])
|
||||
SpecPredicate(db) for db in dbs
|
||||
])
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -352,7 +425,7 @@ def skip(db, reason=None):
|
|||
|
||||
def only_on(dbs, reason=None):
|
||||
return only_if(
|
||||
OrPredicate([SpecPredicate(db) for db in util.to_list(dbs)])
|
||||
OrPredicate([Predicate.as_predicate(db) for db in util.to_list(dbs)])
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# testing/fixtures.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
|
@ -91,20 +91,12 @@ class TablesTest(TestBase):
|
|||
cls.run_create_tables = 'each'
|
||||
assert cls.run_inserts in ('each', None)
|
||||
|
||||
if cls.other is None:
|
||||
cls.other = adict()
|
||||
cls.other = adict()
|
||||
cls.tables = adict()
|
||||
|
||||
if cls.tables is None:
|
||||
cls.tables = adict()
|
||||
|
||||
if cls.bind is None:
|
||||
setattr(cls, 'bind', cls.setup_bind())
|
||||
|
||||
if cls.metadata is None:
|
||||
setattr(cls, 'metadata', sa.MetaData())
|
||||
|
||||
if cls.metadata.bind is None:
|
||||
cls.metadata.bind = cls.bind
|
||||
cls.bind = cls.setup_bind()
|
||||
cls.metadata = sa.MetaData()
|
||||
cls.metadata.bind = cls.bind
|
||||
|
||||
@classmethod
|
||||
def _setup_once_inserts(cls):
|
||||
|
|
@ -142,13 +134,14 @@ class TablesTest(TestBase):
|
|||
def _teardown_each_tables(self):
|
||||
# no need to run deletes if tables are recreated on setup
|
||||
if self.run_define_tables != 'each' and self.run_deletes == 'each':
|
||||
for table in reversed(self.metadata.sorted_tables):
|
||||
try:
|
||||
table.delete().execute().close()
|
||||
except sa.exc.DBAPIError as ex:
|
||||
util.print_(
|
||||
("Error emptying table %s: %r" % (table, ex)),
|
||||
file=sys.stderr)
|
||||
with self.bind.connect() as conn:
|
||||
for table in reversed(self.metadata.sorted_tables):
|
||||
try:
|
||||
conn.execute(table.delete())
|
||||
except sa.exc.DBAPIError as ex:
|
||||
util.print_(
|
||||
("Error emptying table %s: %r" % (table, ex)),
|
||||
file=sys.stderr)
|
||||
|
||||
def setup(self):
|
||||
self._setup_each_tables()
|
||||
|
|
@ -200,9 +193,8 @@ class TablesTest(TestBase):
|
|||
def sql_count_(self, count, fn):
|
||||
self.assert_sql_count(self.bind, fn, count)
|
||||
|
||||
def sql_eq_(self, callable_, statements, with_sequences=None):
|
||||
self.assert_sql(self.bind,
|
||||
callable_, statements, with_sequences)
|
||||
def sql_eq_(self, callable_, statements):
|
||||
self.assert_sql(self.bind, callable_, statements)
|
||||
|
||||
@classmethod
|
||||
def _load_fixtures(cls):
|
||||
|
|
@ -283,12 +275,14 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
|
|||
|
||||
def setup(self):
|
||||
self._setup_each_tables()
|
||||
self._setup_each_classes()
|
||||
self._setup_each_mappers()
|
||||
self._setup_each_inserts()
|
||||
|
||||
def teardown(self):
|
||||
sa.orm.session.Session.close_all()
|
||||
self._teardown_each_mappers()
|
||||
self._teardown_each_classes()
|
||||
self._teardown_each_tables()
|
||||
|
||||
@classmethod
|
||||
|
|
@ -310,6 +304,10 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
|
|||
if self.run_setup_mappers == 'each':
|
||||
self._with_register_classes(self.setup_mappers)
|
||||
|
||||
def _setup_each_classes(self):
|
||||
if self.run_setup_classes == 'each':
|
||||
self._with_register_classes(self.setup_classes)
|
||||
|
||||
@classmethod
|
||||
def _with_register_classes(cls, fn):
|
||||
"""Run a setup method, framing the operation with a Base class
|
||||
|
|
@ -344,6 +342,10 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
|
|||
if self.run_setup_mappers != 'once':
|
||||
sa.orm.clear_mappers()
|
||||
|
||||
def _teardown_each_classes(self):
|
||||
if self.run_setup_classes != 'once':
|
||||
self.classes.clear()
|
||||
|
||||
@classmethod
|
||||
def setup_classes(cls):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# testing/mock.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
|
@ -11,10 +11,10 @@ from __future__ import absolute_import
|
|||
from ..util import py33
|
||||
|
||||
if py33:
|
||||
from unittest.mock import MagicMock, Mock, call, patch
|
||||
from unittest.mock import MagicMock, Mock, call, patch, ANY
|
||||
else:
|
||||
try:
|
||||
from mock import MagicMock, Mock, call, patch
|
||||
from mock import MagicMock, Mock, call, patch, ANY
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"SQLAlchemy's test suite requires the "
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# testing/pickleable.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
"""
|
||||
Bootstrapper for nose/pytest plugins.
|
||||
|
||||
The entire rationale for this system is to get the modules in plugin/
|
||||
imported without importing all of the supporting library, so that we can
|
||||
set up things for testing before coverage starts.
|
||||
|
||||
The rationale for all of plugin/ being *in* the supporting library in the
|
||||
first place is so that the testing and plugin suite is available to other
|
||||
libraries, mainly external SQLAlchemy and Alembic dialects, to make use
|
||||
of the same test environment and standard suites available to
|
||||
SQLAlchemy/Alembic themselves without the need to ship/install a separate
|
||||
package outside of SQLAlchemy.
|
||||
|
||||
NOTE: copied/adapted from SQLAlchemy master for backwards compatibility;
|
||||
this should be removable when Alembic targets SQLAlchemy 1.0.0.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
bootstrap_file = locals()['bootstrap_file']
|
||||
to_bootstrap = locals()['to_bootstrap']
|
||||
|
||||
|
||||
def load_file_as_module(name):
|
||||
path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name)
|
||||
if sys.version_info >= (3, 3):
|
||||
from importlib import machinery
|
||||
mod = machinery.SourceFileLoader(name, path).load_module()
|
||||
else:
|
||||
import imp
|
||||
mod = imp.load_source(name, path)
|
||||
return mod
|
||||
|
||||
if to_bootstrap == "pytest":
|
||||
sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
|
||||
sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
|
||||
elif to_bootstrap == "nose":
|
||||
sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
|
||||
sys.modules["sqla_noseplugin"] = load_file_as_module("noseplugin")
|
||||
else:
|
||||
raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
# plugin/noseplugin.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
|
@ -12,22 +12,22 @@ way (e.g. as a package-less import).
|
|||
|
||||
"""
|
||||
|
||||
try:
|
||||
# installed by bootstrap.py
|
||||
import sqla_plugin_base as plugin_base
|
||||
except ImportError:
|
||||
# assume we're a package, use traditional import
|
||||
from . import plugin_base
|
||||
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from nose.plugins import Plugin
|
||||
import nose
|
||||
fixtures = None
|
||||
|
||||
# no package imports yet! this prevents us from tripping coverage
|
||||
# too soon.
|
||||
path = os.path.join(os.path.dirname(__file__), "plugin_base.py")
|
||||
if sys.version_info >= (3, 3):
|
||||
from importlib import machinery
|
||||
plugin_base = machinery.SourceFileLoader(
|
||||
"plugin_base", path).load_module()
|
||||
else:
|
||||
import imp
|
||||
plugin_base = imp.load_source("plugin_base", path)
|
||||
py3k = sys.version_info >= (3, 0)
|
||||
|
||||
|
||||
class NoseSQLAlchemy(Plugin):
|
||||
|
|
@ -57,28 +57,39 @@ class NoseSQLAlchemy(Plugin):
|
|||
|
||||
plugin_base.set_coverage_flag(options.enable_plugin_coverage)
|
||||
|
||||
global fixtures
|
||||
from sqlalchemy.testing import fixtures
|
||||
plugin_base.set_skip_test(nose.SkipTest)
|
||||
|
||||
def begin(self):
|
||||
global fixtures
|
||||
from sqlalchemy.testing import fixtures # noqa
|
||||
|
||||
plugin_base.post_begin()
|
||||
|
||||
def describeTest(self, test):
|
||||
return ""
|
||||
|
||||
def wantFunction(self, fn):
|
||||
if fn.__module__ is None:
|
||||
return False
|
||||
if fn.__module__.startswith('sqlalchemy.testing'):
|
||||
return False
|
||||
return False
|
||||
|
||||
def wantMethod(self, fn):
|
||||
if py3k:
|
||||
if not hasattr(fn.__self__, 'cls'):
|
||||
return False
|
||||
cls = fn.__self__.cls
|
||||
else:
|
||||
cls = fn.im_class
|
||||
return plugin_base.want_method(cls, fn)
|
||||
|
||||
def wantClass(self, cls):
|
||||
return plugin_base.want_class(cls)
|
||||
|
||||
def beforeTest(self, test):
|
||||
plugin_base.before_test(test,
|
||||
test.test.cls.__module__,
|
||||
test.test.cls, test.test.method.__name__)
|
||||
if not hasattr(test.test, 'cls'):
|
||||
return
|
||||
plugin_base.before_test(
|
||||
test,
|
||||
test.test.cls.__module__,
|
||||
test.test.cls, test.test.method.__name__)
|
||||
|
||||
def afterTest(self, test):
|
||||
plugin_base.after_test(test)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# plugin/plugin_base.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
|
@ -14,12 +14,6 @@ functionality via py.test.
|
|||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
try:
|
||||
# unitttest has a SkipTest also but pytest doesn't
|
||||
# honor it unless nose is imported too...
|
||||
from nose import SkipTest
|
||||
except ImportError:
|
||||
from _pytest.runner import Skipped as SkipTest
|
||||
|
||||
import sys
|
||||
import re
|
||||
|
|
@ -31,7 +25,6 @@ if py3k:
|
|||
else:
|
||||
import ConfigParser as configparser
|
||||
|
||||
|
||||
# late imports
|
||||
fixtures = None
|
||||
engines = None
|
||||
|
|
@ -47,7 +40,8 @@ file_config = None
|
|||
|
||||
|
||||
logging = None
|
||||
db_opts = {}
|
||||
include_tags = set()
|
||||
exclude_tags = set()
|
||||
options = None
|
||||
|
||||
|
||||
|
|
@ -69,8 +63,6 @@ def setup_options(make_option):
|
|||
help="Drop all tables in the target database first")
|
||||
make_option("--backend-only", action="store_true", dest="backend_only",
|
||||
help="Run only tests marked with __backend__")
|
||||
make_option("--mockpool", action="store_true", dest="mockpool",
|
||||
help="Use mock pool (asserts only one connection used)")
|
||||
make_option("--low-connections", action="store_true",
|
||||
dest="low_connections",
|
||||
help="Use a low number of distinct connections - "
|
||||
|
|
@ -86,18 +78,56 @@ def setup_options(make_option):
|
|||
dest="cdecimal", default=False,
|
||||
help="Monkeypatch the cdecimal library into Python 'decimal' "
|
||||
"for all tests")
|
||||
make_option("--serverside", action="callback",
|
||||
callback=_server_side_cursors,
|
||||
help="Turn on server side cursors for PG")
|
||||
make_option("--mysql-engine", action="store",
|
||||
dest="mysql_engine", default=None,
|
||||
help="Use the specified MySQL storage engine for all tables, "
|
||||
"default is a db-default/InnoDB combo.")
|
||||
make_option("--tableopts", action="append", dest="tableopts", default=[],
|
||||
help="Add a dialect-specific table option, key=value")
|
||||
make_option("--include-tag", action="callback", callback=_include_tag,
|
||||
type="string",
|
||||
help="Include tests with tag <tag>")
|
||||
make_option("--exclude-tag", action="callback", callback=_exclude_tag,
|
||||
type="string",
|
||||
help="Exclude tests with tag <tag>")
|
||||
make_option("--write-profiles", action="store_true",
|
||||
dest="write_profiles", default=False,
|
||||
help="Write/update profiling data.")
|
||||
help="Write/update failing profiling data.")
|
||||
make_option("--force-write-profiles", action="store_true",
|
||||
dest="force_write_profiles", default=False,
|
||||
help="Unconditionally write/update profiling data.")
|
||||
|
||||
|
||||
def configure_follower(follower_ident):
|
||||
"""Configure required state for a follower.
|
||||
|
||||
This invokes in the parent process and typically includes
|
||||
database creation.
|
||||
|
||||
"""
|
||||
from sqlalchemy.testing import provision
|
||||
provision.FOLLOWER_IDENT = follower_ident
|
||||
|
||||
|
||||
def memoize_important_follower_config(dict_):
|
||||
"""Store important configuration we will need to send to a follower.
|
||||
|
||||
This invokes in the parent process after normal config is set up.
|
||||
|
||||
This is necessary as py.test seems to not be using forking, so we
|
||||
start with nothing in memory, *but* it isn't running our argparse
|
||||
callables, so we have to just copy all of that over.
|
||||
|
||||
"""
|
||||
dict_['memoized_config'] = {
|
||||
'include_tags': include_tags,
|
||||
'exclude_tags': exclude_tags
|
||||
}
|
||||
|
||||
|
||||
def restore_important_follower_config(dict_):
|
||||
"""Restore important configuration needed by a follower.
|
||||
|
||||
This invokes in the follower process.
|
||||
|
||||
"""
|
||||
global include_tags, exclude_tags
|
||||
include_tags.update(dict_['memoized_config']['include_tags'])
|
||||
exclude_tags.update(dict_['memoized_config']['exclude_tags'])
|
||||
|
||||
|
||||
def read_config():
|
||||
|
|
@ -117,6 +147,13 @@ def pre_begin(opt):
|
|||
def set_coverage_flag(value):
|
||||
options.has_coverage = value
|
||||
|
||||
_skip_test_exception = None
|
||||
|
||||
|
||||
def set_skip_test(exc):
|
||||
global _skip_test_exception
|
||||
_skip_test_exception = exc
|
||||
|
||||
|
||||
def post_begin():
|
||||
"""things to set up later, once we know coverage is running."""
|
||||
|
|
@ -129,10 +166,12 @@ def post_begin():
|
|||
global util, fixtures, engines, exclusions, \
|
||||
assertions, warnings, profiling,\
|
||||
config, testing
|
||||
from sqlalchemy import testing
|
||||
from sqlalchemy.testing import fixtures, engines, exclusions, \
|
||||
assertions, warnings, profiling, config
|
||||
from sqlalchemy import util
|
||||
from sqlalchemy import testing # noqa
|
||||
from sqlalchemy.testing import fixtures, engines, exclusions # noqa
|
||||
from sqlalchemy.testing import assertions, warnings, profiling # noqa
|
||||
from sqlalchemy.testing import config # noqa
|
||||
from sqlalchemy import util # noqa
|
||||
warnings.setup_filters()
|
||||
|
||||
|
||||
def _log(opt_str, value, parser):
|
||||
|
|
@ -154,14 +193,17 @@ def _list_dbs(*args):
|
|||
sys.exit(0)
|
||||
|
||||
|
||||
def _server_side_cursors(opt_str, value, parser):
|
||||
db_opts['server_side_cursors'] = True
|
||||
|
||||
|
||||
def _requirements_opt(opt_str, value, parser):
|
||||
_setup_requirements(value)
|
||||
|
||||
|
||||
def _exclude_tag(opt_str, value, parser):
|
||||
exclude_tags.add(value.replace('-', '_'))
|
||||
|
||||
|
||||
def _include_tag(opt_str, value, parser):
|
||||
include_tags.add(value.replace('-', '_'))
|
||||
|
||||
pre_configure = []
|
||||
post_configure = []
|
||||
|
||||
|
|
@ -189,10 +231,18 @@ def _monkeypatch_cdecimal(options, file_config):
|
|||
sys.modules['decimal'] = cdecimal
|
||||
|
||||
|
||||
@post
|
||||
def _init_skiptest(options, file_config):
|
||||
from sqlalchemy.testing import config
|
||||
|
||||
config._skip_test_exception = _skip_test_exception
|
||||
|
||||
|
||||
@post
|
||||
def _engine_uri(options, file_config):
|
||||
from sqlalchemy.testing import engines, config
|
||||
from sqlalchemy.testing import config
|
||||
from sqlalchemy import testing
|
||||
from sqlalchemy.testing import provision
|
||||
|
||||
if options.dburi:
|
||||
db_urls = list(options.dburi)
|
||||
|
|
@ -214,18 +264,11 @@ def _engine_uri(options, file_config):
|
|||
db_urls.append(file_config.get('db', 'default'))
|
||||
|
||||
for db_url in db_urls:
|
||||
eng = engines.testing_engine(db_url, db_opts)
|
||||
eng.connect().close()
|
||||
config.Config.register(eng, db_opts, options, file_config, testing)
|
||||
cfg = provision.setup_config(
|
||||
db_url, options, file_config, provision.FOLLOWER_IDENT)
|
||||
|
||||
config.db_opts = db_opts
|
||||
|
||||
|
||||
@post
|
||||
def _engine_pool(options, file_config):
|
||||
if options.mockpool:
|
||||
from sqlalchemy import pool
|
||||
db_opts['poolclass'] = pool.AssertionPool
|
||||
if not config._current:
|
||||
cfg.set_as_current(cfg, testing)
|
||||
|
||||
|
||||
@post
|
||||
|
|
@ -256,7 +299,8 @@ def _setup_requirements(argument):
|
|||
|
||||
@post
|
||||
def _prep_testing_database(options, file_config):
|
||||
from sqlalchemy.testing import config
|
||||
from sqlalchemy.testing import config, util
|
||||
from sqlalchemy.testing.exclusions import against
|
||||
from sqlalchemy import schema, inspect
|
||||
|
||||
if options.dropfirst:
|
||||
|
|
@ -286,32 +330,18 @@ def _prep_testing_database(options, file_config):
|
|||
schema="test_schema")
|
||||
))
|
||||
|
||||
for tname in reversed(inspector.get_table_names(
|
||||
order_by="foreign_key")):
|
||||
e.execute(schema.DropTable(
|
||||
schema.Table(tname, schema.MetaData())
|
||||
))
|
||||
util.drop_all_tables(e, inspector)
|
||||
|
||||
if config.requirements.schemas.enabled_for_config(cfg):
|
||||
for tname in reversed(inspector.get_table_names(
|
||||
order_by="foreign_key", schema="test_schema")):
|
||||
e.execute(schema.DropTable(
|
||||
schema.Table(tname, schema.MetaData(),
|
||||
schema="test_schema")
|
||||
))
|
||||
util.drop_all_tables(e, inspector, schema=cfg.test_schema)
|
||||
|
||||
|
||||
@post
|
||||
def _set_table_options(options, file_config):
|
||||
from sqlalchemy.testing import schema
|
||||
|
||||
table_options = schema.table_options
|
||||
for spec in options.tableopts:
|
||||
key, value = spec.split('=')
|
||||
table_options[key] = value
|
||||
|
||||
if options.mysql_engine:
|
||||
table_options['mysql_engine'] = options.mysql_engine
|
||||
if against(cfg, "postgresql"):
|
||||
from sqlalchemy.dialects import postgresql
|
||||
for enum in inspector.get_enums("*"):
|
||||
e.execute(postgresql.DropEnumType(
|
||||
postgresql.ENUM(
|
||||
name=enum['name'],
|
||||
schema=enum['schema'])))
|
||||
|
||||
|
||||
@post
|
||||
|
|
@ -347,6 +377,30 @@ def want_class(cls):
|
|||
return True
|
||||
|
||||
|
||||
def want_method(cls, fn):
|
||||
if not fn.__name__.startswith("test_"):
|
||||
return False
|
||||
elif fn.__module__ is None:
|
||||
return False
|
||||
elif include_tags:
|
||||
return (
|
||||
hasattr(cls, '__tags__') and
|
||||
exclusions.tags(cls.__tags__).include_test(
|
||||
include_tags, exclude_tags)
|
||||
) or (
|
||||
hasattr(fn, '_sa_exclusion_extend') and
|
||||
fn._sa_exclusion_extend.include_test(
|
||||
include_tags, exclude_tags)
|
||||
)
|
||||
elif exclude_tags and hasattr(cls, '__tags__'):
|
||||
return exclusions.tags(cls.__tags__).include_test(
|
||||
include_tags, exclude_tags)
|
||||
elif exclude_tags and hasattr(fn, '_sa_exclusion_extend'):
|
||||
return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags)
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def generate_sub_tests(cls, module):
|
||||
if getattr(cls, '__backend__', False):
|
||||
for cfg in _possible_configs_for_cls(cls):
|
||||
|
|
@ -356,7 +410,7 @@ def generate_sub_tests(cls, module):
|
|||
(cls, ),
|
||||
{
|
||||
"__only_on__": ("%s+%s" % (cfg.db.name, cfg.db.driver)),
|
||||
"__backend__": False}
|
||||
}
|
||||
)
|
||||
setattr(module, name, subcls)
|
||||
yield subcls
|
||||
|
|
@ -370,6 +424,8 @@ def start_test_class(cls):
|
|||
|
||||
|
||||
def stop_test_class(cls):
|
||||
#from sqlalchemy import inspect
|
||||
#assert not inspect(testing.db).get_table_names()
|
||||
engines.testing_reaper._stop_test_ctx()
|
||||
if not options.low_connections:
|
||||
assertions.global_cleanup_assertions()
|
||||
|
|
@ -398,33 +454,27 @@ def before_test(test, test_module_name, test_class, test_name):
|
|||
|
||||
id_ = "%s.%s.%s" % (test_module_name, name, test_name)
|
||||
|
||||
warnings.resetwarnings()
|
||||
profiling._current_test = id_
|
||||
|
||||
|
||||
def after_test(test):
|
||||
engines.testing_reaper._after_test_ctx()
|
||||
warnings.resetwarnings()
|
||||
|
||||
|
||||
def _possible_configs_for_cls(cls):
|
||||
def _possible_configs_for_cls(cls, reasons=None):
|
||||
all_configs = set(config.Config.all_configs())
|
||||
|
||||
if cls.__unsupported_on__:
|
||||
spec = exclusions.db_spec(*cls.__unsupported_on__)
|
||||
for config_obj in list(all_configs):
|
||||
if spec(config_obj):
|
||||
all_configs.remove(config_obj)
|
||||
|
||||
if getattr(cls, '__only_on__', None):
|
||||
spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
|
||||
for config_obj in list(all_configs):
|
||||
if not spec(config_obj):
|
||||
all_configs.remove(config_obj)
|
||||
return all_configs
|
||||
|
||||
|
||||
def _do_skips(cls):
|
||||
all_configs = _possible_configs_for_cls(cls)
|
||||
reasons = []
|
||||
|
||||
if hasattr(cls, '__requires__'):
|
||||
requirements = config.requirements
|
||||
|
|
@ -432,10 +482,11 @@ def _do_skips(cls):
|
|||
for requirement in cls.__requires__:
|
||||
check = getattr(requirements, requirement)
|
||||
|
||||
if check.predicate(config_obj):
|
||||
skip_reasons = check.matching_config_reasons(config_obj)
|
||||
if skip_reasons:
|
||||
all_configs.remove(config_obj)
|
||||
if check.reason:
|
||||
reasons.append(check.reason)
|
||||
if reasons is not None:
|
||||
reasons.extend(skip_reasons)
|
||||
break
|
||||
|
||||
if hasattr(cls, '__prefer_requires__'):
|
||||
|
|
@ -445,36 +496,45 @@ def _do_skips(cls):
|
|||
for requirement in cls.__prefer_requires__:
|
||||
check = getattr(requirements, requirement)
|
||||
|
||||
if check.predicate(config_obj):
|
||||
if not check.enabled_for_config(config_obj):
|
||||
non_preferred.add(config_obj)
|
||||
if all_configs.difference(non_preferred):
|
||||
all_configs.difference_update(non_preferred)
|
||||
|
||||
return all_configs
|
||||
|
||||
|
||||
def _do_skips(cls):
|
||||
reasons = []
|
||||
all_configs = _possible_configs_for_cls(cls, reasons)
|
||||
|
||||
if getattr(cls, '__skip_if__', False):
|
||||
for c in getattr(cls, '__skip_if__'):
|
||||
if c():
|
||||
raise SkipTest("'%s' skipped by %s" % (
|
||||
config.skip_test("'%s' skipped by %s" % (
|
||||
cls.__name__, c.__name__)
|
||||
)
|
||||
|
||||
for db_spec, op, spec in getattr(cls, '__excluded_on__', ()):
|
||||
for config_obj in list(all_configs):
|
||||
if exclusions.skip_if(
|
||||
exclusions.SpecPredicate(db_spec, op, spec)
|
||||
).predicate(config_obj):
|
||||
all_configs.remove(config_obj)
|
||||
if not all_configs:
|
||||
raise SkipTest(
|
||||
"'%s' unsupported on DB implementation %s%s" % (
|
||||
if getattr(cls, '__backend__', False):
|
||||
msg = "'%s' unsupported for implementation '%s'" % (
|
||||
cls.__name__, cls.__only_on__)
|
||||
else:
|
||||
msg = "'%s' unsupported on any DB implementation %s%s" % (
|
||||
cls.__name__,
|
||||
", ".join("'%s' = %s"
|
||||
% (config_obj.db.name,
|
||||
config_obj.db.dialect.server_version_info)
|
||||
for config_obj in config.Config.all_configs()
|
||||
),
|
||||
", ".join(
|
||||
"'%s(%s)+%s'" % (
|
||||
config_obj.db.name,
|
||||
".".join(
|
||||
str(dig) for dig in
|
||||
config_obj.db.dialect.server_version_info),
|
||||
config_obj.db.driver
|
||||
)
|
||||
for config_obj in config.Config.all_configs()
|
||||
),
|
||||
", ".join(reasons)
|
||||
)
|
||||
)
|
||||
config.skip_test(msg)
|
||||
elif hasattr(cls, '__prefer_backends__'):
|
||||
non_preferred = set()
|
||||
spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
|
||||
|
|
|
|||
|
|
@ -1,8 +1,21 @@
|
|||
try:
|
||||
# installed by bootstrap.py
|
||||
import sqla_plugin_base as plugin_base
|
||||
except ImportError:
|
||||
# assume we're a package, use traditional import
|
||||
from . import plugin_base
|
||||
|
||||
import pytest
|
||||
import argparse
|
||||
import inspect
|
||||
from . import plugin_base
|
||||
import collections
|
||||
import itertools
|
||||
|
||||
try:
|
||||
import xdist # noqa
|
||||
has_xdist = True
|
||||
except ImportError:
|
||||
has_xdist = False
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
|
|
@ -24,13 +37,40 @@ def pytest_addoption(parser):
|
|||
|
||||
|
||||
def pytest_configure(config):
|
||||
if hasattr(config, "slaveinput"):
|
||||
plugin_base.restore_important_follower_config(config.slaveinput)
|
||||
plugin_base.configure_follower(
|
||||
config.slaveinput["follower_ident"]
|
||||
)
|
||||
|
||||
plugin_base.pre_begin(config.option)
|
||||
|
||||
plugin_base.set_coverage_flag(bool(getattr(config.option,
|
||||
"cov_source", False)))
|
||||
|
||||
plugin_base.set_skip_test(pytest.skip.Exception)
|
||||
|
||||
|
||||
def pytest_sessionstart(session):
|
||||
plugin_base.post_begin()
|
||||
|
||||
if has_xdist:
|
||||
import uuid
|
||||
|
||||
def pytest_configure_node(node):
|
||||
# the master for each node fills slaveinput dictionary
|
||||
# which pytest-xdist will transfer to the subprocess
|
||||
|
||||
plugin_base.memoize_important_follower_config(node.slaveinput)
|
||||
|
||||
node.slaveinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12]
|
||||
from sqlalchemy.testing import provision
|
||||
provision.create_follower_db(node.slaveinput["follower_ident"])
|
||||
|
||||
def pytest_testnodedown(node, error):
|
||||
from sqlalchemy.testing import provision
|
||||
provision.drop_follower_db(node.slaveinput["follower_ident"])
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(session, config, items):
|
||||
# look for all those classes that specify __backend__ and
|
||||
|
|
@ -44,6 +84,10 @@ def pytest_collection_modifyitems(session, config, items):
|
|||
# new classes to a module on the fly.
|
||||
|
||||
rebuilt_items = collections.defaultdict(list)
|
||||
items[:] = [
|
||||
item for item in
|
||||
items if isinstance(item.parent, pytest.Instance)
|
||||
and not item.parent.parent.name.startswith("_")]
|
||||
test_classes = set(item.parent for item in items)
|
||||
for test_class in test_classes:
|
||||
for sub_cls in plugin_base.generate_sub_tests(
|
||||
|
|
@ -74,12 +118,11 @@ def pytest_collection_modifyitems(session, config, items):
|
|||
|
||||
|
||||
def pytest_pycollect_makeitem(collector, name, obj):
|
||||
|
||||
if inspect.isclass(obj) and plugin_base.want_class(obj):
|
||||
return pytest.Class(name, parent=collector)
|
||||
elif inspect.isfunction(obj) and \
|
||||
name.startswith("test_") and \
|
||||
isinstance(collector, pytest.Instance):
|
||||
isinstance(collector, pytest.Instance) and \
|
||||
plugin_base.want_method(collector.cls, obj):
|
||||
return pytest.Function(name, parent=collector)
|
||||
else:
|
||||
return []
|
||||
|
|
@ -97,16 +140,18 @@ def pytest_runtest_setup(item):
|
|||
return
|
||||
|
||||
# ... so we're doing a little dance here to figure it out...
|
||||
if item.parent.parent is not _current_class:
|
||||
|
||||
if _current_class is None:
|
||||
class_setup(item.parent.parent)
|
||||
_current_class = item.parent.parent
|
||||
|
||||
# this is needed for the class-level, to ensure that the
|
||||
# teardown runs after the class is completed with its own
|
||||
# class-level teardown...
|
||||
item.parent.parent.addfinalizer(
|
||||
lambda: class_teardown(item.parent.parent))
|
||||
def finalize():
|
||||
global _current_class
|
||||
class_teardown(item.parent.parent)
|
||||
_current_class = None
|
||||
item.parent.parent.addfinalizer(finalize)
|
||||
|
||||
test_setup(item)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# testing/profiling.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
|
@ -14,13 +14,11 @@ in a more fine-grained way than nose's profiling plugin.
|
|||
|
||||
import os
|
||||
import sys
|
||||
from .util import gc_collect, decorator
|
||||
from .util import gc_collect
|
||||
from . import config
|
||||
from .plugin.plugin_base import SkipTest
|
||||
import pstats
|
||||
import time
|
||||
import collections
|
||||
from .. import util
|
||||
import contextlib
|
||||
|
||||
try:
|
||||
import cProfile
|
||||
|
|
@ -30,64 +28,8 @@ from ..util import jython, pypy, win32, update_wrapper
|
|||
|
||||
_current_test = None
|
||||
|
||||
|
||||
def profiled(target=None, **target_opts):
|
||||
"""Function profiling.
|
||||
|
||||
@profiled()
|
||||
or
|
||||
@profiled(report=True, sort=('calls',), limit=20)
|
||||
|
||||
Outputs profiling info for a decorated function.
|
||||
|
||||
"""
|
||||
|
||||
profile_config = {'targets': set(),
|
||||
'report': True,
|
||||
'print_callers': False,
|
||||
'print_callees': False,
|
||||
'graphic': False,
|
||||
'sort': ('time', 'calls'),
|
||||
'limit': None}
|
||||
if target is None:
|
||||
target = 'anonymous_target'
|
||||
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
elapsed, load_stats, result = _profile(
|
||||
fn, *args, **kw)
|
||||
|
||||
graphic = target_opts.get('graphic', profile_config['graphic'])
|
||||
if graphic:
|
||||
os.system("runsnake %s" % filename)
|
||||
else:
|
||||
report = target_opts.get('report', profile_config['report'])
|
||||
if report:
|
||||
sort_ = target_opts.get('sort', profile_config['sort'])
|
||||
limit = target_opts.get('limit', profile_config['limit'])
|
||||
print(("Profile report for target '%s'" % (
|
||||
target, )
|
||||
))
|
||||
|
||||
stats = load_stats()
|
||||
stats.sort_stats(*sort_)
|
||||
if limit:
|
||||
stats.print_stats(limit)
|
||||
else:
|
||||
stats.print_stats()
|
||||
|
||||
print_callers = target_opts.get(
|
||||
'print_callers', profile_config['print_callers'])
|
||||
if print_callers:
|
||||
stats.print_callers()
|
||||
|
||||
print_callees = target_opts.get(
|
||||
'print_callees', profile_config['print_callees'])
|
||||
if print_callees:
|
||||
stats.print_callees()
|
||||
|
||||
return result
|
||||
return decorate
|
||||
# ProfileStatsFile instance, set up in plugin_base
|
||||
_profile_stats = None
|
||||
|
||||
|
||||
class ProfileStatsFile(object):
|
||||
|
|
@ -99,7 +41,11 @@ class ProfileStatsFile(object):
|
|||
"""
|
||||
|
||||
def __init__(self, filename):
|
||||
self.write = (
|
||||
self.force_write = (
|
||||
config.options is not None and
|
||||
config.options.force_write_profiles
|
||||
)
|
||||
self.write = self.force_write or (
|
||||
config.options is not None and
|
||||
config.options.write_profiles
|
||||
)
|
||||
|
|
@ -129,6 +75,11 @@ class ProfileStatsFile(object):
|
|||
platform_tokens.append("pypy")
|
||||
if win32:
|
||||
platform_tokens.append("win")
|
||||
platform_tokens.append(
|
||||
"nativeunicode"
|
||||
if config.db.dialect.convert_unicode
|
||||
else "dbapiunicode"
|
||||
)
|
||||
_has_cext = config.requirements._has_cextensions()
|
||||
platform_tokens.append(_has_cext and "cextensions" or "nocextensions")
|
||||
return "_".join(platform_tokens)
|
||||
|
|
@ -172,25 +123,32 @@ class ProfileStatsFile(object):
|
|||
per_fn = self.data[test_key]
|
||||
per_platform = per_fn[self.platform_key]
|
||||
counts = per_platform['counts']
|
||||
counts[-1] = callcount
|
||||
current_count = per_platform['current_count']
|
||||
if current_count < len(counts):
|
||||
counts[current_count - 1] = callcount
|
||||
else:
|
||||
counts[-1] = callcount
|
||||
if self.write:
|
||||
self._write()
|
||||
|
||||
def _header(self):
|
||||
return \
|
||||
"# %s\n"\
|
||||
"# This file is written out on a per-environment basis.\n"\
|
||||
"# For each test in aaa_profiling, the corresponding function and \n"\
|
||||
"# environment is located within this file. If it doesn't exist,\n"\
|
||||
"# the test is skipped.\n"\
|
||||
"# If a callcount does exist, it is compared to what we received. \n"\
|
||||
"# assertions are raised if the counts do not match.\n"\
|
||||
"# \n"\
|
||||
"# To add a new callcount test, apply the function_call_count \n"\
|
||||
"# decorator and re-run the tests using the --write-profiles \n"\
|
||||
"# option - this file will be rewritten including the new count.\n"\
|
||||
"# \n"\
|
||||
"" % (self.fname)
|
||||
return (
|
||||
"# %s\n"
|
||||
"# This file is written out on a per-environment basis.\n"
|
||||
"# For each test in aaa_profiling, the corresponding "
|
||||
"function and \n"
|
||||
"# environment is located within this file. "
|
||||
"If it doesn't exist,\n"
|
||||
"# the test is skipped.\n"
|
||||
"# If a callcount does exist, it is compared "
|
||||
"to what we received. \n"
|
||||
"# assertions are raised if the counts do not match.\n"
|
||||
"# \n"
|
||||
"# To add a new callcount test, apply the function_call_count \n"
|
||||
"# decorator and re-run the tests using the --write-profiles \n"
|
||||
"# option - this file will be rewritten including the new count.\n"
|
||||
"# \n"
|
||||
) % (self.fname)
|
||||
|
||||
def _read(self):
|
||||
try:
|
||||
|
|
@ -239,72 +197,69 @@ def function_call_count(variance=0.05):
|
|||
|
||||
def decorate(fn):
|
||||
def wrap(*args, **kw):
|
||||
|
||||
if cProfile is None:
|
||||
raise SkipTest("cProfile is not installed")
|
||||
|
||||
if not _profile_stats.has_stats() and not _profile_stats.write:
|
||||
# run the function anyway, to support dependent tests
|
||||
# (not a great idea but we have these in test_zoomark)
|
||||
fn(*args, **kw)
|
||||
raise SkipTest("No profiling stats available on this "
|
||||
"platform for this function. Run tests with "
|
||||
"--write-profiles to add statistics to %s for "
|
||||
"this platform." % _profile_stats.short_fname)
|
||||
|
||||
gc_collect()
|
||||
|
||||
timespent, load_stats, fn_result = _profile(
|
||||
fn, *args, **kw
|
||||
)
|
||||
stats = load_stats()
|
||||
callcount = stats.total_calls
|
||||
|
||||
expected = _profile_stats.result(callcount)
|
||||
if expected is None:
|
||||
expected_count = None
|
||||
else:
|
||||
line_no, expected_count = expected
|
||||
|
||||
print(("Pstats calls: %d Expected %s" % (
|
||||
callcount,
|
||||
expected_count
|
||||
)
|
||||
))
|
||||
stats.print_stats()
|
||||
# stats.print_callers()
|
||||
|
||||
if expected_count:
|
||||
deviance = int(callcount * variance)
|
||||
failed = abs(callcount - expected_count) > deviance
|
||||
|
||||
if failed:
|
||||
if _profile_stats.write:
|
||||
_profile_stats.replace(callcount)
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Adjusted function call count %s not within %s%% "
|
||||
"of expected %s. Rerun with --write-profiles to "
|
||||
"regenerate this callcount."
|
||||
% (
|
||||
callcount, (variance * 100),
|
||||
expected_count))
|
||||
return fn_result
|
||||
with count_functions(variance=variance):
|
||||
return fn(*args, **kw)
|
||||
return update_wrapper(wrap, fn)
|
||||
return decorate
|
||||
|
||||
|
||||
def _profile(fn, *args, **kw):
|
||||
filename = "%s.prof" % fn.__name__
|
||||
@contextlib.contextmanager
|
||||
def count_functions(variance=0.05):
|
||||
if cProfile is None:
|
||||
raise SkipTest("cProfile is not installed")
|
||||
|
||||
def load_stats():
|
||||
st = pstats.Stats(filename)
|
||||
os.unlink(filename)
|
||||
return st
|
||||
if not _profile_stats.has_stats() and not _profile_stats.write:
|
||||
config.skip_test(
|
||||
"No profiling stats available on this "
|
||||
"platform for this function. Run tests with "
|
||||
"--write-profiles to add statistics to %s for "
|
||||
"this platform." % _profile_stats.short_fname)
|
||||
|
||||
gc_collect()
|
||||
|
||||
pr = cProfile.Profile()
|
||||
pr.enable()
|
||||
#began = time.time()
|
||||
yield
|
||||
#ended = time.time()
|
||||
pr.disable()
|
||||
|
||||
#s = compat.StringIO()
|
||||
stats = pstats.Stats(pr, stream=sys.stdout)
|
||||
|
||||
#timespent = ended - began
|
||||
callcount = stats.total_calls
|
||||
|
||||
expected = _profile_stats.result(callcount)
|
||||
|
||||
if expected is None:
|
||||
expected_count = None
|
||||
else:
|
||||
line_no, expected_count = expected
|
||||
|
||||
print(("Pstats calls: %d Expected %s" % (
|
||||
callcount,
|
||||
expected_count
|
||||
)
|
||||
))
|
||||
stats.sort_stats("cumulative")
|
||||
stats.print_stats()
|
||||
|
||||
if expected_count:
|
||||
deviance = int(callcount * variance)
|
||||
failed = abs(callcount - expected_count) > deviance
|
||||
|
||||
if failed or _profile_stats.force_write:
|
||||
if _profile_stats.write:
|
||||
_profile_stats.replace(callcount)
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Adjusted function call count %s not within %s%% "
|
||||
"of expected %s, platform %s. Rerun with "
|
||||
"--write-profiles to "
|
||||
"regenerate this callcount."
|
||||
% (
|
||||
callcount, (variance * 100),
|
||||
expected_count, _profile_stats.platform_key))
|
||||
|
||||
began = time.time()
|
||||
cProfile.runctx('result = fn(*args, **kw)', globals(), locals(),
|
||||
filename=filename)
|
||||
ended = time.time()
|
||||
|
||||
return ended - began, load_stats, locals()['result']
|
||||
|
|
|
|||
317
lib/python3.4/site-packages/sqlalchemy/testing/provision.py
Normal file
317
lib/python3.4/site-packages/sqlalchemy/testing/provision.py
Normal file
|
|
@ -0,0 +1,317 @@
|
|||
from sqlalchemy.engine import url as sa_url
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import exc
|
||||
from sqlalchemy.util import compat
|
||||
from . import config, engines
|
||||
import time
|
||||
import logging
|
||||
import os
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
FOLLOWER_IDENT = None
|
||||
|
||||
|
||||
class register(object):
|
||||
def __init__(self):
|
||||
self.fns = {}
|
||||
|
||||
@classmethod
|
||||
def init(cls, fn):
|
||||
return register().for_db("*")(fn)
|
||||
|
||||
def for_db(self, dbname):
|
||||
def decorate(fn):
|
||||
self.fns[dbname] = fn
|
||||
return self
|
||||
return decorate
|
||||
|
||||
def __call__(self, cfg, *arg):
|
||||
if isinstance(cfg, compat.string_types):
|
||||
url = sa_url.make_url(cfg)
|
||||
elif isinstance(cfg, sa_url.URL):
|
||||
url = cfg
|
||||
else:
|
||||
url = cfg.db.url
|
||||
backend = url.get_backend_name()
|
||||
if backend in self.fns:
|
||||
return self.fns[backend](cfg, *arg)
|
||||
else:
|
||||
return self.fns['*'](cfg, *arg)
|
||||
|
||||
|
||||
def create_follower_db(follower_ident):
|
||||
|
||||
for cfg in _configs_for_db_operation():
|
||||
_create_db(cfg, cfg.db, follower_ident)
|
||||
|
||||
|
||||
def configure_follower(follower_ident):
|
||||
for cfg in config.Config.all_configs():
|
||||
_configure_follower(cfg, follower_ident)
|
||||
|
||||
|
||||
def setup_config(db_url, options, file_config, follower_ident):
|
||||
if follower_ident:
|
||||
db_url = _follower_url_from_main(db_url, follower_ident)
|
||||
db_opts = {}
|
||||
_update_db_opts(db_url, db_opts)
|
||||
eng = engines.testing_engine(db_url, db_opts)
|
||||
_post_configure_engine(db_url, eng, follower_ident)
|
||||
eng.connect().close()
|
||||
cfg = config.Config.register(eng, db_opts, options, file_config)
|
||||
if follower_ident:
|
||||
_configure_follower(cfg, follower_ident)
|
||||
return cfg
|
||||
|
||||
|
||||
def drop_follower_db(follower_ident):
|
||||
for cfg in _configs_for_db_operation():
|
||||
_drop_db(cfg, cfg.db, follower_ident)
|
||||
|
||||
|
||||
def _configs_for_db_operation():
|
||||
hosts = set()
|
||||
|
||||
for cfg in config.Config.all_configs():
|
||||
cfg.db.dispose()
|
||||
|
||||
for cfg in config.Config.all_configs():
|
||||
url = cfg.db.url
|
||||
backend = url.get_backend_name()
|
||||
host_conf = (
|
||||
backend,
|
||||
url.username, url.host, url.database)
|
||||
|
||||
if host_conf not in hosts:
|
||||
yield cfg
|
||||
hosts.add(host_conf)
|
||||
|
||||
for cfg in config.Config.all_configs():
|
||||
cfg.db.dispose()
|
||||
|
||||
|
||||
@register.init
|
||||
def _create_db(cfg, eng, ident):
|
||||
raise NotImplementedError("no DB creation routine for cfg: %s" % eng.url)
|
||||
|
||||
|
||||
@register.init
|
||||
def _drop_db(cfg, eng, ident):
|
||||
raise NotImplementedError("no DB drop routine for cfg: %s" % eng.url)
|
||||
|
||||
|
||||
@register.init
|
||||
def _update_db_opts(db_url, db_opts):
|
||||
pass
|
||||
|
||||
|
||||
@register.init
|
||||
def _configure_follower(cfg, ident):
|
||||
pass
|
||||
|
||||
|
||||
@register.init
|
||||
def _post_configure_engine(url, engine, follower_ident):
|
||||
pass
|
||||
|
||||
|
||||
@register.init
|
||||
def _follower_url_from_main(url, ident):
|
||||
url = sa_url.make_url(url)
|
||||
url.database = ident
|
||||
return url
|
||||
|
||||
|
||||
@_update_db_opts.for_db("mssql")
|
||||
def _mssql_update_db_opts(db_url, db_opts):
|
||||
db_opts['legacy_schema_aliasing'] = False
|
||||
|
||||
|
||||
@_follower_url_from_main.for_db("sqlite")
|
||||
def _sqlite_follower_url_from_main(url, ident):
|
||||
url = sa_url.make_url(url)
|
||||
if not url.database or url.database == ':memory:':
|
||||
return url
|
||||
else:
|
||||
return sa_url.make_url("sqlite:///%s.db" % ident)
|
||||
|
||||
|
||||
@_post_configure_engine.for_db("sqlite")
|
||||
def _sqlite_post_configure_engine(url, engine, follower_ident):
|
||||
from sqlalchemy import event
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def connect(dbapi_connection, connection_record):
|
||||
# use file DBs in all cases, memory acts kind of strangely
|
||||
# as an attached
|
||||
if not follower_ident:
|
||||
dbapi_connection.execute(
|
||||
'ATTACH DATABASE "test_schema.db" AS test_schema')
|
||||
else:
|
||||
dbapi_connection.execute(
|
||||
'ATTACH DATABASE "%s_test_schema.db" AS test_schema'
|
||||
% follower_ident)
|
||||
|
||||
|
||||
@_create_db.for_db("postgresql")
|
||||
def _pg_create_db(cfg, eng, ident):
|
||||
with eng.connect().execution_options(
|
||||
isolation_level="AUTOCOMMIT") as conn:
|
||||
try:
|
||||
_pg_drop_db(cfg, conn, ident)
|
||||
except Exception:
|
||||
pass
|
||||
currentdb = conn.scalar("select current_database()")
|
||||
for attempt in range(3):
|
||||
try:
|
||||
conn.execute(
|
||||
"CREATE DATABASE %s TEMPLATE %s" % (ident, currentdb))
|
||||
except exc.OperationalError as err:
|
||||
if attempt != 2 and "accessed by other users" in str(err):
|
||||
time.sleep(.2)
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
@_create_db.for_db("mysql")
|
||||
def _mysql_create_db(cfg, eng, ident):
|
||||
with eng.connect() as conn:
|
||||
try:
|
||||
_mysql_drop_db(cfg, conn, ident)
|
||||
except Exception:
|
||||
pass
|
||||
conn.execute("CREATE DATABASE %s" % ident)
|
||||
conn.execute("CREATE DATABASE %s_test_schema" % ident)
|
||||
conn.execute("CREATE DATABASE %s_test_schema_2" % ident)
|
||||
|
||||
|
||||
@_configure_follower.for_db("mysql")
|
||||
def _mysql_configure_follower(config, ident):
|
||||
config.test_schema = "%s_test_schema" % ident
|
||||
config.test_schema_2 = "%s_test_schema_2" % ident
|
||||
|
||||
|
||||
@_create_db.for_db("sqlite")
|
||||
def _sqlite_create_db(cfg, eng, ident):
|
||||
pass
|
||||
|
||||
|
||||
@_drop_db.for_db("postgresql")
|
||||
def _pg_drop_db(cfg, eng, ident):
|
||||
with eng.connect().execution_options(
|
||||
isolation_level="AUTOCOMMIT") as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"select pg_terminate_backend(pid) from pg_stat_activity "
|
||||
"where usename=current_user and pid != pg_backend_pid() "
|
||||
"and datname=:dname"
|
||||
), dname=ident)
|
||||
conn.execute("DROP DATABASE %s" % ident)
|
||||
|
||||
|
||||
@_drop_db.for_db("sqlite")
|
||||
def _sqlite_drop_db(cfg, eng, ident):
|
||||
if ident:
|
||||
os.remove("%s_test_schema.db" % ident)
|
||||
else:
|
||||
os.remove("%s.db" % ident)
|
||||
|
||||
|
||||
@_drop_db.for_db("mysql")
|
||||
def _mysql_drop_db(cfg, eng, ident):
|
||||
with eng.connect() as conn:
|
||||
try:
|
||||
conn.execute("DROP DATABASE %s_test_schema" % ident)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
conn.execute("DROP DATABASE %s_test_schema_2" % ident)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
conn.execute("DROP DATABASE %s" % ident)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@_create_db.for_db("oracle")
|
||||
def _oracle_create_db(cfg, eng, ident):
|
||||
# NOTE: make sure you've run "ALTER DATABASE default tablespace users" or
|
||||
# similar, so that the default tablespace is not "system"; reflection will
|
||||
# fail otherwise
|
||||
with eng.connect() as conn:
|
||||
conn.execute("create user %s identified by xe" % ident)
|
||||
conn.execute("create user %s_ts1 identified by xe" % ident)
|
||||
conn.execute("create user %s_ts2 identified by xe" % ident)
|
||||
conn.execute("grant dba to %s" % (ident, ))
|
||||
conn.execute("grant unlimited tablespace to %s" % ident)
|
||||
conn.execute("grant unlimited tablespace to %s_ts1" % ident)
|
||||
conn.execute("grant unlimited tablespace to %s_ts2" % ident)
|
||||
|
||||
@_configure_follower.for_db("oracle")
|
||||
def _oracle_configure_follower(config, ident):
|
||||
config.test_schema = "%s_ts1" % ident
|
||||
config.test_schema_2 = "%s_ts2" % ident
|
||||
|
||||
|
||||
def _ora_drop_ignore(conn, dbname):
|
||||
try:
|
||||
conn.execute("drop user %s cascade" % dbname)
|
||||
log.info("Reaped db: %s" % dbname)
|
||||
return True
|
||||
except exc.DatabaseError as err:
|
||||
log.warn("couldn't drop db: %s" % err)
|
||||
return False
|
||||
|
||||
|
||||
@_drop_db.for_db("oracle")
|
||||
def _oracle_drop_db(cfg, eng, ident):
|
||||
with eng.connect() as conn:
|
||||
# cx_Oracle seems to occasionally leak open connections when a large
|
||||
# suite it run, even if we confirm we have zero references to
|
||||
# connection objects.
|
||||
# while there is a "kill session" command in Oracle,
|
||||
# it unfortunately does not release the connection sufficiently.
|
||||
_ora_drop_ignore(conn, ident)
|
||||
_ora_drop_ignore(conn, "%s_ts1" % ident)
|
||||
_ora_drop_ignore(conn, "%s_ts2" % ident)
|
||||
|
||||
|
||||
def reap_oracle_dbs(eng):
|
||||
log.info("Reaping Oracle dbs...")
|
||||
with eng.connect() as conn:
|
||||
to_reap = conn.execute(
|
||||
"select u.username from all_users u where username "
|
||||
"like 'TEST_%' and not exists (select username "
|
||||
"from v$session where username=u.username)")
|
||||
all_names = set([username.lower() for (username, ) in to_reap])
|
||||
to_drop = set()
|
||||
for name in all_names:
|
||||
if name.endswith("_ts1") or name.endswith("_ts2"):
|
||||
continue
|
||||
else:
|
||||
to_drop.add(name)
|
||||
if "%s_ts1" % name in all_names:
|
||||
to_drop.add("%s_ts1" % name)
|
||||
if "%s_ts2" % name in all_names:
|
||||
to_drop.add("%s_ts2" % name)
|
||||
|
||||
dropped = total = 0
|
||||
for total, username in enumerate(to_drop, 1):
|
||||
if _ora_drop_ignore(conn, username):
|
||||
dropped += 1
|
||||
log.info(
|
||||
"Dropped %d out of %d stale databases detected", dropped, total)
|
||||
|
||||
|
||||
@_follower_url_from_main.for_db("oracle")
|
||||
def _oracle_follower_url_from_main(url, ident):
|
||||
url = sa_url.make_url(url)
|
||||
url.username = ident
|
||||
url.password = 'xe'
|
||||
return url
|
||||
|
||||
|
||||
172
lib/python3.4/site-packages/sqlalchemy/testing/replay_fixture.py
Normal file
172
lib/python3.4/site-packages/sqlalchemy/testing/replay_fixture.py
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
from . import fixtures
|
||||
from . import profiling
|
||||
from .. import util
|
||||
import types
|
||||
from collections import deque
|
||||
import contextlib
|
||||
from . import config
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class ReplayFixtureTest(fixtures.TestBase):
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _dummy_ctx(self, *arg, **kw):
|
||||
yield
|
||||
|
||||
def test_invocation(self):
|
||||
|
||||
dbapi_session = ReplayableSession()
|
||||
creator = config.db.pool._creator
|
||||
recorder = lambda: dbapi_session.recorder(creator())
|
||||
engine = create_engine(
|
||||
config.db.url, creator=recorder,
|
||||
use_native_hstore=False)
|
||||
self.metadata = MetaData(engine)
|
||||
self.engine = engine
|
||||
self.session = Session(engine)
|
||||
|
||||
self.setup_engine()
|
||||
try:
|
||||
self._run_steps(ctx=self._dummy_ctx)
|
||||
finally:
|
||||
self.teardown_engine()
|
||||
engine.dispose()
|
||||
|
||||
player = lambda: dbapi_session.player()
|
||||
engine = create_engine(
|
||||
config.db.url, creator=player,
|
||||
use_native_hstore=False)
|
||||
|
||||
self.metadata = MetaData(engine)
|
||||
self.engine = engine
|
||||
self.session = Session(engine)
|
||||
|
||||
self.setup_engine()
|
||||
try:
|
||||
self._run_steps(ctx=profiling.count_functions)
|
||||
finally:
|
||||
self.session.close()
|
||||
engine.dispose()
|
||||
|
||||
def setup_engine(self):
|
||||
pass
|
||||
|
||||
def teardown_engine(self):
|
||||
pass
|
||||
|
||||
def _run_steps(self, ctx):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ReplayableSession(object):
|
||||
"""A simple record/playback tool.
|
||||
|
||||
This is *not* a mock testing class. It only records a session for later
|
||||
playback and makes no assertions on call consistency whatsoever. It's
|
||||
unlikely to be suitable for anything other than DB-API recording.
|
||||
|
||||
"""
|
||||
|
||||
Callable = object()
|
||||
NoAttribute = object()
|
||||
|
||||
if util.py2k:
|
||||
Natives = set([getattr(types, t)
|
||||
for t in dir(types) if not t.startswith('_')]).\
|
||||
difference([getattr(types, t)
|
||||
for t in ('FunctionType', 'BuiltinFunctionType',
|
||||
'MethodType', 'BuiltinMethodType',
|
||||
'LambdaType', 'UnboundMethodType',)])
|
||||
else:
|
||||
Natives = set([getattr(types, t)
|
||||
for t in dir(types) if not t.startswith('_')]).\
|
||||
union([type(t) if not isinstance(t, type)
|
||||
else t for t in __builtins__.values()]).\
|
||||
difference([getattr(types, t)
|
||||
for t in ('FunctionType', 'BuiltinFunctionType',
|
||||
'MethodType', 'BuiltinMethodType',
|
||||
'LambdaType', )])
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = deque()
|
||||
|
||||
def recorder(self, base):
|
||||
return self.Recorder(self.buffer, base)
|
||||
|
||||
def player(self):
|
||||
return self.Player(self.buffer)
|
||||
|
||||
class Recorder(object):
|
||||
def __init__(self, buffer, subject):
|
||||
self._buffer = buffer
|
||||
self._subject = subject
|
||||
|
||||
def __call__(self, *args, **kw):
|
||||
subject, buffer = [object.__getattribute__(self, x)
|
||||
for x in ('_subject', '_buffer')]
|
||||
|
||||
result = subject(*args, **kw)
|
||||
if type(result) not in ReplayableSession.Natives:
|
||||
buffer.append(ReplayableSession.Callable)
|
||||
return type(self)(buffer, result)
|
||||
else:
|
||||
buffer.append(result)
|
||||
return result
|
||||
|
||||
@property
|
||||
def _sqla_unwrap(self):
|
||||
return self._subject
|
||||
|
||||
def __getattribute__(self, key):
|
||||
try:
|
||||
return object.__getattribute__(self, key)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
subject, buffer = [object.__getattribute__(self, x)
|
||||
for x in ('_subject', '_buffer')]
|
||||
try:
|
||||
result = type(subject).__getattribute__(subject, key)
|
||||
except AttributeError:
|
||||
buffer.append(ReplayableSession.NoAttribute)
|
||||
raise
|
||||
else:
|
||||
if type(result) not in ReplayableSession.Natives:
|
||||
buffer.append(ReplayableSession.Callable)
|
||||
return type(self)(buffer, result)
|
||||
else:
|
||||
buffer.append(result)
|
||||
return result
|
||||
|
||||
class Player(object):
|
||||
def __init__(self, buffer):
|
||||
self._buffer = buffer
|
||||
|
||||
def __call__(self, *args, **kw):
|
||||
buffer = object.__getattribute__(self, '_buffer')
|
||||
result = buffer.popleft()
|
||||
if result is ReplayableSession.Callable:
|
||||
return self
|
||||
else:
|
||||
return result
|
||||
|
||||
@property
|
||||
def _sqla_unwrap(self):
|
||||
return None
|
||||
|
||||
def __getattribute__(self, key):
|
||||
try:
|
||||
return object.__getattribute__(self, key)
|
||||
except AttributeError:
|
||||
pass
|
||||
buffer = object.__getattribute__(self, '_buffer')
|
||||
result = buffer.popleft()
|
||||
if result is ReplayableSession.Callable:
|
||||
return self
|
||||
elif result is ReplayableSession.NoAttribute:
|
||||
raise AttributeError(key)
|
||||
else:
|
||||
return result
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
# testing/requirements.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
|
@ -16,6 +16,7 @@ to provide specific inclusion/exclusions.
|
|||
"""
|
||||
|
||||
from . import exclusions
|
||||
from .. import util
|
||||
|
||||
|
||||
class Requirements(object):
|
||||
|
|
@ -101,6 +102,14 @@ class SuiteRequirements(Requirements):
|
|||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def bound_limit_offset(self):
|
||||
"""target database can render LIMIT and/or OFFSET using a bound
|
||||
parameter
|
||||
"""
|
||||
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def boolean_col_expressions(self):
|
||||
"""Target database must support boolean expressions as columns"""
|
||||
|
|
@ -179,7 +188,7 @@ class SuiteRequirements(Requirements):
|
|||
|
||||
return exclusions.only_if(
|
||||
lambda config: config.db.dialect.implicit_returning,
|
||||
"'returning' not supported by database"
|
||||
"%(database)s %(does_support)s 'returning'"
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
@ -304,6 +313,25 @@ class SuiteRequirements(Requirements):
|
|||
def foreign_key_constraint_reflection(self):
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def temp_table_reflection(self):
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def temp_table_names(self):
|
||||
"""target dialect supports listing of temporary table names"""
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def temporary_tables(self):
|
||||
"""target database supports temporary tables"""
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def temporary_views(self):
|
||||
"""target database supports temporary views"""
|
||||
return exclusions.closed()
|
||||
|
||||
@property
|
||||
def index_reflection(self):
|
||||
return exclusions.open()
|
||||
|
|
@ -313,6 +341,14 @@ class SuiteRequirements(Requirements):
|
|||
"""target dialect supports reflection of unique constraints"""
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def duplicate_key_raises_integrity_error(self):
|
||||
"""target dialect raises IntegrityError when reporting an INSERT
|
||||
with a primary key violation. (hint: it should)
|
||||
|
||||
"""
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def unbounded_varchar(self):
|
||||
"""Target database must support VARCHAR with no length"""
|
||||
|
|
@ -584,6 +620,14 @@ class SuiteRequirements(Requirements):
|
|||
"""
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def graceful_disconnects(self):
|
||||
"""Target driver must raise a DBAPI-level exception, such as
|
||||
InterfaceError, when the underlying connection has been closed
|
||||
and the execute() method is called.
|
||||
"""
|
||||
return exclusions.open()
|
||||
|
||||
@property
|
||||
def skip_mysql_on_windows(self):
|
||||
"""Catchall for a large variety of MySQL on Windows failures"""
|
||||
|
|
@ -601,6 +645,38 @@ class SuiteRequirements(Requirements):
|
|||
return exclusions.skip_if(
|
||||
lambda config: config.options.low_connections)
|
||||
|
||||
@property
|
||||
def timing_intensive(self):
|
||||
return exclusions.requires_tag("timing_intensive")
|
||||
|
||||
@property
|
||||
def memory_intensive(self):
|
||||
return exclusions.requires_tag("memory_intensive")
|
||||
|
||||
@property
|
||||
def threading_with_mock(self):
|
||||
"""Mark tests that use threading and mock at the same time - stability
|
||||
issues have been observed with coverage + python 3.3
|
||||
|
||||
"""
|
||||
return exclusions.skip_if(
|
||||
lambda config: util.py3k and config.options.has_coverage,
|
||||
"Stability issues with coverage + py3k"
|
||||
)
|
||||
|
||||
@property
|
||||
def no_coverage(self):
|
||||
"""Test should be skipped if coverage is enabled.
|
||||
|
||||
This is to block tests that exercise libraries that seem to be
|
||||
sensitive to coverage, such as Postgresql notice logging.
|
||||
|
||||
"""
|
||||
return exclusions.skip_if(
|
||||
lambda config: config.options.has_coverage,
|
||||
"Issues observed when coverage is enabled"
|
||||
)
|
||||
|
||||
def _has_mysql_on_windows(self, config):
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
# testing/runner.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
|
@ -30,7 +30,7 @@ SQLAlchemy itself is possible.
|
|||
|
||||
"""
|
||||
|
||||
from sqlalchemy.testing.plugin.noseplugin import NoseSQLAlchemy
|
||||
from .plugin.noseplugin import NoseSQLAlchemy
|
||||
|
||||
import nose
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# testing/schema.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
|
@ -67,7 +67,7 @@ def Column(*args, **kw):
|
|||
test_opts = dict([(k, kw.pop(k)) for k in list(kw)
|
||||
if k.startswith('test_')])
|
||||
|
||||
if config.requirements.foreign_key_ddl.predicate(config):
|
||||
if not config.requirements.foreign_key_ddl.enabled_for_config(config):
|
||||
args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)]
|
||||
|
||||
col = schema.Column(*args, **kw)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
|
||||
from sqlalchemy.testing.suite.test_dialect import *
|
||||
from sqlalchemy.testing.suite.test_ddl import *
|
||||
from sqlalchemy.testing.suite.test_insert import *
|
||||
from sqlalchemy.testing.suite.test_sequence import *
|
||||
|
|
|
|||
|
|
@ -0,0 +1,41 @@
|
|||
from .. import fixtures, config
|
||||
from ..config import requirements
|
||||
from sqlalchemy import exc
|
||||
from sqlalchemy import Integer, String
|
||||
from .. import assert_raises
|
||||
from ..schema import Table, Column
|
||||
|
||||
|
||||
class ExceptionTest(fixtures.TablesTest):
|
||||
"""Test basic exception wrapping.
|
||||
|
||||
DBAPIs vary a lot in exception behavior so to actually anticipate
|
||||
specific exceptions from real round trips, we need to be conservative.
|
||||
|
||||
"""
|
||||
run_deletes = 'each'
|
||||
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table('manual_pk', metadata,
|
||||
Column('id', Integer, primary_key=True, autoincrement=False),
|
||||
Column('data', String(50))
|
||||
)
|
||||
|
||||
@requirements.duplicate_key_raises_integrity_error
|
||||
def test_integrity_error(self):
|
||||
|
||||
with config.db.begin() as conn:
|
||||
conn.execute(
|
||||
self.tables.manual_pk.insert(),
|
||||
{'id': 1, 'data': 'd1'}
|
||||
)
|
||||
|
||||
assert_raises(
|
||||
exc.IntegrityError,
|
||||
conn.execute,
|
||||
self.tables.manual_pk.insert(),
|
||||
{'id': 1, 'data': 'd1'}
|
||||
)
|
||||
|
|
@ -4,7 +4,7 @@ from .. import exclusions
|
|||
from ..assertions import eq_
|
||||
from .. import engines
|
||||
|
||||
from sqlalchemy import Integer, String, select, util
|
||||
from sqlalchemy import Integer, String, select, literal_column, literal
|
||||
|
||||
from ..schema import Table, Column
|
||||
|
||||
|
|
@ -90,6 +90,13 @@ class InsertBehaviorTest(fixtures.TablesTest):
|
|||
Column('id', Integer, primary_key=True, autoincrement=False),
|
||||
Column('data', String(50))
|
||||
)
|
||||
Table('includes_defaults', metadata,
|
||||
Column('id', Integer, primary_key=True,
|
||||
test_needs_autoincrement=True),
|
||||
Column('data', String(50)),
|
||||
Column('x', Integer, default=5),
|
||||
Column('y', Integer,
|
||||
default=literal_column("2", type_=Integer) + literal(2)))
|
||||
|
||||
def test_autoclose_on_insert(self):
|
||||
if requirements.returning.enabled:
|
||||
|
|
@ -102,7 +109,8 @@ class InsertBehaviorTest(fixtures.TablesTest):
|
|||
self.tables.autoinc_pk.insert(),
|
||||
data="some data"
|
||||
)
|
||||
assert r.closed
|
||||
assert r._soft_closed
|
||||
assert not r.closed
|
||||
assert r.is_insert
|
||||
assert not r.returns_rows
|
||||
|
||||
|
|
@ -112,7 +120,8 @@ class InsertBehaviorTest(fixtures.TablesTest):
|
|||
self.tables.autoinc_pk.insert(),
|
||||
data="some data"
|
||||
)
|
||||
assert r.closed
|
||||
assert r._soft_closed
|
||||
assert not r.closed
|
||||
assert r.is_insert
|
||||
assert not r.returns_rows
|
||||
|
||||
|
|
@ -121,7 +130,8 @@ class InsertBehaviorTest(fixtures.TablesTest):
|
|||
r = config.db.execute(
|
||||
self.tables.autoinc_pk.insert(),
|
||||
)
|
||||
assert r.closed
|
||||
assert r._soft_closed
|
||||
assert not r.closed
|
||||
|
||||
r = config.db.execute(
|
||||
self.tables.autoinc_pk.select().
|
||||
|
|
@ -158,6 +168,34 @@ class InsertBehaviorTest(fixtures.TablesTest):
|
|||
("data3", ), ("data3", )]
|
||||
)
|
||||
|
||||
@requirements.insert_from_select
|
||||
def test_insert_from_select_with_defaults(self):
|
||||
table = self.tables.includes_defaults
|
||||
config.db.execute(
|
||||
table.insert(),
|
||||
[
|
||||
dict(id=1, data="data1"),
|
||||
dict(id=2, data="data2"),
|
||||
dict(id=3, data="data3"),
|
||||
]
|
||||
)
|
||||
|
||||
config.db.execute(
|
||||
table.insert(inline=True).
|
||||
from_select(("id", "data",),
|
||||
select([table.c.id + 5, table.c.data]).
|
||||
where(table.c.data.in_(["data2", "data3"]))
|
||||
),
|
||||
)
|
||||
|
||||
eq_(
|
||||
config.db.execute(
|
||||
select([table]).order_by(table.c.data, table.c.id)
|
||||
).fetchall(),
|
||||
[(1, 'data1', 5, 4), (2, 'data2', 5, 4),
|
||||
(7, 'data2', 5, 4), (3, 'data3', 5, 4), (8, 'data3', 5, 4)]
|
||||
)
|
||||
|
||||
|
||||
class ReturningTest(fixtures.TablesTest):
|
||||
run_create_tables = 'each'
|
||||
|
|
|
|||
|
|
@ -39,11 +39,20 @@ class ComponentReflectionTest(fixtures.TablesTest):
|
|||
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def setup_bind(cls):
|
||||
if config.requirements.independent_connections.enabled:
|
||||
from sqlalchemy import pool
|
||||
return engines.testing_engine(
|
||||
options=dict(poolclass=pool.StaticPool))
|
||||
else:
|
||||
return config.db
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
cls.define_reflected_tables(metadata, None)
|
||||
if testing.requires.schemas.enabled:
|
||||
cls.define_reflected_tables(metadata, "test_schema")
|
||||
cls.define_reflected_tables(metadata, testing.config.test_schema)
|
||||
|
||||
@classmethod
|
||||
def define_reflected_tables(cls, metadata, schema):
|
||||
|
|
@ -95,6 +104,43 @@ class ComponentReflectionTest(fixtures.TablesTest):
|
|||
cls.define_index(metadata, users)
|
||||
if testing.requires.view_column_reflection.enabled:
|
||||
cls.define_views(metadata, schema)
|
||||
if not schema and testing.requires.temp_table_reflection.enabled:
|
||||
cls.define_temp_tables(metadata)
|
||||
|
||||
@classmethod
|
||||
def define_temp_tables(cls, metadata):
|
||||
# cheat a bit, we should fix this with some dialect-level
|
||||
# temp table fixture
|
||||
if testing.against("oracle"):
|
||||
kw = {
|
||||
'prefixes': ["GLOBAL TEMPORARY"],
|
||||
'oracle_on_commit': 'PRESERVE ROWS'
|
||||
}
|
||||
else:
|
||||
kw = {
|
||||
'prefixes': ["TEMPORARY"],
|
||||
}
|
||||
|
||||
user_tmp = Table(
|
||||
"user_tmp", metadata,
|
||||
Column("id", sa.INT, primary_key=True),
|
||||
Column('name', sa.VARCHAR(50)),
|
||||
Column('foo', sa.INT),
|
||||
sa.UniqueConstraint('name', name='user_tmp_uq'),
|
||||
sa.Index("user_tmp_ix", "foo"),
|
||||
**kw
|
||||
)
|
||||
if testing.requires.view_reflection.enabled and \
|
||||
testing.requires.temporary_views.enabled:
|
||||
event.listen(
|
||||
user_tmp, "after_create",
|
||||
DDL("create temporary view user_tmp_v as "
|
||||
"select * from user_tmp")
|
||||
)
|
||||
event.listen(
|
||||
user_tmp, "before_drop",
|
||||
DDL("drop view user_tmp_v")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def define_index(cls, metadata, users):
|
||||
|
|
@ -126,7 +172,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
|
|||
def test_get_schema_names(self):
|
||||
insp = inspect(testing.db)
|
||||
|
||||
self.assert_('test_schema' in insp.get_schema_names())
|
||||
self.assert_(testing.config.test_schema in insp.get_schema_names())
|
||||
|
||||
@testing.requires.schema_reflection
|
||||
def test_dialect_initialize(self):
|
||||
|
|
@ -147,6 +193,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
|
|||
users, addresses, dingalings = self.tables.users, \
|
||||
self.tables.email_addresses, self.tables.dingalings
|
||||
insp = inspect(meta.bind)
|
||||
|
||||
if table_type == 'view':
|
||||
table_names = insp.get_view_names(schema)
|
||||
table_names.sort()
|
||||
|
|
@ -162,6 +209,20 @@ class ComponentReflectionTest(fixtures.TablesTest):
|
|||
answer = ['dingalings', 'email_addresses', 'users']
|
||||
eq_(sorted(table_names), answer)
|
||||
|
||||
@testing.requires.temp_table_names
|
||||
def test_get_temp_table_names(self):
|
||||
insp = inspect(self.bind)
|
||||
temp_table_names = insp.get_temp_table_names()
|
||||
eq_(sorted(temp_table_names), ['user_tmp'])
|
||||
|
||||
@testing.requires.view_reflection
|
||||
@testing.requires.temp_table_names
|
||||
@testing.requires.temporary_views
|
||||
def test_get_temp_view_names(self):
|
||||
insp = inspect(self.bind)
|
||||
temp_table_names = insp.get_temp_view_names()
|
||||
eq_(sorted(temp_table_names), ['user_tmp_v'])
|
||||
|
||||
@testing.requires.table_reflection
|
||||
def test_get_table_names(self):
|
||||
self._test_get_table_names()
|
||||
|
|
@ -174,7 +235,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
|
|||
@testing.requires.table_reflection
|
||||
@testing.requires.schemas
|
||||
def test_get_table_names_with_schema(self):
|
||||
self._test_get_table_names('test_schema')
|
||||
self._test_get_table_names(testing.config.test_schema)
|
||||
|
||||
@testing.requires.view_column_reflection
|
||||
def test_get_view_names(self):
|
||||
|
|
@ -183,7 +244,8 @@ class ComponentReflectionTest(fixtures.TablesTest):
|
|||
@testing.requires.view_column_reflection
|
||||
@testing.requires.schemas
|
||||
def test_get_view_names_with_schema(self):
|
||||
self._test_get_table_names('test_schema', table_type='view')
|
||||
self._test_get_table_names(
|
||||
testing.config.test_schema, table_type='view')
|
||||
|
||||
@testing.requires.table_reflection
|
||||
@testing.requires.view_column_reflection
|
||||
|
|
@ -291,7 +353,29 @@ class ComponentReflectionTest(fixtures.TablesTest):
|
|||
@testing.requires.table_reflection
|
||||
@testing.requires.schemas
|
||||
def test_get_columns_with_schema(self):
|
||||
self._test_get_columns(schema='test_schema')
|
||||
self._test_get_columns(schema=testing.config.test_schema)
|
||||
|
||||
@testing.requires.temp_table_reflection
|
||||
def test_get_temp_table_columns(self):
|
||||
meta = MetaData(self.bind)
|
||||
user_tmp = self.tables.user_tmp
|
||||
insp = inspect(meta.bind)
|
||||
cols = insp.get_columns('user_tmp')
|
||||
self.assert_(len(cols) > 0, len(cols))
|
||||
|
||||
for i, col in enumerate(user_tmp.columns):
|
||||
eq_(col.name, cols[i]['name'])
|
||||
|
||||
@testing.requires.temp_table_reflection
|
||||
@testing.requires.view_column_reflection
|
||||
@testing.requires.temporary_views
|
||||
def test_get_temp_view_columns(self):
|
||||
insp = inspect(self.bind)
|
||||
cols = insp.get_columns('user_tmp_v')
|
||||
eq_(
|
||||
[col['name'] for col in cols],
|
||||
['id', 'name', 'foo']
|
||||
)
|
||||
|
||||
@testing.requires.view_column_reflection
|
||||
def test_get_view_columns(self):
|
||||
|
|
@ -300,7 +384,8 @@ class ComponentReflectionTest(fixtures.TablesTest):
|
|||
@testing.requires.view_column_reflection
|
||||
@testing.requires.schemas
|
||||
def test_get_view_columns_with_schema(self):
|
||||
self._test_get_columns(schema='test_schema', table_type='view')
|
||||
self._test_get_columns(
|
||||
schema=testing.config.test_schema, table_type='view')
|
||||
|
||||
@testing.provide_metadata
|
||||
def _test_get_pk_constraint(self, schema=None):
|
||||
|
|
@ -327,7 +412,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
|
|||
@testing.requires.primary_key_constraint_reflection
|
||||
@testing.requires.schemas
|
||||
def test_get_pk_constraint_with_schema(self):
|
||||
self._test_get_pk_constraint(schema='test_schema')
|
||||
self._test_get_pk_constraint(schema=testing.config.test_schema)
|
||||
|
||||
@testing.requires.table_reflection
|
||||
@testing.provide_metadata
|
||||
|
|
@ -385,7 +470,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
|
|||
@testing.requires.foreign_key_constraint_reflection
|
||||
@testing.requires.schemas
|
||||
def test_get_foreign_keys_with_schema(self):
|
||||
self._test_get_foreign_keys(schema='test_schema')
|
||||
self._test_get_foreign_keys(schema=testing.config.test_schema)
|
||||
|
||||
@testing.provide_metadata
|
||||
def _test_get_indexes(self, schema=None):
|
||||
|
|
@ -418,25 +503,57 @@ class ComponentReflectionTest(fixtures.TablesTest):
|
|||
@testing.requires.index_reflection
|
||||
@testing.requires.schemas
|
||||
def test_get_indexes_with_schema(self):
|
||||
self._test_get_indexes(schema='test_schema')
|
||||
self._test_get_indexes(schema=testing.config.test_schema)
|
||||
|
||||
@testing.requires.unique_constraint_reflection
|
||||
def test_get_unique_constraints(self):
|
||||
self._test_get_unique_constraints()
|
||||
|
||||
@testing.requires.temp_table_reflection
|
||||
@testing.requires.unique_constraint_reflection
|
||||
def test_get_temp_table_unique_constraints(self):
|
||||
insp = inspect(self.bind)
|
||||
reflected = insp.get_unique_constraints('user_tmp')
|
||||
for refl in reflected:
|
||||
# Different dialects handle duplicate index and constraints
|
||||
# differently, so ignore this flag
|
||||
refl.pop('duplicates_index', None)
|
||||
eq_(reflected, [{'column_names': ['name'], 'name': 'user_tmp_uq'}])
|
||||
|
||||
@testing.requires.temp_table_reflection
|
||||
def test_get_temp_table_indexes(self):
|
||||
insp = inspect(self.bind)
|
||||
indexes = insp.get_indexes('user_tmp')
|
||||
for ind in indexes:
|
||||
ind.pop('dialect_options', None)
|
||||
eq_(
|
||||
# TODO: we need to add better filtering for indexes/uq constraints
|
||||
# that are doubled up
|
||||
[idx for idx in indexes if idx['name'] == 'user_tmp_ix'],
|
||||
[{'unique': False, 'column_names': ['foo'], 'name': 'user_tmp_ix'}]
|
||||
)
|
||||
|
||||
@testing.requires.unique_constraint_reflection
|
||||
@testing.requires.schemas
|
||||
def test_get_unique_constraints_with_schema(self):
|
||||
self._test_get_unique_constraints(schema='test_schema')
|
||||
self._test_get_unique_constraints(schema=testing.config.test_schema)
|
||||
|
||||
@testing.provide_metadata
|
||||
def _test_get_unique_constraints(self, schema=None):
|
||||
# SQLite dialect needs to parse the names of the constraints
|
||||
# separately from what it gets from PRAGMA index_list(), and
|
||||
# then matches them up. so same set of column_names in two
|
||||
# constraints will confuse it. Perhaps we should no longer
|
||||
# bother with index_list() here since we have the whole
|
||||
# CREATE TABLE?
|
||||
uniques = sorted(
|
||||
[
|
||||
{'name': 'unique_a', 'column_names': ['a']},
|
||||
{'name': 'unique_a_b_c', 'column_names': ['a', 'b', 'c']},
|
||||
{'name': 'unique_c_a_b', 'column_names': ['c', 'a', 'b']},
|
||||
{'name': 'unique_asc_key', 'column_names': ['asc', 'key']},
|
||||
{'name': 'i.have.dots', 'column_names': ['b']},
|
||||
{'name': 'i have spaces', 'column_names': ['c']},
|
||||
],
|
||||
key=operator.itemgetter('name')
|
||||
)
|
||||
|
|
@ -464,6 +581,9 @@ class ComponentReflectionTest(fixtures.TablesTest):
|
|||
)
|
||||
|
||||
for orig, refl in zip(uniques, reflected):
|
||||
# Different dialects handle duplicate index and constraints
|
||||
# differently, so ignore this flag
|
||||
refl.pop('duplicates_index', None)
|
||||
eq_(orig, refl)
|
||||
|
||||
@testing.provide_metadata
|
||||
|
|
@ -486,7 +606,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
|
|||
@testing.requires.view_reflection
|
||||
@testing.requires.schemas
|
||||
def test_get_view_definition_with_schema(self):
|
||||
self._test_get_view_definition(schema='test_schema')
|
||||
self._test_get_view_definition(schema=testing.config.test_schema)
|
||||
|
||||
@testing.only_on("postgresql", "PG specific feature")
|
||||
@testing.provide_metadata
|
||||
|
|
@ -503,7 +623,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
|
|||
|
||||
@testing.requires.schemas
|
||||
def test_get_table_oid_with_schema(self):
|
||||
self._test_get_table_oid('users', schema='test_schema')
|
||||
self._test_get_table_oid('users', schema=testing.config.test_schema)
|
||||
|
||||
@testing.requires.table_reflection
|
||||
@testing.provide_metadata
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@ from .. import fixtures, config
|
|||
from ..assertions import eq_
|
||||
|
||||
from sqlalchemy import util
|
||||
from sqlalchemy import Integer, String, select, func
|
||||
from sqlalchemy import Integer, String, select, func, bindparam
|
||||
from sqlalchemy import testing
|
||||
|
||||
from ..schema import Table, Column
|
||||
|
||||
|
||||
class OrderByLabelTest(fixtures.TablesTest):
|
||||
|
||||
"""Test the dialect sends appropriate ORDER BY expressions when
|
||||
labels are used.
|
||||
|
||||
|
|
@ -85,3 +85,108 @@ class OrderByLabelTest(fixtures.TablesTest):
|
|||
select([lx]).order_by(lx.desc()),
|
||||
[(7, ), (5, ), (3, )]
|
||||
)
|
||||
|
||||
def test_group_by_composed(self):
|
||||
table = self.tables.some_table
|
||||
expr = (table.c.x + table.c.y).label('lx')
|
||||
stmt = select([func.count(table.c.id), expr]).group_by(expr).order_by(expr)
|
||||
self._assert_result(
|
||||
stmt,
|
||||
[(1, 3), (1, 5), (1, 7)]
|
||||
)
|
||||
|
||||
|
||||
class LimitOffsetTest(fixtures.TablesTest):
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table("some_table", metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('x', Integer),
|
||||
Column('y', Integer))
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls):
|
||||
config.db.execute(
|
||||
cls.tables.some_table.insert(),
|
||||
[
|
||||
{"id": 1, "x": 1, "y": 2},
|
||||
{"id": 2, "x": 2, "y": 3},
|
||||
{"id": 3, "x": 3, "y": 4},
|
||||
{"id": 4, "x": 4, "y": 5},
|
||||
]
|
||||
)
|
||||
|
||||
def _assert_result(self, select, result, params=()):
|
||||
eq_(
|
||||
config.db.execute(select, params).fetchall(),
|
||||
result
|
||||
)
|
||||
|
||||
def test_simple_limit(self):
|
||||
table = self.tables.some_table
|
||||
self._assert_result(
|
||||
select([table]).order_by(table.c.id).limit(2),
|
||||
[(1, 1, 2), (2, 2, 3)]
|
||||
)
|
||||
|
||||
@testing.requires.offset
|
||||
def test_simple_offset(self):
|
||||
table = self.tables.some_table
|
||||
self._assert_result(
|
||||
select([table]).order_by(table.c.id).offset(2),
|
||||
[(3, 3, 4), (4, 4, 5)]
|
||||
)
|
||||
|
||||
@testing.requires.offset
|
||||
def test_simple_limit_offset(self):
|
||||
table = self.tables.some_table
|
||||
self._assert_result(
|
||||
select([table]).order_by(table.c.id).limit(2).offset(1),
|
||||
[(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
@testing.requires.offset
|
||||
def test_limit_offset_nobinds(self):
|
||||
"""test that 'literal binds' mode works - no bound params."""
|
||||
|
||||
table = self.tables.some_table
|
||||
stmt = select([table]).order_by(table.c.id).limit(2).offset(1)
|
||||
sql = stmt.compile(
|
||||
dialect=config.db.dialect,
|
||||
compile_kwargs={"literal_binds": True})
|
||||
sql = str(sql)
|
||||
|
||||
self._assert_result(
|
||||
sql,
|
||||
[(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
@testing.requires.bound_limit_offset
|
||||
def test_bound_limit(self):
|
||||
table = self.tables.some_table
|
||||
self._assert_result(
|
||||
select([table]).order_by(table.c.id).limit(bindparam('l')),
|
||||
[(1, 1, 2), (2, 2, 3)],
|
||||
params={"l": 2}
|
||||
)
|
||||
|
||||
@testing.requires.bound_limit_offset
|
||||
def test_bound_offset(self):
|
||||
table = self.tables.some_table
|
||||
self._assert_result(
|
||||
select([table]).order_by(table.c.id).offset(bindparam('o')),
|
||||
[(3, 3, 4), (4, 4, 5)],
|
||||
params={"o": 2}
|
||||
)
|
||||
|
||||
@testing.requires.bound_limit_offset
|
||||
def test_bound_limit_offset(self):
|
||||
table = self.tables.some_table
|
||||
self._assert_result(
|
||||
select([table]).order_by(table.c.id).
|
||||
limit(bindparam("l")).offset(bindparam("o")),
|
||||
[(2, 2, 3), (3, 3, 4)],
|
||||
params={"l": 2, "o": 1}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -86,11 +86,11 @@ class HasSequenceTest(fixtures.TestBase):
|
|||
|
||||
@testing.requires.schemas
|
||||
def test_has_sequence_schema(self):
|
||||
s1 = Sequence('user_id_seq', schema="test_schema")
|
||||
s1 = Sequence('user_id_seq', schema=config.test_schema)
|
||||
testing.db.execute(schema.CreateSequence(s1))
|
||||
try:
|
||||
eq_(testing.db.dialect.has_sequence(
|
||||
testing.db, 'user_id_seq', schema="test_schema"), True)
|
||||
testing.db, 'user_id_seq', schema=config.test_schema), True)
|
||||
finally:
|
||||
testing.db.execute(schema.DropSequence(s1))
|
||||
|
||||
|
|
@ -101,7 +101,7 @@ class HasSequenceTest(fixtures.TestBase):
|
|||
@testing.requires.schemas
|
||||
def test_has_sequence_schemas_neg(self):
|
||||
eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq',
|
||||
schema="test_schema"),
|
||||
schema=config.test_schema),
|
||||
False)
|
||||
|
||||
@testing.requires.schemas
|
||||
|
|
@ -110,14 +110,14 @@ class HasSequenceTest(fixtures.TestBase):
|
|||
testing.db.execute(schema.CreateSequence(s1))
|
||||
try:
|
||||
eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq',
|
||||
schema="test_schema"),
|
||||
schema=config.test_schema),
|
||||
False)
|
||||
finally:
|
||||
testing.db.execute(schema.DropSequence(s1))
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_has_sequence_remote_not_in_default(self):
|
||||
s1 = Sequence('user_id_seq', schema="test_schema")
|
||||
s1 = Sequence('user_id_seq', schema=config.test_schema)
|
||||
testing.db.execute(schema.CreateSequence(s1))
|
||||
try:
|
||||
eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'),
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# testing/util.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
|
@ -147,6 +147,10 @@ def run_as_contextmanager(ctx, fn, *arg, **kw):
|
|||
simulating the behavior of 'with' to support older
|
||||
Python versions.
|
||||
|
||||
This is not necessary anymore as we have placed 2.6
|
||||
as minimum Python version, however some tests are still using
|
||||
this structure.
|
||||
|
||||
"""
|
||||
|
||||
obj = ctx.__enter__()
|
||||
|
|
@ -181,6 +185,7 @@ def provide_metadata(fn, *args, **kw):
|
|||
"""Provide bound MetaData for a single test, dropping afterwards."""
|
||||
|
||||
from . import config
|
||||
from . import engines
|
||||
from sqlalchemy import schema
|
||||
|
||||
metadata = schema.MetaData(config.db)
|
||||
|
|
@ -190,10 +195,29 @@ def provide_metadata(fn, *args, **kw):
|
|||
try:
|
||||
return fn(*args, **kw)
|
||||
finally:
|
||||
metadata.drop_all()
|
||||
engines.drop_all_tables(metadata, config.db)
|
||||
self.metadata = prev_meta
|
||||
|
||||
|
||||
def force_drop_names(*names):
|
||||
"""Force the given table names to be dropped after test complete,
|
||||
isolating for foreign key cycles
|
||||
|
||||
"""
|
||||
from . import config
|
||||
from sqlalchemy import inspect
|
||||
|
||||
@decorator
|
||||
def go(fn, *args, **kw):
|
||||
|
||||
try:
|
||||
return fn(*args, **kw)
|
||||
finally:
|
||||
drop_all_tables(
|
||||
config.db, inspect(config.db), include_names=names)
|
||||
return go
|
||||
|
||||
|
||||
class adict(dict):
|
||||
"""Dict keys available as attributes. Shadows."""
|
||||
|
||||
|
|
@ -203,5 +227,54 @@ class adict(dict):
|
|||
except KeyError:
|
||||
return dict.__getattribute__(self, key)
|
||||
|
||||
def get_all(self, *keys):
|
||||
def __call__(self, *keys):
|
||||
return tuple([self[key] for key in keys])
|
||||
|
||||
get_all = __call__
|
||||
|
||||
|
||||
def drop_all_tables(engine, inspector, schema=None, include_names=None):
|
||||
from sqlalchemy import Column, Table, Integer, MetaData, \
|
||||
ForeignKeyConstraint
|
||||
from sqlalchemy.schema import DropTable, DropConstraint
|
||||
|
||||
if include_names is not None:
|
||||
include_names = set(include_names)
|
||||
|
||||
with engine.connect() as conn:
|
||||
for tname, fkcs in reversed(
|
||||
inspector.get_sorted_table_and_fkc_names(schema=schema)):
|
||||
if tname:
|
||||
if include_names is not None and tname not in include_names:
|
||||
continue
|
||||
conn.execute(DropTable(
|
||||
Table(tname, MetaData(), schema=schema)
|
||||
))
|
||||
elif fkcs:
|
||||
if not engine.dialect.supports_alter:
|
||||
continue
|
||||
for tname, fkc in fkcs:
|
||||
if include_names is not None and \
|
||||
tname not in include_names:
|
||||
continue
|
||||
tb = Table(
|
||||
tname, MetaData(),
|
||||
Column('x', Integer),
|
||||
Column('y', Integer),
|
||||
schema=schema
|
||||
)
|
||||
conn.execute(DropConstraint(
|
||||
ForeignKeyConstraint(
|
||||
[tb.c.x], [tb.c.y], name=fkc)
|
||||
))
|
||||
|
||||
|
||||
def teardown_events(event_cls):
|
||||
@decorator
|
||||
def decorate(fn, *arg, **kw):
|
||||
try:
|
||||
return fn(*arg, **kw)
|
||||
finally:
|
||||
event_cls._clear()
|
||||
return decorate
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# testing/warnings.py
|
||||
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
|
@ -9,25 +9,11 @@ from __future__ import absolute_import
|
|||
|
||||
import warnings
|
||||
from .. import exc as sa_exc
|
||||
from .. import util
|
||||
import re
|
||||
from . import assertions
|
||||
|
||||
|
||||
def testing_warn(msg, stacklevel=3):
|
||||
"""Replaces sqlalchemy.util.warn during tests."""
|
||||
|
||||
filename = "sqlalchemy.testing.warnings"
|
||||
lineno = 1
|
||||
if isinstance(msg, util.string_types):
|
||||
warnings.warn_explicit(msg, sa_exc.SAWarning, filename, lineno)
|
||||
else:
|
||||
warnings.warn_explicit(msg, filename, lineno)
|
||||
|
||||
|
||||
def resetwarnings():
|
||||
"""Reset warning behavior to testing defaults."""
|
||||
|
||||
util.warn = util.langhelpers.warn = testing_warn
|
||||
def setup_filters():
|
||||
"""Set global warning behavior for the test suite."""
|
||||
|
||||
warnings.filterwarnings('ignore',
|
||||
category=sa_exc.SAPendingDeprecationWarning)
|
||||
|
|
@ -35,24 +21,14 @@ def resetwarnings():
|
|||
warnings.filterwarnings('error', category=sa_exc.SAWarning)
|
||||
|
||||
|
||||
def assert_warnings(fn, warnings, regex=False):
|
||||
"""Assert that each of the given warnings are emitted by fn."""
|
||||
def assert_warnings(fn, warning_msgs, regex=False):
|
||||
"""Assert that each of the given warnings are emitted by fn.
|
||||
|
||||
from .assertions import eq_, emits_warning
|
||||
Deprecated. Please use assertions.expect_warnings().
|
||||
|
||||
canary = []
|
||||
orig_warn = util.warn
|
||||
"""
|
||||
|
||||
def capture_warnings(*args, **kw):
|
||||
orig_warn(*args, **kw)
|
||||
popwarn = warnings.pop(0)
|
||||
canary.append(popwarn)
|
||||
if regex:
|
||||
assert re.match(popwarn, args[0])
|
||||
else:
|
||||
eq_(args[0], popwarn)
|
||||
util.warn = util.langhelpers.warn = capture_warnings
|
||||
with assertions._expect_warnings(
|
||||
sa_exc.SAWarning, warning_msgs, regex=regex):
|
||||
return fn()
|
||||
|
||||
result = emits_warning()(fn)()
|
||||
assert canary, "No warning was emitted"
|
||||
return result
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue