update sqlalchemy

This commit is contained in:
j 2016-02-22 16:43:36 +05:30
commit 3b436646a2
362 changed files with 37720 additions and 11021 deletions

View file

@ -1,12 +1,12 @@
# testing/__init__.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from .warnings import testing_warn, assert_warnings, resetwarnings
from .warnings import assert_warnings
from . import config
@ -19,11 +19,14 @@ def against(*queries):
return _against(config._current, *queries)
from .assertions import emits_warning, emits_warning_on, uses_deprecated, \
eq_, ne_, is_, is_not_, startswith_, assert_raises, \
eq_, ne_, le_, is_, is_not_, startswith_, assert_raises, \
assert_raises_message, AssertsCompiledSQL, ComparesTables, \
AssertsExecutionResults, expect_deprecated
AssertsExecutionResults, expect_deprecated, expect_warnings, \
in_, not_in_
from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict
from .util import run_as_contextmanager, rowset, fail, \
provide_metadata, adict, force_drop_names, \
teardown_events
crashes = skip

View file

@ -1,5 +1,5 @@
# testing/assertions.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@ -9,79 +9,88 @@ from __future__ import absolute_import
from . import util as testutil
from sqlalchemy import pool, orm, util
from sqlalchemy.engine import default, create_engine, url
from sqlalchemy import exc as sa_exc
from sqlalchemy.engine import default, url
from sqlalchemy.util import decorator
from sqlalchemy import types as sqltypes, schema
from sqlalchemy import types as sqltypes, schema, exc as sa_exc
import warnings
import re
from .warnings import resetwarnings
from .exclusions import db_spec, _is_excluded
from . import assertsql
from . import config
import itertools
from .util import fail
import contextlib
from . import mock
def expect_warnings(*messages, **kw):
"""Context manager which expects one or more warnings.
With no arguments, squelches all SAWarnings emitted via
sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise
pass string expressions that will match selected warnings via regex;
all non-matching warnings are sent through.
The expect version **asserts** that the warnings were in fact seen.
Note that the test suite sets SAWarning warnings to raise exceptions.
"""
return _expect_warnings(sa_exc.SAWarning, messages, **kw)
@contextlib.contextmanager
def expect_warnings_on(db, *messages, **kw):
"""Context manager which expects one or more warnings on specific
dialects.
The expect version **asserts** that the warnings were in fact seen.
"""
spec = db_spec(db)
if isinstance(db, util.string_types) and not spec(config._current):
yield
else:
with expect_warnings(*messages, **kw):
yield
def emits_warning(*messages):
"""Mark a test as emitting a warning.
"""Decorator form of expect_warnings().
Note that emits_warning does **not** assert that the warnings
were in fact seen.
With no arguments, squelches all SAWarning failures. Or pass one or more
strings; these will be matched to the root of the warning description by
warnings.filterwarnings().
"""
# TODO: it would be nice to assert that a named warning was
# emitted. should work with some monkeypatching of warnings,
# and may work on non-CPython if they keep to the spirit of
# warnings.showwarning's docstring.
# - update: jython looks ok, it uses cpython's module
@decorator
def decorate(fn, *args, **kw):
# todo: should probably be strict about this, too
filters = [dict(action='ignore',
category=sa_exc.SAPendingDeprecationWarning)]
if not messages:
filters.append(dict(action='ignore',
category=sa_exc.SAWarning))
else:
filters.extend(dict(action='ignore',
message=message,
category=sa_exc.SAWarning)
for message in messages)
for f in filters:
warnings.filterwarnings(**f)
try:
with expect_warnings(assert_=False, *messages):
return fn(*args, **kw)
finally:
resetwarnings()
return decorate
def emits_warning_on(db, *warnings):
def expect_deprecated(*messages, **kw):
return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw)
def emits_warning_on(db, *messages):
"""Mark a test as emitting a warning on a specific dialect.
With no arguments, squelches all SAWarning failures. Or pass one or more
strings; these will be matched to the root of the warning description by
warnings.filterwarnings().
"""
spec = db_spec(db)
Note that emits_warning_on does **not** assert that the warnings
were in fact seen.
"""
@decorator
def decorate(fn, *args, **kw):
if isinstance(db, util.string_types):
if not spec(config._current):
return fn(*args, **kw)
else:
wrapped = emits_warning(*warnings)(fn)
return wrapped(*args, **kw)
else:
if not _is_excluded(*db):
return fn(*args, **kw)
else:
wrapped = emits_warning(*warnings)(fn)
return wrapped(*args, **kw)
with expect_warnings_on(db, assert_=False, *messages):
return fn(*args, **kw)
return decorate
@ -95,39 +104,52 @@ def uses_deprecated(*messages):
As a special case, you may pass a function name prefixed with //
and it will be re-written as needed to match the standard warning
verbiage emitted by the sqlalchemy.util.deprecated decorator.
Note that uses_deprecated does **not** assert that the warnings
were in fact seen.
"""
@decorator
def decorate(fn, *args, **kw):
with expect_deprecated(*messages):
with expect_deprecated(*messages, assert_=False):
return fn(*args, **kw)
return decorate
@contextlib.contextmanager
def expect_deprecated(*messages):
# todo: should probably be strict about this, too
filters = [dict(action='ignore',
category=sa_exc.SAPendingDeprecationWarning)]
if not messages:
filters.append(dict(action='ignore',
category=sa_exc.SADeprecationWarning))
else:
filters.extend(
[dict(action='ignore',
message=message,
category=sa_exc.SADeprecationWarning)
for message in
[(m.startswith('//') and
('Call to deprecated function ' + m[2:]) or m)
for m in messages]])
def _expect_warnings(exc_cls, messages, regex=True, assert_=True):
for f in filters:
warnings.filterwarnings(**f)
try:
if regex:
filters = [re.compile(msg, re.I | re.S) for msg in messages]
else:
filters = messages
seen = set(filters)
real_warn = warnings.warn
def our_warn(msg, exception, *arg, **kw):
if not issubclass(exception, exc_cls):
return real_warn(msg, exception, *arg, **kw)
if not filters:
return
for filter_ in filters:
if (regex and filter_.match(msg)) or \
(not regex and filter_ == msg):
seen.discard(filter_)
break
else:
real_warn(msg, exception, *arg, **kw)
with mock.patch("warnings.warn", our_warn):
yield
finally:
resetwarnings()
if assert_:
assert not seen, "Warnings were not seen: %s" % \
", ".join("%r" % (s.pattern if regex else s) for s in seen)
def global_cleanup_assertions():
@ -192,6 +214,11 @@ def ne_(a, b, msg=None):
assert a != b, msg or "%r == %r" % (a, b)
def le_(a, b, msg=None):
"""Assert a <= b, with repr messaging on failure."""
assert a <= b, msg or "%r != %r" % (a, b)
def is_(a, b, msg=None):
"""Assert a is b, with repr messaging on failure."""
assert a is b, msg or "%r is not %r" % (a, b)
@ -202,6 +229,16 @@ def is_not_(a, b, msg=None):
assert a is not b, msg or "%r is %r" % (a, b)
def in_(a, b, msg=None):
"""Assert a in b, with repr messaging on failure."""
assert a in b, msg or "%r not in %r" % (a, b)
def not_in_(a, b, msg=None):
"""Assert a in not b, with repr messaging on failure."""
assert a not in b, msg or "%r is in %r" % (a, b)
def startswith_(a, fragment, msg=None):
"""Assert a.startswith(fragment), with repr messaging on failure."""
assert a.startswith(fragment), msg or "%r does not start with %r" % (
@ -233,6 +270,7 @@ class AssertsCompiledSQL(object):
def assert_compile(self, clause, result, params=None,
checkparams=None, dialect=None,
checkpositional=None,
check_prefetch=None,
use_default_dialect=False,
allow_dialect_select=False,
literal_binds=False):
@ -293,6 +331,8 @@ class AssertsCompiledSQL(object):
if checkpositional is not None:
p = c.construct_params(params)
eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
if check_prefetch is not None:
eq_(c.prefetch, check_prefetch)
class ComparesTables(object):
@ -409,29 +449,27 @@ class AssertsExecutionResults(object):
cls.__name__, repr(expected_item)))
return True
def assert_sql_execution(self, db, callable_, *rules):
assertsql.asserter.add_rules(rules)
try:
callable_()
assertsql.asserter.statement_complete()
finally:
assertsql.asserter.clear_rules()
def sql_execution_asserter(self, db=None):
if db is None:
from . import db as db
def assert_sql(self, db, callable_, list_, with_sequences=None):
if (with_sequences is not None and
config.db.dialect.supports_sequences):
rules = with_sequences
else:
rules = list_
return assertsql.assert_engine(db)
def assert_sql_execution(self, db, callable_, *rules):
with self.sql_execution_asserter(db) as asserter:
callable_()
asserter.assert_(*rules)
def assert_sql(self, db, callable_, rules):
newrules = []
for rule in rules:
if isinstance(rule, dict):
newrule = assertsql.AllOf(*[
assertsql.ExactSQL(k, v) for k, v in rule.items()
assertsql.CompiledSQL(k, v) for k, v in rule.items()
])
else:
newrule = assertsql.ExactSQL(*rule)
newrule = assertsql.CompiledSQL(*rule)
newrules.append(newrule)
self.assert_sql_execution(db, callable_, *newrules)

View file

@ -1,5 +1,5 @@
# testing/assertsql.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@ -8,84 +8,151 @@
from ..engine.default import DefaultDialect
from .. import util
import re
import collections
import contextlib
from .. import event
from sqlalchemy.schema import _DDLCompiles
from sqlalchemy.engine.util import _distill_params
from sqlalchemy.engine import url
class AssertRule(object):
def process_execute(self, clauseelement, *multiparams, **params):
is_consumed = False
errormessage = None
consume_statement = True
def process_statement(self, execute_observed):
pass
def process_cursor_execute(self, statement, parameters, context,
executemany):
pass
def is_consumed(self):
"""Return True if this rule has been consumed, False if not.
Should raise an AssertionError if this rule's condition has
definitely failed.
"""
raise NotImplementedError()
def rule_passed(self):
"""Return True if the last test of this rule passed, False if
failed, None if no test was applied."""
raise NotImplementedError()
def consume_final(self):
"""Return True if this rule has been consumed.
Should raise an AssertionError if this rule's condition has not
been consumed or has failed.
"""
if self._result is None:
assert False, 'Rule has not been consumed'
return self.is_consumed()
def no_more_statements(self):
assert False, 'All statements are complete, but pending '\
'assertion rules remain'
class SQLMatchRule(AssertRule):
def __init__(self):
self._result = None
self._errmsg = ""
def rule_passed(self):
return self._result
def is_consumed(self):
if self._result is None:
return False
assert self._result, self._errmsg
return True
pass
class ExactSQL(SQLMatchRule):
class CursorSQL(SQLMatchRule):
consume_statement = False
def __init__(self, sql, params=None):
SQLMatchRule.__init__(self)
self.sql = sql
def __init__(self, statement, params=None):
self.statement = statement
self.params = params
def process_cursor_execute(self, statement, parameters, context,
executemany):
if not context:
return
_received_statement = \
_process_engine_statement(context.unicode_statement,
context)
_received_parameters = context.compiled_parameters
def process_statement(self, execute_observed):
stmt = execute_observed.statements[0]
if self.statement != stmt.statement or (
self.params is not None and self.params != stmt.parameters):
self.errormessage = \
"Testing for exact SQL %s parameters %s received %s %s" % (
self.statement, self.params,
stmt.statement, stmt.parameters
)
else:
execute_observed.statements.pop(0)
self.is_consumed = True
if not execute_observed.statements:
self.consume_statement = True
# TODO: remove this step once all unit tests are migrated, as
# ExactSQL should really be *exact* SQL
sql = _process_assertion_statement(self.sql, context)
equivalent = _received_statement == sql
class CompiledSQL(SQLMatchRule):
def __init__(self, statement, params=None, dialect='default'):
self.statement = statement
self.params = params
self.dialect = dialect
def _compare_sql(self, execute_observed, received_statement):
stmt = re.sub(r'[\n\t]', '', self.statement)
return received_statement == stmt
def _compile_dialect(self, execute_observed):
if self.dialect == 'default':
return DefaultDialect()
else:
# ugh
if self.dialect == 'postgresql':
params = {'implicit_returning': True}
else:
params = {}
return url.URL(self.dialect).get_dialect()(**params)
def _received_statement(self, execute_observed):
"""reconstruct the statement and params in terms
of a target dialect, which for CompiledSQL is just DefaultDialect."""
context = execute_observed.context
compare_dialect = self._compile_dialect(execute_observed)
if isinstance(context.compiled.statement, _DDLCompiles):
compiled = \
context.compiled.statement.compile(dialect=compare_dialect)
else:
compiled = (
context.compiled.statement.compile(
dialect=compare_dialect,
column_keys=context.compiled.column_keys,
inline=context.compiled.inline)
)
_received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled))
parameters = execute_observed.parameters
if not parameters:
_received_parameters = [compiled.construct_params()]
else:
_received_parameters = [
compiled.construct_params(m) for m in parameters]
return _received_statement, _received_parameters
def process_statement(self, execute_observed):
context = execute_observed.context
_received_statement, _received_parameters = \
self._received_statement(execute_observed)
params = self._all_params(context)
equivalent = self._compare_sql(execute_observed, _received_statement)
if equivalent:
if params is not None:
all_params = list(params)
all_received = list(_received_parameters)
while all_params and all_received:
param = dict(all_params.pop(0))
for idx, received in enumerate(list(all_received)):
# do a positive compare only
for param_key in param:
# a key in param did not match current
# 'received'
if param_key not in received or \
received[param_key] != param[param_key]:
break
else:
# all keys in param matched 'received';
# onto next param
del all_received[idx]
break
else:
# param did not match any entry
# in all_received
equivalent = False
break
if all_params or all_received:
equivalent = False
if equivalent:
self.is_consumed = True
self.errormessage = None
else:
self.errormessage = self._failure_message(params) % {
'received_statement': _received_statement,
'received_parameters': _received_parameters
}
def _all_params(self, context):
if self.params:
if util.callable(self.params):
params = self.params(context)
@ -93,127 +160,84 @@ class ExactSQL(SQLMatchRule):
params = self.params
if not isinstance(params, list):
params = [params]
equivalent = equivalent and params \
== context.compiled_parameters
return params
else:
params = {}
self._result = equivalent
if not self._result:
self._errmsg = (
'Testing for exact statement %r exact params %r, '
'received %r with params %r' %
(sql, params, _received_statement, _received_parameters))
return None
def _failure_message(self, expected_params):
return (
'Testing for compiled statement %r partial params %r, '
'received %%(received_statement)r with params '
'%%(received_parameters)r' % (
self.statement.replace('%', '%%'), expected_params
)
)
class RegexSQL(SQLMatchRule):
class RegexSQL(CompiledSQL):
def __init__(self, regex, params=None):
SQLMatchRule.__init__(self)
self.regex = re.compile(regex)
self.orig_regex = regex
self.params = params
self.dialect = 'default'
def process_cursor_execute(self, statement, parameters, context,
executemany):
if not context:
return
_received_statement = \
_process_engine_statement(context.unicode_statement,
context)
_received_parameters = context.compiled_parameters
equivalent = bool(self.regex.match(_received_statement))
if self.params:
if util.callable(self.params):
params = self.params(context)
else:
params = self.params
if not isinstance(params, list):
params = [params]
# do a positive compare only
for param, received in zip(params, _received_parameters):
for k, v in param.items():
if k not in received or received[k] != v:
equivalent = False
break
else:
params = {}
self._result = equivalent
if not self._result:
self._errmsg = \
'Testing for regex %r partial params %r, received %r '\
'with params %r' % (self.orig_regex, params,
_received_statement,
_received_parameters)
class CompiledSQL(SQLMatchRule):
def __init__(self, statement, params=None):
SQLMatchRule.__init__(self)
self.statement = statement
self.params = params
def process_cursor_execute(self, statement, parameters, context,
executemany):
if not context:
return
from sqlalchemy.schema import _DDLCompiles
_received_parameters = list(context.compiled_parameters)
# recompile from the context, using the default dialect
if isinstance(context.compiled.statement, _DDLCompiles):
compiled = \
context.compiled.statement.compile(dialect=DefaultDialect())
else:
compiled = (
context.compiled.statement.compile(
dialect=DefaultDialect(),
column_keys=context.compiled.column_keys)
def _failure_message(self, expected_params):
return (
'Testing for compiled statement ~%r partial params %r, '
'received %%(received_statement)r with params '
'%%(received_parameters)r' % (
self.orig_regex, expected_params
)
_received_statement = re.sub(r'[\n\t]', '', str(compiled))
equivalent = self.statement == _received_statement
if self.params:
if util.callable(self.params):
params = self.params(context)
else:
params = self.params
if not isinstance(params, list):
params = [params]
else:
params = list(params)
all_params = list(params)
all_received = list(_received_parameters)
while params:
param = dict(params.pop(0))
for k, v in context.compiled.params.items():
param.setdefault(k, v)
if param not in _received_parameters:
equivalent = False
break
else:
_received_parameters.remove(param)
if _received_parameters:
equivalent = False
else:
params = {}
all_params = {}
all_received = []
self._result = equivalent
if not self._result:
print('Testing for compiled statement %r partial params '
'%r, received %r with params %r' %
(self.statement, all_params,
_received_statement, all_received))
self._errmsg = (
'Testing for compiled statement %r partial params %r, '
'received %r with params %r' %
(self.statement, all_params,
_received_statement, all_received))
)
# print self._errmsg
def _compare_sql(self, execute_observed, received_statement):
return bool(self.regex.match(received_statement))
class DialectSQL(CompiledSQL):
def _compile_dialect(self, execute_observed):
return execute_observed.context.dialect
def _compare_no_space(self, real_stmt, received_stmt):
stmt = re.sub(r'[\n\t]', '', real_stmt)
return received_stmt == stmt
def _received_statement(self, execute_observed):
received_stmt, received_params = super(DialectSQL, self).\
_received_statement(execute_observed)
# TODO: why do we need this part?
for real_stmt in execute_observed.statements:
if self._compare_no_space(real_stmt.statement, received_stmt):
break
else:
raise AssertionError(
"Can't locate compiled statement %r in list of "
"statements actually invoked" % received_stmt)
return received_stmt, execute_observed.context.compiled_parameters
def _compare_sql(self, execute_observed, received_statement):
stmt = re.sub(r'[\n\t]', '', self.statement)
# convert our comparison statement to have the
# paramstyle of the received
paramstyle = execute_observed.context.dialect.paramstyle
if paramstyle == 'pyformat':
stmt = re.sub(
r':([\w_]+)', r"%(\1)s", stmt)
else:
# positional params
repl = None
if paramstyle == 'qmark':
repl = "?"
elif paramstyle == 'format':
repl = r"%s"
elif paramstyle == 'numeric':
repl = None
stmt = re.sub(r':([\w_]+)', repl, stmt)
return received_statement == stmt
class CountStatements(AssertRule):
@ -222,21 +246,13 @@ class CountStatements(AssertRule):
self.count = count
self._statement_count = 0
def process_execute(self, clauseelement, *multiparams, **params):
def process_statement(self, execute_observed):
self._statement_count += 1
def process_cursor_execute(self, statement, parameters, context,
executemany):
pass
def is_consumed(self):
return False
def consume_final(self):
assert self.count == self._statement_count, \
'desired statement count %d does not match %d' \
% (self.count, self._statement_count)
return True
def no_more_statements(self):
if self.count != self._statement_count:
assert False, 'desired statement count %d does not match %d' \
% (self.count, self._statement_count)
class AllOf(AssertRule):
@ -244,116 +260,113 @@ class AllOf(AssertRule):
def __init__(self, *rules):
self.rules = set(rules)
def process_execute(self, clauseelement, *multiparams, **params):
for rule in self.rules:
rule.process_execute(clauseelement, *multiparams, **params)
def process_cursor_execute(self, statement, parameters, context,
executemany):
for rule in self.rules:
rule.process_cursor_execute(statement, parameters, context,
executemany)
def is_consumed(self):
if not self.rules:
return True
def process_statement(self, execute_observed):
for rule in list(self.rules):
if rule.rule_passed(): # a rule passed, move on
self.rules.remove(rule)
return len(self.rules) == 0
return False
def rule_passed(self):
return self.is_consumed()
def consume_final(self):
return len(self.rules) == 0
rule.errormessage = None
rule.process_statement(execute_observed)
if rule.is_consumed:
self.rules.discard(rule)
if not self.rules:
self.is_consumed = True
break
elif not rule.errormessage:
# rule is not done yet
self.errormessage = None
break
else:
self.errormessage = list(self.rules)[0].errormessage
class Or(AllOf):
def __init__(self, *rules):
self.rules = set(rules)
self._consume_final = False
def is_consumed(self):
if not self.rules:
return True
for rule in list(self.rules):
if rule.rule_passed(): # a rule passed
self._consume_final = True
return True
return False
def consume_final(self):
assert self._consume_final, "Unsatisified rules remain"
def _process_engine_statement(query, context):
if util.jython:
# oracle+zxjdbc passes a PyStatement when returning into
query = str(query)
if context.engine.name == 'mssql' \
and query.endswith('; select scope_identity()'):
query = query[:-25]
query = re.sub(r'\n', '', query)
return query
def _process_assertion_statement(query, context):
paramstyle = context.dialect.paramstyle
if paramstyle == 'named':
pass
elif paramstyle == 'pyformat':
query = re.sub(r':([\w_]+)', r"%(\1)s", query)
else:
# positional params
repl = None
if paramstyle == 'qmark':
repl = "?"
elif paramstyle == 'format':
repl = r"%s"
elif paramstyle == 'numeric':
repl = None
query = re.sub(r':([\w_]+)', repl, query)
return query
class SQLAssert(object):
rules = None
def add_rules(self, rules):
self.rules = list(rules)
def statement_complete(self):
def process_statement(self, execute_observed):
for rule in self.rules:
if not rule.consume_final():
assert False, \
'All statements are complete, but pending '\
'assertion rules remain'
rule.process_statement(execute_observed)
if rule.is_consumed:
self.is_consumed = True
break
else:
self.errormessage = list(self.rules)[0].errormessage
def clear_rules(self):
del self.rules
def execute(self, conn, clauseelement, multiparams, params, result):
if self.rules is not None:
if not self.rules:
assert False, \
'All rules have been exhausted, but further '\
'statements remain'
rule = self.rules[0]
rule.process_execute(clauseelement, *multiparams, **params)
if rule.is_consumed():
self.rules.pop(0)
class SQLExecuteObserved(object):
def __init__(self, context, clauseelement, multiparams, params):
self.context = context
self.clauseelement = clauseelement
self.parameters = _distill_params(multiparams, params)
self.statements = []
def cursor_execute(self, conn, cursor, statement, parameters,
class SQLCursorExecuteObserved(
collections.namedtuple(
"SQLCursorExecuteObserved",
["statement", "parameters", "context", "executemany"])
):
pass
class SQLAsserter(object):
def __init__(self):
self.accumulated = []
def _close(self):
self._final = self.accumulated
del self.accumulated
def assert_(self, *rules):
rules = list(rules)
observed = list(self._final)
while observed and rules:
rule = rules[0]
rule.process_statement(observed[0])
if rule.is_consumed:
rules.pop(0)
elif rule.errormessage:
assert False, rule.errormessage
if rule.consume_statement:
observed.pop(0)
if not observed and rules:
rules[0].no_more_statements()
elif not rules and observed:
assert False, "Additional SQL statements remain"
@contextlib.contextmanager
def assert_engine(engine):
asserter = SQLAsserter()
orig = []
@event.listens_for(engine, "before_execute")
def connection_execute(conn, clauseelement, multiparams, params):
# grab the original statement + params before any cursor
# execution
orig[:] = clauseelement, multiparams, params
@event.listens_for(engine, "after_cursor_execute")
def cursor_execute(conn, cursor, statement, parameters,
context, executemany):
if self.rules:
rule = self.rules[0]
rule.process_cursor_execute(statement, parameters, context,
executemany)
if not context:
return
# then grab real cursor statements and associate them all
# around a single context
if asserter.accumulated and \
asserter.accumulated[-1].context is context:
obs = asserter.accumulated[-1]
else:
obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
asserter.accumulated.append(obs)
obs.statements.append(
SQLCursorExecuteObserved(
statement, parameters, context, executemany)
)
asserter = SQLAssert()
try:
yield asserter
finally:
event.remove(engine, "after_cursor_execute", cursor_execute)
event.remove(engine, "before_execute", connection_execute)
asserter._close()

View file

@ -1,5 +1,5 @@
# testing/config.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@ -12,8 +12,10 @@ db = None
db_url = None
db_opts = None
file_config = None
test_schema = None
test_schema_2 = None
_current = None
_skip_test_exception = None
class Config(object):
@ -22,12 +24,14 @@ class Config(object):
self.db_opts = db_opts
self.options = options
self.file_config = file_config
self.test_schema = "test_schema"
self.test_schema_2 = "test_schema_2"
_stack = collections.deque()
_configs = {}
@classmethod
def register(cls, db, db_opts, options, file_config, namespace):
def register(cls, db, db_opts, options, file_config):
"""add a config as one of the global configs.
If there are no configs set up yet, this config also
@ -35,18 +39,19 @@ class Config(object):
"""
cfg = Config(db, db_opts, options, file_config)
global _current
if not _current:
cls.set_as_current(cfg, namespace)
cls._configs[cfg.db.name] = cfg
cls._configs[(cfg.db.name, cfg.db.dialect)] = cfg
cls._configs[cfg.db] = cfg
return cfg
@classmethod
def set_as_current(cls, config, namespace):
global db, _current, db_url
global db, _current, db_url, test_schema, test_schema_2, db_opts
_current = config
db_url = config.db.url
db_opts = config.db_opts
test_schema = config.test_schema
test_schema_2 = config.test_schema_2
namespace.db = db = config.db
@classmethod
@ -78,3 +83,10 @@ class Config(object):
def all_dbs(cls):
for cfg in cls.all_configs():
yield cfg.db
def skip_test(self, msg):
skip_test(msg)
def skip_test(msg):
raise _skip_test_exception(msg)

View file

@ -8,4 +8,4 @@ import pytest
class TestSuite(unittest.TestCase):
def test_sqlalchemy(self):
pytest.main()
pytest.main(["-n", "4", "-q"])

View file

@ -1,5 +1,5 @@
# testing/engines.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@ -7,15 +7,12 @@
from __future__ import absolute_import
import types
import weakref
from collections import deque
from . import config
from .util import decorator
from .. import event, pool
import re
import warnings
from .. import util
class ConnectionKiller(object):
@ -40,8 +37,6 @@ class ConnectionKiller(object):
def _safe(self, fn):
try:
fn()
except (SystemExit, KeyboardInterrupt):
raise
except Exception as e:
warnings.warn(
"testing_reaper couldn't "
@ -103,7 +98,14 @@ def drop_all_tables(metadata, bind):
testing_reaper.close_all()
if hasattr(bind, 'close'):
bind.close()
metadata.drop_all(bind)
if not config.db.dialect.supports_alter:
from . import assertions
with assertions.expect_warnings(
"Can't sort tables", assert_=False):
metadata.drop_all(bind)
else:
metadata.drop_all(bind)
@decorator
@ -171,8 +173,6 @@ class ReconnectFixture(object):
def _safe(self, fn):
try:
fn()
except (SystemExit, KeyboardInterrupt):
raise
except Exception as e:
warnings.warn(
"ReconnectFixture couldn't "
@ -211,7 +211,7 @@ def testing_engine(url=None, options=None):
"""Produce an engine configured by --options with optional overrides."""
from sqlalchemy import create_engine
from .assertsql import asserter
from sqlalchemy.engine.url import make_url
if not options:
use_reaper = True
@ -219,15 +219,20 @@ def testing_engine(url=None, options=None):
use_reaper = options.pop('use_reaper', True)
url = url or config.db.url
url = make_url(url)
if options is None:
options = config.db_opts
if config.db is None or url.drivername == config.db.url.drivername:
options = config.db_opts
else:
options = {}
engine = create_engine(url, **options)
engine._has_events = True # enable event blocks, helps with profiling
if isinstance(engine.pool, pool.QueuePool):
engine.pool._timeout = 0
engine.pool._max_overflow = 0
event.listen(engine, 'after_execute', asserter.execute)
event.listen(engine, 'after_cursor_execute', asserter.cursor_execute)
if use_reaper:
event.listen(engine.pool, 'connect', testing_reaper.connect)
event.listen(engine.pool, 'checkout', testing_reaper.checkout)
@ -287,10 +292,10 @@ class DBAPIProxyCursor(object):
"""
def __init__(self, engine, conn):
def __init__(self, engine, conn, *args, **kwargs):
self.engine = engine
self.connection = conn
self.cursor = conn.cursor()
self.cursor = conn.cursor(*args, **kwargs)
def execute(self, stmt, parameters=None, **kw):
if parameters:
@ -318,8 +323,8 @@ class DBAPIProxyConnection(object):
self.engine = engine
self.cursor_cls = cursor_cls
def cursor(self):
return self.cursor_cls(self.engine, self.conn)
def cursor(self, *args, **kwargs):
return self.cursor_cls(self.engine, self.conn, *args, **kwargs)
def close(self):
self.conn.close()
@ -339,112 +344,3 @@ def proxying_engine(conn_cls=DBAPIProxyConnection,
return testing_engine(options={'creator': mock_conn})
class ReplayableSession(object):
"""A simple record/playback tool.
This is *not* a mock testing class. It only records a session for later
playback and makes no assertions on call consistency whatsoever. It's
unlikely to be suitable for anything other than DB-API recording.
"""
Callable = object()
NoAttribute = object()
if util.py2k:
Natives = set([getattr(types, t)
for t in dir(types) if not t.startswith('_')]).\
difference([getattr(types, t)
for t in ('FunctionType', 'BuiltinFunctionType',
'MethodType', 'BuiltinMethodType',
'LambdaType', 'UnboundMethodType',)])
else:
Natives = set([getattr(types, t)
for t in dir(types) if not t.startswith('_')]).\
union([type(t) if not isinstance(t, type)
else t for t in __builtins__.values()]).\
difference([getattr(types, t)
for t in ('FunctionType', 'BuiltinFunctionType',
'MethodType', 'BuiltinMethodType',
'LambdaType', )])
def __init__(self):
self.buffer = deque()
def recorder(self, base):
return self.Recorder(self.buffer, base)
def player(self):
return self.Player(self.buffer)
class Recorder(object):
def __init__(self, buffer, subject):
self._buffer = buffer
self._subject = subject
def __call__(self, *args, **kw):
subject, buffer = [object.__getattribute__(self, x)
for x in ('_subject', '_buffer')]
result = subject(*args, **kw)
if type(result) not in ReplayableSession.Natives:
buffer.append(ReplayableSession.Callable)
return type(self)(buffer, result)
else:
buffer.append(result)
return result
@property
def _sqla_unwrap(self):
return self._subject
def __getattribute__(self, key):
try:
return object.__getattribute__(self, key)
except AttributeError:
pass
subject, buffer = [object.__getattribute__(self, x)
for x in ('_subject', '_buffer')]
try:
result = type(subject).__getattribute__(subject, key)
except AttributeError:
buffer.append(ReplayableSession.NoAttribute)
raise
else:
if type(result) not in ReplayableSession.Natives:
buffer.append(ReplayableSession.Callable)
return type(self)(buffer, result)
else:
buffer.append(result)
return result
class Player(object):
def __init__(self, buffer):
self._buffer = buffer
def __call__(self, *args, **kw):
buffer = object.__getattribute__(self, '_buffer')
result = buffer.popleft()
if result is ReplayableSession.Callable:
return self
else:
return result
@property
def _sqla_unwrap(self):
return None
def __getattribute__(self, key):
try:
return object.__getattribute__(self, key)
except AttributeError:
pass
buffer = object.__getattribute__(self, '_buffer')
result = buffer.popleft()
if result is ReplayableSession.Callable:
return self
elif result is ReplayableSession.NoAttribute:
raise AttributeError(key)
else:
return result

View file

@ -1,5 +1,5 @@
# testing/entities.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View file

@ -1,5 +1,5 @@
# testing/exclusions.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@ -7,87 +7,161 @@
import operator
from .plugin.plugin_base import SkipTest
from ..util import decorator
from . import config
from .. import util
import contextlib
import inspect
import contextlib
from sqlalchemy.util.compat import inspect_getargspec
class skip_if(object):
def __init__(self, predicate, reason=None):
self.predicate = _as_predicate(predicate)
self.reason = reason
def skip_if(predicate, reason=None):
rule = compound()
pred = _as_predicate(predicate, reason)
rule.skips.add(pred)
return rule
_fails_on = None
def fails_if(predicate, reason=None):
rule = compound()
pred = _as_predicate(predicate, reason)
rule.fails.add(pred)
return rule
class compound(object):
def __init__(self):
self.fails = set()
self.skips = set()
self.tags = set()
def __add__(self, other):
def decorate(fn):
return other(self(fn))
return decorate
return self.add(other)
def add(self, *others):
copy = compound()
copy.fails.update(self.fails)
copy.skips.update(self.skips)
copy.tags.update(self.tags)
for other in others:
copy.fails.update(other.fails)
copy.skips.update(other.skips)
copy.tags.update(other.tags)
return copy
def not_(self):
copy = compound()
copy.fails.update(NotPredicate(fail) for fail in self.fails)
copy.skips.update(NotPredicate(skip) for skip in self.skips)
copy.tags.update(self.tags)
return copy
@property
def enabled(self):
return self.enabled_for_config(config._current)
def enabled_for_config(self, config):
return not self.predicate(config)
for predicate in self.skips.union(self.fails):
if predicate(config):
return False
else:
return True
def matching_config_reasons(self, config):
return [
predicate._as_string(config) for predicate
in self.skips.union(self.fails)
if predicate(config)
]
def include_test(self, include_tags, exclude_tags):
return bool(
not self.tags.intersection(exclude_tags) and
(not include_tags or self.tags.intersection(include_tags))
)
def _extend(self, other):
self.skips.update(other.skips)
self.fails.update(other.fails)
self.tags.update(other.tags)
def __call__(self, fn):
if hasattr(fn, '_sa_exclusion_extend'):
fn._sa_exclusion_extend._extend(self)
return fn
@decorator
def decorate(fn, *args, **kw):
return self._do(config._current, fn, *args, **kw)
decorated = decorate(fn)
decorated._sa_exclusion_extend = self
return decorated
@contextlib.contextmanager
def fail_if(self, name='block'):
def fail_if(self):
all_fails = compound()
all_fails.fails.update(self.skips.union(self.fails))
try:
yield
except Exception as ex:
if self.predicate(config._current):
print(("%s failed as expected (%s): %s " % (
name, self.predicate, str(ex))))
else:
raise
all_fails._expect_failure(config._current, ex)
else:
if self.predicate(config._current):
raise AssertionError(
"Unexpected success for '%s' (%s)" %
(name, self.predicate))
all_fails._expect_success(config._current)
def __call__(self, fn):
@decorator
def decorate(fn, *args, **kw):
if self.predicate(config._current):
if self.reason:
msg = "'%s' : %s" % (
fn.__name__,
self.reason
def _do(self, config, fn, *args, **kw):
for skip in self.skips:
if skip(config):
msg = "'%s' : %s" % (
fn.__name__,
skip._as_string(config)
)
config.skip_test(msg)
try:
return_value = fn(*args, **kw)
except Exception as ex:
self._expect_failure(config, ex, name=fn.__name__)
else:
self._expect_success(config, name=fn.__name__)
return return_value
def _expect_failure(self, config, ex, name='block'):
for fail in self.fails:
if fail(config):
print(("%s failed as expected (%s): %s " % (
name, fail._as_string(config), str(ex))))
break
else:
util.raise_from_cause(ex)
def _expect_success(self, config, name='block'):
if not self.fails:
return
for fail in self.fails:
if not fail(config):
break
else:
raise AssertionError(
"Unexpected success for '%s' (%s)" %
(
name,
" and ".join(
fail._as_string(config)
for fail in self.fails
)
else:
msg = "'%s': %s" % (
fn.__name__, self.predicate
)
raise SkipTest(msg)
else:
if self._fails_on:
with self._fails_on.fail_if(name=fn.__name__):
return fn(*args, **kw)
else:
return fn(*args, **kw)
return decorate(fn)
def fails_on(self, other, reason=None):
self._fails_on = skip_if(other, reason)
return self
def fails_on_everything_except(self, *dbs):
self._fails_on = skip_if(fails_on_everything_except(*dbs))
return self
)
)
class fails_if(skip_if):
def __call__(self, fn):
@decorator
def decorate(fn, *args, **kw):
with self.fail_if(name=fn.__name__):
return fn(*args, **kw)
return decorate(fn)
def requires_tag(tagname):
return tags([tagname])
def tags(tagnames):
comp = compound()
comp.tags.update(tagnames)
return comp
def only_if(predicate, reason=None):
@ -102,13 +176,17 @@ def succeeds_if(predicate, reason=None):
class Predicate(object):
@classmethod
def as_predicate(cls, predicate):
if isinstance(predicate, skip_if):
return NotPredicate(predicate.predicate)
def as_predicate(cls, predicate, description=None):
if isinstance(predicate, compound):
return cls.as_predicate(predicate.enabled_for_config, description)
elif isinstance(predicate, Predicate):
if description and predicate.description is None:
predicate.description = description
return predicate
elif isinstance(predicate, list):
return OrPredicate([cls.as_predicate(pred) for pred in predicate])
elif isinstance(predicate, (list, set)):
return OrPredicate(
[cls.as_predicate(pred) for pred in predicate],
description)
elif isinstance(predicate, tuple):
return SpecPredicate(*predicate)
elif isinstance(predicate, util.string_types):
@ -119,12 +197,26 @@ class Predicate(object):
op = tokens.pop(0)
if tokens:
spec = tuple(int(d) for d in tokens.pop(0).split("."))
return SpecPredicate(db, op, spec)
return SpecPredicate(db, op, spec, description=description)
elif util.callable(predicate):
return LambdaPredicate(predicate)
return LambdaPredicate(predicate, description)
else:
assert False, "unknown predicate type: %s" % predicate
def _format_description(self, config, negate=False):
bool_ = self(config)
if negate:
bool_ = not negate
return self.description % {
"driver": config.db.url.get_driver_name(),
"database": config.db.url.get_backend_name(),
"doesnt_support": "doesn't support" if bool_ else "does support",
"does_support": "does support" if bool_ else "doesn't support"
}
def _as_string(self, config=None, negate=False):
raise NotImplementedError()
class BooleanPredicate(Predicate):
def __init__(self, value, description=None):
@ -134,14 +226,8 @@ class BooleanPredicate(Predicate):
def __call__(self, config):
return self.value
def _as_string(self, negate=False):
if negate:
return "not " + self.description
else:
return self.description
def __str__(self):
return self._as_string()
def _as_string(self, config, negate=False):
return self._format_description(config, negate=negate)
class SpecPredicate(Predicate):
@ -185,9 +271,9 @@ class SpecPredicate(Predicate):
else:
return True
def _as_string(self, negate=False):
def _as_string(self, config, negate=False):
if self.description is not None:
return self.description
return self._format_description(config)
elif self.op is None:
if negate:
return "not %s" % self.db
@ -207,13 +293,10 @@ class SpecPredicate(Predicate):
self.spec
)
def __str__(self):
return self._as_string()
class LambdaPredicate(Predicate):
def __init__(self, lambda_, description=None, args=None, kw=None):
spec = inspect.getargspec(lambda_)
spec = inspect_getargspec(lambda_)
if not spec[0]:
self.lambda_ = lambda db: lambda_()
else:
@ -230,25 +313,23 @@ class LambdaPredicate(Predicate):
def __call__(self, config):
return self.lambda_(config)
def _as_string(self, negate=False):
if negate:
return "not " + self.description
else:
return self.description
def __str__(self):
return self._as_string()
def _as_string(self, config, negate=False):
return self._format_description(config)
class NotPredicate(Predicate):
def __init__(self, predicate):
def __init__(self, predicate, description=None):
self.predicate = predicate
self.description = description
def __call__(self, config):
return not self.predicate(config)
def __str__(self):
return self.predicate._as_string(True)
def _as_string(self, config, negate=False):
if self.description:
return self._format_description(config, not negate)
else:
return self.predicate._as_string(config, not negate)
class OrPredicate(Predicate):
@ -259,40 +340,32 @@ class OrPredicate(Predicate):
def __call__(self, config):
for pred in self.predicates:
if pred(config):
self._str = pred
return True
return False
_str = None
def _eval_str(self, negate=False):
if self._str is None:
if negate:
conjunction = " and "
else:
conjunction = " or "
return conjunction.join(p._as_string(negate=negate)
for p in self.predicates)
else:
return self._str._as_string(negate=negate)
def _negation_str(self):
if self.description is not None:
return "Not " + (self.description % {"spec": self._str})
else:
return self._eval_str(negate=True)
def _as_string(self, negate=False):
def _eval_str(self, config, negate=False):
if negate:
return self._negation_str()
conjunction = " and "
else:
conjunction = " or "
return conjunction.join(p._as_string(config, negate=negate)
for p in self.predicates)
def _negation_str(self, config):
if self.description is not None:
return "Not " + self._format_description(config)
else:
return self._eval_str(config, negate=True)
def _as_string(self, config, negate=False):
if negate:
return self._negation_str(config)
else:
if self.description is not None:
return self.description % {"spec": self._str}
return self._format_description(config)
else:
return self._eval_str()
return self._eval_str(config)
def __str__(self):
return self._as_string()
_as_predicate = Predicate.as_predicate
@ -325,8 +398,8 @@ def closed():
return skip_if(BooleanPredicate(True, "marked as skip"))
def fails():
return fails_if(BooleanPredicate(True, "expected to fail"))
def fails(reason=None):
return fails_if(BooleanPredicate(True, reason or "expected to fail"))
@decorator
@ -341,8 +414,8 @@ def fails_on(db, reason=None):
def fails_on_everything_except(*dbs):
return succeeds_if(
OrPredicate([
SpecPredicate(db) for db in dbs
])
SpecPredicate(db) for db in dbs
])
)
@ -352,7 +425,7 @@ def skip(db, reason=None):
def only_on(dbs, reason=None):
return only_if(
OrPredicate([SpecPredicate(db) for db in util.to_list(dbs)])
OrPredicate([Predicate.as_predicate(db) for db in util.to_list(dbs)])
)

View file

@ -1,5 +1,5 @@
# testing/fixtures.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@ -91,20 +91,12 @@ class TablesTest(TestBase):
cls.run_create_tables = 'each'
assert cls.run_inserts in ('each', None)
if cls.other is None:
cls.other = adict()
cls.other = adict()
cls.tables = adict()
if cls.tables is None:
cls.tables = adict()
if cls.bind is None:
setattr(cls, 'bind', cls.setup_bind())
if cls.metadata is None:
setattr(cls, 'metadata', sa.MetaData())
if cls.metadata.bind is None:
cls.metadata.bind = cls.bind
cls.bind = cls.setup_bind()
cls.metadata = sa.MetaData()
cls.metadata.bind = cls.bind
@classmethod
def _setup_once_inserts(cls):
@ -142,13 +134,14 @@ class TablesTest(TestBase):
def _teardown_each_tables(self):
# no need to run deletes if tables are recreated on setup
if self.run_define_tables != 'each' and self.run_deletes == 'each':
for table in reversed(self.metadata.sorted_tables):
try:
table.delete().execute().close()
except sa.exc.DBAPIError as ex:
util.print_(
("Error emptying table %s: %r" % (table, ex)),
file=sys.stderr)
with self.bind.connect() as conn:
for table in reversed(self.metadata.sorted_tables):
try:
conn.execute(table.delete())
except sa.exc.DBAPIError as ex:
util.print_(
("Error emptying table %s: %r" % (table, ex)),
file=sys.stderr)
def setup(self):
self._setup_each_tables()
@ -200,9 +193,8 @@ class TablesTest(TestBase):
def sql_count_(self, count, fn):
self.assert_sql_count(self.bind, fn, count)
def sql_eq_(self, callable_, statements, with_sequences=None):
self.assert_sql(self.bind,
callable_, statements, with_sequences)
def sql_eq_(self, callable_, statements):
self.assert_sql(self.bind, callable_, statements)
@classmethod
def _load_fixtures(cls):
@ -283,12 +275,14 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
def setup(self):
self._setup_each_tables()
self._setup_each_classes()
self._setup_each_mappers()
self._setup_each_inserts()
def teardown(self):
sa.orm.session.Session.close_all()
self._teardown_each_mappers()
self._teardown_each_classes()
self._teardown_each_tables()
@classmethod
@ -310,6 +304,10 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
if self.run_setup_mappers == 'each':
self._with_register_classes(self.setup_mappers)
def _setup_each_classes(self):
if self.run_setup_classes == 'each':
self._with_register_classes(self.setup_classes)
@classmethod
def _with_register_classes(cls, fn):
"""Run a setup method, framing the operation with a Base class
@ -344,6 +342,10 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
if self.run_setup_mappers != 'once':
sa.orm.clear_mappers()
def _teardown_each_classes(self):
if self.run_setup_classes != 'once':
self.classes.clear()
@classmethod
def setup_classes(cls):
pass

View file

@ -1,5 +1,5 @@
# testing/mock.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@ -11,10 +11,10 @@ from __future__ import absolute_import
from ..util import py33
if py33:
from unittest.mock import MagicMock, Mock, call, patch
from unittest.mock import MagicMock, Mock, call, patch, ANY
else:
try:
from mock import MagicMock, Mock, call, patch
from mock import MagicMock, Mock, call, patch, ANY
except ImportError:
raise ImportError(
"SQLAlchemy's test suite requires the "

View file

@ -1,5 +1,5 @@
# testing/pickleable.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View file

@ -0,0 +1,44 @@
"""
Bootstrapper for nose/pytest plugins.
The entire rationale for this system is to get the modules in plugin/
imported without importing all of the supporting library, so that we can
set up things for testing before coverage starts.
The rationale for all of plugin/ being *in* the supporting library in the
first place is so that the testing and plugin suite is available to other
libraries, mainly external SQLAlchemy and Alembic dialects, to make use
of the same test environment and standard suites available to
SQLAlchemy/Alembic themselves without the need to ship/install a separate
package outside of SQLAlchemy.
NOTE: copied/adapted from SQLAlchemy master for backwards compatibility;
this should be removable when Alembic targets SQLAlchemy 1.0.0.
"""
import os
import sys
bootstrap_file = locals()['bootstrap_file']
to_bootstrap = locals()['to_bootstrap']
def load_file_as_module(name):
path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name)
if sys.version_info >= (3, 3):
from importlib import machinery
mod = machinery.SourceFileLoader(name, path).load_module()
else:
import imp
mod = imp.load_source(name, path)
return mod
if to_bootstrap == "pytest":
sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
elif to_bootstrap == "nose":
sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
sys.modules["sqla_noseplugin"] = load_file_as_module("noseplugin")
else:
raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa

View file

@ -1,5 +1,5 @@
# plugin/noseplugin.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@ -12,22 +12,22 @@ way (e.g. as a package-less import).
"""
try:
# installed by bootstrap.py
import sqla_plugin_base as plugin_base
except ImportError:
# assume we're a package, use traditional import
from . import plugin_base
import os
import sys
from nose.plugins import Plugin
import nose
fixtures = None
# no package imports yet! this prevents us from tripping coverage
# too soon.
path = os.path.join(os.path.dirname(__file__), "plugin_base.py")
if sys.version_info >= (3, 3):
from importlib import machinery
plugin_base = machinery.SourceFileLoader(
"plugin_base", path).load_module()
else:
import imp
plugin_base = imp.load_source("plugin_base", path)
py3k = sys.version_info >= (3, 0)
class NoseSQLAlchemy(Plugin):
@ -57,28 +57,39 @@ class NoseSQLAlchemy(Plugin):
plugin_base.set_coverage_flag(options.enable_plugin_coverage)
global fixtures
from sqlalchemy.testing import fixtures
plugin_base.set_skip_test(nose.SkipTest)
def begin(self):
global fixtures
from sqlalchemy.testing import fixtures # noqa
plugin_base.post_begin()
def describeTest(self, test):
return ""
def wantFunction(self, fn):
if fn.__module__ is None:
return False
if fn.__module__.startswith('sqlalchemy.testing'):
return False
return False
def wantMethod(self, fn):
if py3k:
if not hasattr(fn.__self__, 'cls'):
return False
cls = fn.__self__.cls
else:
cls = fn.im_class
return plugin_base.want_method(cls, fn)
def wantClass(self, cls):
return plugin_base.want_class(cls)
def beforeTest(self, test):
plugin_base.before_test(test,
test.test.cls.__module__,
test.test.cls, test.test.method.__name__)
if not hasattr(test.test, 'cls'):
return
plugin_base.before_test(
test,
test.test.cls.__module__,
test.test.cls, test.test.method.__name__)
def afterTest(self, test):
plugin_base.after_test(test)

View file

@ -1,5 +1,5 @@
# plugin/plugin_base.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@ -14,12 +14,6 @@ functionality via py.test.
"""
from __future__ import absolute_import
try:
# unitttest has a SkipTest also but pytest doesn't
# honor it unless nose is imported too...
from nose import SkipTest
except ImportError:
from _pytest.runner import Skipped as SkipTest
import sys
import re
@ -31,7 +25,6 @@ if py3k:
else:
import ConfigParser as configparser
# late imports
fixtures = None
engines = None
@ -47,7 +40,8 @@ file_config = None
logging = None
db_opts = {}
include_tags = set()
exclude_tags = set()
options = None
@ -69,8 +63,6 @@ def setup_options(make_option):
help="Drop all tables in the target database first")
make_option("--backend-only", action="store_true", dest="backend_only",
help="Run only tests marked with __backend__")
make_option("--mockpool", action="store_true", dest="mockpool",
help="Use mock pool (asserts only one connection used)")
make_option("--low-connections", action="store_true",
dest="low_connections",
help="Use a low number of distinct connections - "
@ -86,18 +78,56 @@ def setup_options(make_option):
dest="cdecimal", default=False,
help="Monkeypatch the cdecimal library into Python 'decimal' "
"for all tests")
make_option("--serverside", action="callback",
callback=_server_side_cursors,
help="Turn on server side cursors for PG")
make_option("--mysql-engine", action="store",
dest="mysql_engine", default=None,
help="Use the specified MySQL storage engine for all tables, "
"default is a db-default/InnoDB combo.")
make_option("--tableopts", action="append", dest="tableopts", default=[],
help="Add a dialect-specific table option, key=value")
make_option("--include-tag", action="callback", callback=_include_tag,
type="string",
help="Include tests with tag <tag>")
make_option("--exclude-tag", action="callback", callback=_exclude_tag,
type="string",
help="Exclude tests with tag <tag>")
make_option("--write-profiles", action="store_true",
dest="write_profiles", default=False,
help="Write/update profiling data.")
help="Write/update failing profiling data.")
make_option("--force-write-profiles", action="store_true",
dest="force_write_profiles", default=False,
help="Unconditionally write/update profiling data.")
def configure_follower(follower_ident):
"""Configure required state for a follower.
This invokes in the parent process and typically includes
database creation.
"""
from sqlalchemy.testing import provision
provision.FOLLOWER_IDENT = follower_ident
def memoize_important_follower_config(dict_):
"""Store important configuration we will need to send to a follower.
This invokes in the parent process after normal config is set up.
This is necessary as py.test seems to not be using forking, so we
start with nothing in memory, *but* it isn't running our argparse
callables, so we have to just copy all of that over.
"""
dict_['memoized_config'] = {
'include_tags': include_tags,
'exclude_tags': exclude_tags
}
def restore_important_follower_config(dict_):
"""Restore important configuration needed by a follower.
This invokes in the follower process.
"""
global include_tags, exclude_tags
include_tags.update(dict_['memoized_config']['include_tags'])
exclude_tags.update(dict_['memoized_config']['exclude_tags'])
def read_config():
@ -117,6 +147,13 @@ def pre_begin(opt):
def set_coverage_flag(value):
options.has_coverage = value
_skip_test_exception = None
def set_skip_test(exc):
global _skip_test_exception
_skip_test_exception = exc
def post_begin():
"""things to set up later, once we know coverage is running."""
@ -129,10 +166,12 @@ def post_begin():
global util, fixtures, engines, exclusions, \
assertions, warnings, profiling,\
config, testing
from sqlalchemy import testing
from sqlalchemy.testing import fixtures, engines, exclusions, \
assertions, warnings, profiling, config
from sqlalchemy import util
from sqlalchemy import testing # noqa
from sqlalchemy.testing import fixtures, engines, exclusions # noqa
from sqlalchemy.testing import assertions, warnings, profiling # noqa
from sqlalchemy.testing import config # noqa
from sqlalchemy import util # noqa
warnings.setup_filters()
def _log(opt_str, value, parser):
@ -154,14 +193,17 @@ def _list_dbs(*args):
sys.exit(0)
def _server_side_cursors(opt_str, value, parser):
db_opts['server_side_cursors'] = True
def _requirements_opt(opt_str, value, parser):
_setup_requirements(value)
def _exclude_tag(opt_str, value, parser):
exclude_tags.add(value.replace('-', '_'))
def _include_tag(opt_str, value, parser):
include_tags.add(value.replace('-', '_'))
pre_configure = []
post_configure = []
@ -189,10 +231,18 @@ def _monkeypatch_cdecimal(options, file_config):
sys.modules['decimal'] = cdecimal
@post
def _init_skiptest(options, file_config):
from sqlalchemy.testing import config
config._skip_test_exception = _skip_test_exception
@post
def _engine_uri(options, file_config):
from sqlalchemy.testing import engines, config
from sqlalchemy.testing import config
from sqlalchemy import testing
from sqlalchemy.testing import provision
if options.dburi:
db_urls = list(options.dburi)
@ -214,18 +264,11 @@ def _engine_uri(options, file_config):
db_urls.append(file_config.get('db', 'default'))
for db_url in db_urls:
eng = engines.testing_engine(db_url, db_opts)
eng.connect().close()
config.Config.register(eng, db_opts, options, file_config, testing)
cfg = provision.setup_config(
db_url, options, file_config, provision.FOLLOWER_IDENT)
config.db_opts = db_opts
@post
def _engine_pool(options, file_config):
if options.mockpool:
from sqlalchemy import pool
db_opts['poolclass'] = pool.AssertionPool
if not config._current:
cfg.set_as_current(cfg, testing)
@post
@ -256,7 +299,8 @@ def _setup_requirements(argument):
@post
def _prep_testing_database(options, file_config):
from sqlalchemy.testing import config
from sqlalchemy.testing import config, util
from sqlalchemy.testing.exclusions import against
from sqlalchemy import schema, inspect
if options.dropfirst:
@ -286,32 +330,18 @@ def _prep_testing_database(options, file_config):
schema="test_schema")
))
for tname in reversed(inspector.get_table_names(
order_by="foreign_key")):
e.execute(schema.DropTable(
schema.Table(tname, schema.MetaData())
))
util.drop_all_tables(e, inspector)
if config.requirements.schemas.enabled_for_config(cfg):
for tname in reversed(inspector.get_table_names(
order_by="foreign_key", schema="test_schema")):
e.execute(schema.DropTable(
schema.Table(tname, schema.MetaData(),
schema="test_schema")
))
util.drop_all_tables(e, inspector, schema=cfg.test_schema)
@post
def _set_table_options(options, file_config):
from sqlalchemy.testing import schema
table_options = schema.table_options
for spec in options.tableopts:
key, value = spec.split('=')
table_options[key] = value
if options.mysql_engine:
table_options['mysql_engine'] = options.mysql_engine
if against(cfg, "postgresql"):
from sqlalchemy.dialects import postgresql
for enum in inspector.get_enums("*"):
e.execute(postgresql.DropEnumType(
postgresql.ENUM(
name=enum['name'],
schema=enum['schema'])))
@post
@ -347,6 +377,30 @@ def want_class(cls):
return True
def want_method(cls, fn):
if not fn.__name__.startswith("test_"):
return False
elif fn.__module__ is None:
return False
elif include_tags:
return (
hasattr(cls, '__tags__') and
exclusions.tags(cls.__tags__).include_test(
include_tags, exclude_tags)
) or (
hasattr(fn, '_sa_exclusion_extend') and
fn._sa_exclusion_extend.include_test(
include_tags, exclude_tags)
)
elif exclude_tags and hasattr(cls, '__tags__'):
return exclusions.tags(cls.__tags__).include_test(
include_tags, exclude_tags)
elif exclude_tags and hasattr(fn, '_sa_exclusion_extend'):
return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags)
else:
return True
def generate_sub_tests(cls, module):
if getattr(cls, '__backend__', False):
for cfg in _possible_configs_for_cls(cls):
@ -356,7 +410,7 @@ def generate_sub_tests(cls, module):
(cls, ),
{
"__only_on__": ("%s+%s" % (cfg.db.name, cfg.db.driver)),
"__backend__": False}
}
)
setattr(module, name, subcls)
yield subcls
@ -370,6 +424,8 @@ def start_test_class(cls):
def stop_test_class(cls):
#from sqlalchemy import inspect
#assert not inspect(testing.db).get_table_names()
engines.testing_reaper._stop_test_ctx()
if not options.low_connections:
assertions.global_cleanup_assertions()
@ -398,33 +454,27 @@ def before_test(test, test_module_name, test_class, test_name):
id_ = "%s.%s.%s" % (test_module_name, name, test_name)
warnings.resetwarnings()
profiling._current_test = id_
def after_test(test):
engines.testing_reaper._after_test_ctx()
warnings.resetwarnings()
def _possible_configs_for_cls(cls):
def _possible_configs_for_cls(cls, reasons=None):
all_configs = set(config.Config.all_configs())
if cls.__unsupported_on__:
spec = exclusions.db_spec(*cls.__unsupported_on__)
for config_obj in list(all_configs):
if spec(config_obj):
all_configs.remove(config_obj)
if getattr(cls, '__only_on__', None):
spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
for config_obj in list(all_configs):
if not spec(config_obj):
all_configs.remove(config_obj)
return all_configs
def _do_skips(cls):
all_configs = _possible_configs_for_cls(cls)
reasons = []
if hasattr(cls, '__requires__'):
requirements = config.requirements
@ -432,10 +482,11 @@ def _do_skips(cls):
for requirement in cls.__requires__:
check = getattr(requirements, requirement)
if check.predicate(config_obj):
skip_reasons = check.matching_config_reasons(config_obj)
if skip_reasons:
all_configs.remove(config_obj)
if check.reason:
reasons.append(check.reason)
if reasons is not None:
reasons.extend(skip_reasons)
break
if hasattr(cls, '__prefer_requires__'):
@ -445,36 +496,45 @@ def _do_skips(cls):
for requirement in cls.__prefer_requires__:
check = getattr(requirements, requirement)
if check.predicate(config_obj):
if not check.enabled_for_config(config_obj):
non_preferred.add(config_obj)
if all_configs.difference(non_preferred):
all_configs.difference_update(non_preferred)
return all_configs
def _do_skips(cls):
reasons = []
all_configs = _possible_configs_for_cls(cls, reasons)
if getattr(cls, '__skip_if__', False):
for c in getattr(cls, '__skip_if__'):
if c():
raise SkipTest("'%s' skipped by %s" % (
config.skip_test("'%s' skipped by %s" % (
cls.__name__, c.__name__)
)
for db_spec, op, spec in getattr(cls, '__excluded_on__', ()):
for config_obj in list(all_configs):
if exclusions.skip_if(
exclusions.SpecPredicate(db_spec, op, spec)
).predicate(config_obj):
all_configs.remove(config_obj)
if not all_configs:
raise SkipTest(
"'%s' unsupported on DB implementation %s%s" % (
if getattr(cls, '__backend__', False):
msg = "'%s' unsupported for implementation '%s'" % (
cls.__name__, cls.__only_on__)
else:
msg = "'%s' unsupported on any DB implementation %s%s" % (
cls.__name__,
", ".join("'%s' = %s"
% (config_obj.db.name,
config_obj.db.dialect.server_version_info)
for config_obj in config.Config.all_configs()
),
", ".join(
"'%s(%s)+%s'" % (
config_obj.db.name,
".".join(
str(dig) for dig in
config_obj.db.dialect.server_version_info),
config_obj.db.driver
)
for config_obj in config.Config.all_configs()
),
", ".join(reasons)
)
)
config.skip_test(msg)
elif hasattr(cls, '__prefer_backends__'):
non_preferred = set()
spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))

View file

@ -1,8 +1,21 @@
try:
# installed by bootstrap.py
import sqla_plugin_base as plugin_base
except ImportError:
# assume we're a package, use traditional import
from . import plugin_base
import pytest
import argparse
import inspect
from . import plugin_base
import collections
import itertools
try:
import xdist # noqa
has_xdist = True
except ImportError:
has_xdist = False
def pytest_addoption(parser):
@ -24,13 +37,40 @@ def pytest_addoption(parser):
def pytest_configure(config):
if hasattr(config, "slaveinput"):
plugin_base.restore_important_follower_config(config.slaveinput)
plugin_base.configure_follower(
config.slaveinput["follower_ident"]
)
plugin_base.pre_begin(config.option)
plugin_base.set_coverage_flag(bool(getattr(config.option,
"cov_source", False)))
plugin_base.set_skip_test(pytest.skip.Exception)
def pytest_sessionstart(session):
plugin_base.post_begin()
if has_xdist:
import uuid
def pytest_configure_node(node):
# the master for each node fills slaveinput dictionary
# which pytest-xdist will transfer to the subprocess
plugin_base.memoize_important_follower_config(node.slaveinput)
node.slaveinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12]
from sqlalchemy.testing import provision
provision.create_follower_db(node.slaveinput["follower_ident"])
def pytest_testnodedown(node, error):
from sqlalchemy.testing import provision
provision.drop_follower_db(node.slaveinput["follower_ident"])
def pytest_collection_modifyitems(session, config, items):
# look for all those classes that specify __backend__ and
@ -44,6 +84,10 @@ def pytest_collection_modifyitems(session, config, items):
# new classes to a module on the fly.
rebuilt_items = collections.defaultdict(list)
items[:] = [
item for item in
items if isinstance(item.parent, pytest.Instance)
and not item.parent.parent.name.startswith("_")]
test_classes = set(item.parent for item in items)
for test_class in test_classes:
for sub_cls in plugin_base.generate_sub_tests(
@ -74,12 +118,11 @@ def pytest_collection_modifyitems(session, config, items):
def pytest_pycollect_makeitem(collector, name, obj):
if inspect.isclass(obj) and plugin_base.want_class(obj):
return pytest.Class(name, parent=collector)
elif inspect.isfunction(obj) and \
name.startswith("test_") and \
isinstance(collector, pytest.Instance):
isinstance(collector, pytest.Instance) and \
plugin_base.want_method(collector.cls, obj):
return pytest.Function(name, parent=collector)
else:
return []
@ -97,16 +140,18 @@ def pytest_runtest_setup(item):
return
# ... so we're doing a little dance here to figure it out...
if item.parent.parent is not _current_class:
if _current_class is None:
class_setup(item.parent.parent)
_current_class = item.parent.parent
# this is needed for the class-level, to ensure that the
# teardown runs after the class is completed with its own
# class-level teardown...
item.parent.parent.addfinalizer(
lambda: class_teardown(item.parent.parent))
def finalize():
global _current_class
class_teardown(item.parent.parent)
_current_class = None
item.parent.parent.addfinalizer(finalize)
test_setup(item)

View file

@ -1,5 +1,5 @@
# testing/profiling.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@ -14,13 +14,11 @@ in a more fine-grained way than nose's profiling plugin.
import os
import sys
from .util import gc_collect, decorator
from .util import gc_collect
from . import config
from .plugin.plugin_base import SkipTest
import pstats
import time
import collections
from .. import util
import contextlib
try:
import cProfile
@ -30,64 +28,8 @@ from ..util import jython, pypy, win32, update_wrapper
_current_test = None
def profiled(target=None, **target_opts):
"""Function profiling.
@profiled()
or
@profiled(report=True, sort=('calls',), limit=20)
Outputs profiling info for a decorated function.
"""
profile_config = {'targets': set(),
'report': True,
'print_callers': False,
'print_callees': False,
'graphic': False,
'sort': ('time', 'calls'),
'limit': None}
if target is None:
target = 'anonymous_target'
@decorator
def decorate(fn, *args, **kw):
elapsed, load_stats, result = _profile(
fn, *args, **kw)
graphic = target_opts.get('graphic', profile_config['graphic'])
if graphic:
os.system("runsnake %s" % filename)
else:
report = target_opts.get('report', profile_config['report'])
if report:
sort_ = target_opts.get('sort', profile_config['sort'])
limit = target_opts.get('limit', profile_config['limit'])
print(("Profile report for target '%s'" % (
target, )
))
stats = load_stats()
stats.sort_stats(*sort_)
if limit:
stats.print_stats(limit)
else:
stats.print_stats()
print_callers = target_opts.get(
'print_callers', profile_config['print_callers'])
if print_callers:
stats.print_callers()
print_callees = target_opts.get(
'print_callees', profile_config['print_callees'])
if print_callees:
stats.print_callees()
return result
return decorate
# ProfileStatsFile instance, set up in plugin_base
_profile_stats = None
class ProfileStatsFile(object):
@ -99,7 +41,11 @@ class ProfileStatsFile(object):
"""
def __init__(self, filename):
self.write = (
self.force_write = (
config.options is not None and
config.options.force_write_profiles
)
self.write = self.force_write or (
config.options is not None and
config.options.write_profiles
)
@ -129,6 +75,11 @@ class ProfileStatsFile(object):
platform_tokens.append("pypy")
if win32:
platform_tokens.append("win")
platform_tokens.append(
"nativeunicode"
if config.db.dialect.convert_unicode
else "dbapiunicode"
)
_has_cext = config.requirements._has_cextensions()
platform_tokens.append(_has_cext and "cextensions" or "nocextensions")
return "_".join(platform_tokens)
@ -172,25 +123,32 @@ class ProfileStatsFile(object):
per_fn = self.data[test_key]
per_platform = per_fn[self.platform_key]
counts = per_platform['counts']
counts[-1] = callcount
current_count = per_platform['current_count']
if current_count < len(counts):
counts[current_count - 1] = callcount
else:
counts[-1] = callcount
if self.write:
self._write()
def _header(self):
return \
"# %s\n"\
"# This file is written out on a per-environment basis.\n"\
"# For each test in aaa_profiling, the corresponding function and \n"\
"# environment is located within this file. If it doesn't exist,\n"\
"# the test is skipped.\n"\
"# If a callcount does exist, it is compared to what we received. \n"\
"# assertions are raised if the counts do not match.\n"\
"# \n"\
"# To add a new callcount test, apply the function_call_count \n"\
"# decorator and re-run the tests using the --write-profiles \n"\
"# option - this file will be rewritten including the new count.\n"\
"# \n"\
"" % (self.fname)
return (
"# %s\n"
"# This file is written out on a per-environment basis.\n"
"# For each test in aaa_profiling, the corresponding "
"function and \n"
"# environment is located within this file. "
"If it doesn't exist,\n"
"# the test is skipped.\n"
"# If a callcount does exist, it is compared "
"to what we received. \n"
"# assertions are raised if the counts do not match.\n"
"# \n"
"# To add a new callcount test, apply the function_call_count \n"
"# decorator and re-run the tests using the --write-profiles \n"
"# option - this file will be rewritten including the new count.\n"
"# \n"
) % (self.fname)
def _read(self):
try:
@ -239,72 +197,69 @@ def function_call_count(variance=0.05):
def decorate(fn):
def wrap(*args, **kw):
if cProfile is None:
raise SkipTest("cProfile is not installed")
if not _profile_stats.has_stats() and not _profile_stats.write:
# run the function anyway, to support dependent tests
# (not a great idea but we have these in test_zoomark)
fn(*args, **kw)
raise SkipTest("No profiling stats available on this "
"platform for this function. Run tests with "
"--write-profiles to add statistics to %s for "
"this platform." % _profile_stats.short_fname)
gc_collect()
timespent, load_stats, fn_result = _profile(
fn, *args, **kw
)
stats = load_stats()
callcount = stats.total_calls
expected = _profile_stats.result(callcount)
if expected is None:
expected_count = None
else:
line_no, expected_count = expected
print(("Pstats calls: %d Expected %s" % (
callcount,
expected_count
)
))
stats.print_stats()
# stats.print_callers()
if expected_count:
deviance = int(callcount * variance)
failed = abs(callcount - expected_count) > deviance
if failed:
if _profile_stats.write:
_profile_stats.replace(callcount)
else:
raise AssertionError(
"Adjusted function call count %s not within %s%% "
"of expected %s. Rerun with --write-profiles to "
"regenerate this callcount."
% (
callcount, (variance * 100),
expected_count))
return fn_result
with count_functions(variance=variance):
return fn(*args, **kw)
return update_wrapper(wrap, fn)
return decorate
def _profile(fn, *args, **kw):
filename = "%s.prof" % fn.__name__
@contextlib.contextmanager
def count_functions(variance=0.05):
if cProfile is None:
raise SkipTest("cProfile is not installed")
def load_stats():
st = pstats.Stats(filename)
os.unlink(filename)
return st
if not _profile_stats.has_stats() and not _profile_stats.write:
config.skip_test(
"No profiling stats available on this "
"platform for this function. Run tests with "
"--write-profiles to add statistics to %s for "
"this platform." % _profile_stats.short_fname)
gc_collect()
pr = cProfile.Profile()
pr.enable()
#began = time.time()
yield
#ended = time.time()
pr.disable()
#s = compat.StringIO()
stats = pstats.Stats(pr, stream=sys.stdout)
#timespent = ended - began
callcount = stats.total_calls
expected = _profile_stats.result(callcount)
if expected is None:
expected_count = None
else:
line_no, expected_count = expected
print(("Pstats calls: %d Expected %s" % (
callcount,
expected_count
)
))
stats.sort_stats("cumulative")
stats.print_stats()
if expected_count:
deviance = int(callcount * variance)
failed = abs(callcount - expected_count) > deviance
if failed or _profile_stats.force_write:
if _profile_stats.write:
_profile_stats.replace(callcount)
else:
raise AssertionError(
"Adjusted function call count %s not within %s%% "
"of expected %s, platform %s. Rerun with "
"--write-profiles to "
"regenerate this callcount."
% (
callcount, (variance * 100),
expected_count, _profile_stats.platform_key))
began = time.time()
cProfile.runctx('result = fn(*args, **kw)', globals(), locals(),
filename=filename)
ended = time.time()
return ended - began, load_stats, locals()['result']

View file

@ -0,0 +1,317 @@
from sqlalchemy.engine import url as sa_url
from sqlalchemy import text
from sqlalchemy import exc
from sqlalchemy.util import compat
from . import config, engines
import time
import logging
import os
log = logging.getLogger(__name__)
FOLLOWER_IDENT = None
class register(object):
def __init__(self):
self.fns = {}
@classmethod
def init(cls, fn):
return register().for_db("*")(fn)
def for_db(self, dbname):
def decorate(fn):
self.fns[dbname] = fn
return self
return decorate
def __call__(self, cfg, *arg):
if isinstance(cfg, compat.string_types):
url = sa_url.make_url(cfg)
elif isinstance(cfg, sa_url.URL):
url = cfg
else:
url = cfg.db.url
backend = url.get_backend_name()
if backend in self.fns:
return self.fns[backend](cfg, *arg)
else:
return self.fns['*'](cfg, *arg)
def create_follower_db(follower_ident):
for cfg in _configs_for_db_operation():
_create_db(cfg, cfg.db, follower_ident)
def configure_follower(follower_ident):
for cfg in config.Config.all_configs():
_configure_follower(cfg, follower_ident)
def setup_config(db_url, options, file_config, follower_ident):
if follower_ident:
db_url = _follower_url_from_main(db_url, follower_ident)
db_opts = {}
_update_db_opts(db_url, db_opts)
eng = engines.testing_engine(db_url, db_opts)
_post_configure_engine(db_url, eng, follower_ident)
eng.connect().close()
cfg = config.Config.register(eng, db_opts, options, file_config)
if follower_ident:
_configure_follower(cfg, follower_ident)
return cfg
def drop_follower_db(follower_ident):
for cfg in _configs_for_db_operation():
_drop_db(cfg, cfg.db, follower_ident)
def _configs_for_db_operation():
hosts = set()
for cfg in config.Config.all_configs():
cfg.db.dispose()
for cfg in config.Config.all_configs():
url = cfg.db.url
backend = url.get_backend_name()
host_conf = (
backend,
url.username, url.host, url.database)
if host_conf not in hosts:
yield cfg
hosts.add(host_conf)
for cfg in config.Config.all_configs():
cfg.db.dispose()
@register.init
def _create_db(cfg, eng, ident):
raise NotImplementedError("no DB creation routine for cfg: %s" % eng.url)
@register.init
def _drop_db(cfg, eng, ident):
raise NotImplementedError("no DB drop routine for cfg: %s" % eng.url)
@register.init
def _update_db_opts(db_url, db_opts):
pass
@register.init
def _configure_follower(cfg, ident):
pass
@register.init
def _post_configure_engine(url, engine, follower_ident):
pass
@register.init
def _follower_url_from_main(url, ident):
url = sa_url.make_url(url)
url.database = ident
return url
@_update_db_opts.for_db("mssql")
def _mssql_update_db_opts(db_url, db_opts):
db_opts['legacy_schema_aliasing'] = False
@_follower_url_from_main.for_db("sqlite")
def _sqlite_follower_url_from_main(url, ident):
url = sa_url.make_url(url)
if not url.database or url.database == ':memory:':
return url
else:
return sa_url.make_url("sqlite:///%s.db" % ident)
@_post_configure_engine.for_db("sqlite")
def _sqlite_post_configure_engine(url, engine, follower_ident):
from sqlalchemy import event
@event.listens_for(engine, "connect")
def connect(dbapi_connection, connection_record):
# use file DBs in all cases, memory acts kind of strangely
# as an attached
if not follower_ident:
dbapi_connection.execute(
'ATTACH DATABASE "test_schema.db" AS test_schema')
else:
dbapi_connection.execute(
'ATTACH DATABASE "%s_test_schema.db" AS test_schema'
% follower_ident)
@_create_db.for_db("postgresql")
def _pg_create_db(cfg, eng, ident):
with eng.connect().execution_options(
isolation_level="AUTOCOMMIT") as conn:
try:
_pg_drop_db(cfg, conn, ident)
except Exception:
pass
currentdb = conn.scalar("select current_database()")
for attempt in range(3):
try:
conn.execute(
"CREATE DATABASE %s TEMPLATE %s" % (ident, currentdb))
except exc.OperationalError as err:
if attempt != 2 and "accessed by other users" in str(err):
time.sleep(.2)
continue
else:
raise
else:
break
@_create_db.for_db("mysql")
def _mysql_create_db(cfg, eng, ident):
with eng.connect() as conn:
try:
_mysql_drop_db(cfg, conn, ident)
except Exception:
pass
conn.execute("CREATE DATABASE %s" % ident)
conn.execute("CREATE DATABASE %s_test_schema" % ident)
conn.execute("CREATE DATABASE %s_test_schema_2" % ident)
@_configure_follower.for_db("mysql")
def _mysql_configure_follower(config, ident):
config.test_schema = "%s_test_schema" % ident
config.test_schema_2 = "%s_test_schema_2" % ident
@_create_db.for_db("sqlite")
def _sqlite_create_db(cfg, eng, ident):
pass
@_drop_db.for_db("postgresql")
def _pg_drop_db(cfg, eng, ident):
with eng.connect().execution_options(
isolation_level="AUTOCOMMIT") as conn:
conn.execute(
text(
"select pg_terminate_backend(pid) from pg_stat_activity "
"where usename=current_user and pid != pg_backend_pid() "
"and datname=:dname"
), dname=ident)
conn.execute("DROP DATABASE %s" % ident)
@_drop_db.for_db("sqlite")
def _sqlite_drop_db(cfg, eng, ident):
if ident:
os.remove("%s_test_schema.db" % ident)
else:
os.remove("%s.db" % ident)
@_drop_db.for_db("mysql")
def _mysql_drop_db(cfg, eng, ident):
with eng.connect() as conn:
try:
conn.execute("DROP DATABASE %s_test_schema" % ident)
except Exception:
pass
try:
conn.execute("DROP DATABASE %s_test_schema_2" % ident)
except Exception:
pass
try:
conn.execute("DROP DATABASE %s" % ident)
except Exception:
pass
@_create_db.for_db("oracle")
def _oracle_create_db(cfg, eng, ident):
# NOTE: make sure you've run "ALTER DATABASE default tablespace users" or
# similar, so that the default tablespace is not "system"; reflection will
# fail otherwise
with eng.connect() as conn:
conn.execute("create user %s identified by xe" % ident)
conn.execute("create user %s_ts1 identified by xe" % ident)
conn.execute("create user %s_ts2 identified by xe" % ident)
conn.execute("grant dba to %s" % (ident, ))
conn.execute("grant unlimited tablespace to %s" % ident)
conn.execute("grant unlimited tablespace to %s_ts1" % ident)
conn.execute("grant unlimited tablespace to %s_ts2" % ident)
@_configure_follower.for_db("oracle")
def _oracle_configure_follower(config, ident):
config.test_schema = "%s_ts1" % ident
config.test_schema_2 = "%s_ts2" % ident
def _ora_drop_ignore(conn, dbname):
try:
conn.execute("drop user %s cascade" % dbname)
log.info("Reaped db: %s" % dbname)
return True
except exc.DatabaseError as err:
log.warn("couldn't drop db: %s" % err)
return False
@_drop_db.for_db("oracle")
def _oracle_drop_db(cfg, eng, ident):
with eng.connect() as conn:
# cx_Oracle seems to occasionally leak open connections when a large
# suite it run, even if we confirm we have zero references to
# connection objects.
# while there is a "kill session" command in Oracle,
# it unfortunately does not release the connection sufficiently.
_ora_drop_ignore(conn, ident)
_ora_drop_ignore(conn, "%s_ts1" % ident)
_ora_drop_ignore(conn, "%s_ts2" % ident)
def reap_oracle_dbs(eng):
log.info("Reaping Oracle dbs...")
with eng.connect() as conn:
to_reap = conn.execute(
"select u.username from all_users u where username "
"like 'TEST_%' and not exists (select username "
"from v$session where username=u.username)")
all_names = set([username.lower() for (username, ) in to_reap])
to_drop = set()
for name in all_names:
if name.endswith("_ts1") or name.endswith("_ts2"):
continue
else:
to_drop.add(name)
if "%s_ts1" % name in all_names:
to_drop.add("%s_ts1" % name)
if "%s_ts2" % name in all_names:
to_drop.add("%s_ts2" % name)
dropped = total = 0
for total, username in enumerate(to_drop, 1):
if _ora_drop_ignore(conn, username):
dropped += 1
log.info(
"Dropped %d out of %d stale databases detected", dropped, total)
@_follower_url_from_main.for_db("oracle")
def _oracle_follower_url_from_main(url, ident):
url = sa_url.make_url(url)
url.username = ident
url.password = 'xe'
return url

View file

@ -0,0 +1,172 @@
from . import fixtures
from . import profiling
from .. import util
import types
from collections import deque
import contextlib
from . import config
from sqlalchemy import MetaData
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
class ReplayFixtureTest(fixtures.TestBase):
@contextlib.contextmanager
def _dummy_ctx(self, *arg, **kw):
yield
def test_invocation(self):
dbapi_session = ReplayableSession()
creator = config.db.pool._creator
recorder = lambda: dbapi_session.recorder(creator())
engine = create_engine(
config.db.url, creator=recorder,
use_native_hstore=False)
self.metadata = MetaData(engine)
self.engine = engine
self.session = Session(engine)
self.setup_engine()
try:
self._run_steps(ctx=self._dummy_ctx)
finally:
self.teardown_engine()
engine.dispose()
player = lambda: dbapi_session.player()
engine = create_engine(
config.db.url, creator=player,
use_native_hstore=False)
self.metadata = MetaData(engine)
self.engine = engine
self.session = Session(engine)
self.setup_engine()
try:
self._run_steps(ctx=profiling.count_functions)
finally:
self.session.close()
engine.dispose()
def setup_engine(self):
pass
def teardown_engine(self):
pass
def _run_steps(self, ctx):
raise NotImplementedError()
class ReplayableSession(object):
"""A simple record/playback tool.
This is *not* a mock testing class. It only records a session for later
playback and makes no assertions on call consistency whatsoever. It's
unlikely to be suitable for anything other than DB-API recording.
"""
Callable = object()
NoAttribute = object()
if util.py2k:
Natives = set([getattr(types, t)
for t in dir(types) if not t.startswith('_')]).\
difference([getattr(types, t)
for t in ('FunctionType', 'BuiltinFunctionType',
'MethodType', 'BuiltinMethodType',
'LambdaType', 'UnboundMethodType',)])
else:
Natives = set([getattr(types, t)
for t in dir(types) if not t.startswith('_')]).\
union([type(t) if not isinstance(t, type)
else t for t in __builtins__.values()]).\
difference([getattr(types, t)
for t in ('FunctionType', 'BuiltinFunctionType',
'MethodType', 'BuiltinMethodType',
'LambdaType', )])
def __init__(self):
self.buffer = deque()
def recorder(self, base):
return self.Recorder(self.buffer, base)
def player(self):
return self.Player(self.buffer)
class Recorder(object):
def __init__(self, buffer, subject):
self._buffer = buffer
self._subject = subject
def __call__(self, *args, **kw):
subject, buffer = [object.__getattribute__(self, x)
for x in ('_subject', '_buffer')]
result = subject(*args, **kw)
if type(result) not in ReplayableSession.Natives:
buffer.append(ReplayableSession.Callable)
return type(self)(buffer, result)
else:
buffer.append(result)
return result
@property
def _sqla_unwrap(self):
return self._subject
def __getattribute__(self, key):
try:
return object.__getattribute__(self, key)
except AttributeError:
pass
subject, buffer = [object.__getattribute__(self, x)
for x in ('_subject', '_buffer')]
try:
result = type(subject).__getattribute__(subject, key)
except AttributeError:
buffer.append(ReplayableSession.NoAttribute)
raise
else:
if type(result) not in ReplayableSession.Natives:
buffer.append(ReplayableSession.Callable)
return type(self)(buffer, result)
else:
buffer.append(result)
return result
class Player(object):
def __init__(self, buffer):
self._buffer = buffer
def __call__(self, *args, **kw):
buffer = object.__getattribute__(self, '_buffer')
result = buffer.popleft()
if result is ReplayableSession.Callable:
return self
else:
return result
@property
def _sqla_unwrap(self):
return None
def __getattribute__(self, key):
try:
return object.__getattribute__(self, key)
except AttributeError:
pass
buffer = object.__getattribute__(self, '_buffer')
result = buffer.popleft()
if result is ReplayableSession.Callable:
return self
elif result is ReplayableSession.NoAttribute:
raise AttributeError(key)
else:
return result

View file

@ -1,5 +1,5 @@
# testing/requirements.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@ -16,6 +16,7 @@ to provide specific inclusion/exclusions.
"""
from . import exclusions
from .. import util
class Requirements(object):
@ -101,6 +102,14 @@ class SuiteRequirements(Requirements):
return exclusions.open()
@property
def bound_limit_offset(self):
"""target database can render LIMIT and/or OFFSET using a bound
parameter
"""
return exclusions.open()
@property
def boolean_col_expressions(self):
"""Target database must support boolean expressions as columns"""
@ -179,7 +188,7 @@ class SuiteRequirements(Requirements):
return exclusions.only_if(
lambda config: config.db.dialect.implicit_returning,
"'returning' not supported by database"
"%(database)s %(does_support)s 'returning'"
)
@property
@ -304,6 +313,25 @@ class SuiteRequirements(Requirements):
def foreign_key_constraint_reflection(self):
return exclusions.open()
@property
def temp_table_reflection(self):
return exclusions.open()
@property
def temp_table_names(self):
"""target dialect supports listing of temporary table names"""
return exclusions.closed()
@property
def temporary_tables(self):
"""target database supports temporary tables"""
return exclusions.open()
@property
def temporary_views(self):
"""target database supports temporary views"""
return exclusions.closed()
@property
def index_reflection(self):
return exclusions.open()
@ -313,6 +341,14 @@ class SuiteRequirements(Requirements):
"""target dialect supports reflection of unique constraints"""
return exclusions.open()
@property
def duplicate_key_raises_integrity_error(self):
"""target dialect raises IntegrityError when reporting an INSERT
with a primary key violation. (hint: it should)
"""
return exclusions.open()
@property
def unbounded_varchar(self):
"""Target database must support VARCHAR with no length"""
@ -584,6 +620,14 @@ class SuiteRequirements(Requirements):
"""
return exclusions.open()
@property
def graceful_disconnects(self):
"""Target driver must raise a DBAPI-level exception, such as
InterfaceError, when the underlying connection has been closed
and the execute() method is called.
"""
return exclusions.open()
@property
def skip_mysql_on_windows(self):
"""Catchall for a large variety of MySQL on Windows failures"""
@ -601,6 +645,38 @@ class SuiteRequirements(Requirements):
return exclusions.skip_if(
lambda config: config.options.low_connections)
@property
def timing_intensive(self):
return exclusions.requires_tag("timing_intensive")
@property
def memory_intensive(self):
return exclusions.requires_tag("memory_intensive")
@property
def threading_with_mock(self):
"""Mark tests that use threading and mock at the same time - stability
issues have been observed with coverage + python 3.3
"""
return exclusions.skip_if(
lambda config: util.py3k and config.options.has_coverage,
"Stability issues with coverage + py3k"
)
@property
def no_coverage(self):
"""Test should be skipped if coverage is enabled.
This is to block tests that exercise libraries that seem to be
sensitive to coverage, such as Postgresql notice logging.
"""
return exclusions.skip_if(
lambda config: config.options.has_coverage,
"Issues observed when coverage is enabled"
)
def _has_mysql_on_windows(self, config):
return False

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python
# testing/runner.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@ -30,7 +30,7 @@ SQLAlchemy itself is possible.
"""
from sqlalchemy.testing.plugin.noseplugin import NoseSQLAlchemy
from .plugin.noseplugin import NoseSQLAlchemy
import nose

View file

@ -1,5 +1,5 @@
# testing/schema.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@ -67,7 +67,7 @@ def Column(*args, **kw):
test_opts = dict([(k, kw.pop(k)) for k in list(kw)
if k.startswith('test_')])
if config.requirements.foreign_key_ddl.predicate(config):
if not config.requirements.foreign_key_ddl.enabled_for_config(config):
args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)]
col = schema.Column(*args, **kw)

View file

@ -1,4 +1,5 @@
from sqlalchemy.testing.suite.test_dialect import *
from sqlalchemy.testing.suite.test_ddl import *
from sqlalchemy.testing.suite.test_insert import *
from sqlalchemy.testing.suite.test_sequence import *

View file

@ -0,0 +1,41 @@
from .. import fixtures, config
from ..config import requirements
from sqlalchemy import exc
from sqlalchemy import Integer, String
from .. import assert_raises
from ..schema import Table, Column
class ExceptionTest(fixtures.TablesTest):
"""Test basic exception wrapping.
DBAPIs vary a lot in exception behavior so to actually anticipate
specific exceptions from real round trips, we need to be conservative.
"""
run_deletes = 'each'
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table('manual_pk', metadata,
Column('id', Integer, primary_key=True, autoincrement=False),
Column('data', String(50))
)
@requirements.duplicate_key_raises_integrity_error
def test_integrity_error(self):
with config.db.begin() as conn:
conn.execute(
self.tables.manual_pk.insert(),
{'id': 1, 'data': 'd1'}
)
assert_raises(
exc.IntegrityError,
conn.execute,
self.tables.manual_pk.insert(),
{'id': 1, 'data': 'd1'}
)

View file

@ -4,7 +4,7 @@ from .. import exclusions
from ..assertions import eq_
from .. import engines
from sqlalchemy import Integer, String, select, util
from sqlalchemy import Integer, String, select, literal_column, literal
from ..schema import Table, Column
@ -90,6 +90,13 @@ class InsertBehaviorTest(fixtures.TablesTest):
Column('id', Integer, primary_key=True, autoincrement=False),
Column('data', String(50))
)
Table('includes_defaults', metadata,
Column('id', Integer, primary_key=True,
test_needs_autoincrement=True),
Column('data', String(50)),
Column('x', Integer, default=5),
Column('y', Integer,
default=literal_column("2", type_=Integer) + literal(2)))
def test_autoclose_on_insert(self):
if requirements.returning.enabled:
@ -102,7 +109,8 @@ class InsertBehaviorTest(fixtures.TablesTest):
self.tables.autoinc_pk.insert(),
data="some data"
)
assert r.closed
assert r._soft_closed
assert not r.closed
assert r.is_insert
assert not r.returns_rows
@ -112,7 +120,8 @@ class InsertBehaviorTest(fixtures.TablesTest):
self.tables.autoinc_pk.insert(),
data="some data"
)
assert r.closed
assert r._soft_closed
assert not r.closed
assert r.is_insert
assert not r.returns_rows
@ -121,7 +130,8 @@ class InsertBehaviorTest(fixtures.TablesTest):
r = config.db.execute(
self.tables.autoinc_pk.insert(),
)
assert r.closed
assert r._soft_closed
assert not r.closed
r = config.db.execute(
self.tables.autoinc_pk.select().
@ -158,6 +168,34 @@ class InsertBehaviorTest(fixtures.TablesTest):
("data3", ), ("data3", )]
)
@requirements.insert_from_select
def test_insert_from_select_with_defaults(self):
table = self.tables.includes_defaults
config.db.execute(
table.insert(),
[
dict(id=1, data="data1"),
dict(id=2, data="data2"),
dict(id=3, data="data3"),
]
)
config.db.execute(
table.insert(inline=True).
from_select(("id", "data",),
select([table.c.id + 5, table.c.data]).
where(table.c.data.in_(["data2", "data3"]))
),
)
eq_(
config.db.execute(
select([table]).order_by(table.c.data, table.c.id)
).fetchall(),
[(1, 'data1', 5, 4), (2, 'data2', 5, 4),
(7, 'data2', 5, 4), (3, 'data3', 5, 4), (8, 'data3', 5, 4)]
)
class ReturningTest(fixtures.TablesTest):
run_create_tables = 'each'

View file

@ -39,11 +39,20 @@ class ComponentReflectionTest(fixtures.TablesTest):
__backend__ = True
@classmethod
def setup_bind(cls):
if config.requirements.independent_connections.enabled:
from sqlalchemy import pool
return engines.testing_engine(
options=dict(poolclass=pool.StaticPool))
else:
return config.db
@classmethod
def define_tables(cls, metadata):
cls.define_reflected_tables(metadata, None)
if testing.requires.schemas.enabled:
cls.define_reflected_tables(metadata, "test_schema")
cls.define_reflected_tables(metadata, testing.config.test_schema)
@classmethod
def define_reflected_tables(cls, metadata, schema):
@ -95,6 +104,43 @@ class ComponentReflectionTest(fixtures.TablesTest):
cls.define_index(metadata, users)
if testing.requires.view_column_reflection.enabled:
cls.define_views(metadata, schema)
if not schema and testing.requires.temp_table_reflection.enabled:
cls.define_temp_tables(metadata)
@classmethod
def define_temp_tables(cls, metadata):
# cheat a bit, we should fix this with some dialect-level
# temp table fixture
if testing.against("oracle"):
kw = {
'prefixes': ["GLOBAL TEMPORARY"],
'oracle_on_commit': 'PRESERVE ROWS'
}
else:
kw = {
'prefixes': ["TEMPORARY"],
}
user_tmp = Table(
"user_tmp", metadata,
Column("id", sa.INT, primary_key=True),
Column('name', sa.VARCHAR(50)),
Column('foo', sa.INT),
sa.UniqueConstraint('name', name='user_tmp_uq'),
sa.Index("user_tmp_ix", "foo"),
**kw
)
if testing.requires.view_reflection.enabled and \
testing.requires.temporary_views.enabled:
event.listen(
user_tmp, "after_create",
DDL("create temporary view user_tmp_v as "
"select * from user_tmp")
)
event.listen(
user_tmp, "before_drop",
DDL("drop view user_tmp_v")
)
@classmethod
def define_index(cls, metadata, users):
@ -126,7 +172,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
def test_get_schema_names(self):
insp = inspect(testing.db)
self.assert_('test_schema' in insp.get_schema_names())
self.assert_(testing.config.test_schema in insp.get_schema_names())
@testing.requires.schema_reflection
def test_dialect_initialize(self):
@ -147,6 +193,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
users, addresses, dingalings = self.tables.users, \
self.tables.email_addresses, self.tables.dingalings
insp = inspect(meta.bind)
if table_type == 'view':
table_names = insp.get_view_names(schema)
table_names.sort()
@ -162,6 +209,20 @@ class ComponentReflectionTest(fixtures.TablesTest):
answer = ['dingalings', 'email_addresses', 'users']
eq_(sorted(table_names), answer)
@testing.requires.temp_table_names
def test_get_temp_table_names(self):
insp = inspect(self.bind)
temp_table_names = insp.get_temp_table_names()
eq_(sorted(temp_table_names), ['user_tmp'])
@testing.requires.view_reflection
@testing.requires.temp_table_names
@testing.requires.temporary_views
def test_get_temp_view_names(self):
insp = inspect(self.bind)
temp_table_names = insp.get_temp_view_names()
eq_(sorted(temp_table_names), ['user_tmp_v'])
@testing.requires.table_reflection
def test_get_table_names(self):
self._test_get_table_names()
@ -174,7 +235,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.table_reflection
@testing.requires.schemas
def test_get_table_names_with_schema(self):
self._test_get_table_names('test_schema')
self._test_get_table_names(testing.config.test_schema)
@testing.requires.view_column_reflection
def test_get_view_names(self):
@ -183,7 +244,8 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.view_column_reflection
@testing.requires.schemas
def test_get_view_names_with_schema(self):
self._test_get_table_names('test_schema', table_type='view')
self._test_get_table_names(
testing.config.test_schema, table_type='view')
@testing.requires.table_reflection
@testing.requires.view_column_reflection
@ -291,7 +353,29 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.table_reflection
@testing.requires.schemas
def test_get_columns_with_schema(self):
self._test_get_columns(schema='test_schema')
self._test_get_columns(schema=testing.config.test_schema)
@testing.requires.temp_table_reflection
def test_get_temp_table_columns(self):
meta = MetaData(self.bind)
user_tmp = self.tables.user_tmp
insp = inspect(meta.bind)
cols = insp.get_columns('user_tmp')
self.assert_(len(cols) > 0, len(cols))
for i, col in enumerate(user_tmp.columns):
eq_(col.name, cols[i]['name'])
@testing.requires.temp_table_reflection
@testing.requires.view_column_reflection
@testing.requires.temporary_views
def test_get_temp_view_columns(self):
insp = inspect(self.bind)
cols = insp.get_columns('user_tmp_v')
eq_(
[col['name'] for col in cols],
['id', 'name', 'foo']
)
@testing.requires.view_column_reflection
def test_get_view_columns(self):
@ -300,7 +384,8 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.view_column_reflection
@testing.requires.schemas
def test_get_view_columns_with_schema(self):
self._test_get_columns(schema='test_schema', table_type='view')
self._test_get_columns(
schema=testing.config.test_schema, table_type='view')
@testing.provide_metadata
def _test_get_pk_constraint(self, schema=None):
@ -327,7 +412,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.primary_key_constraint_reflection
@testing.requires.schemas
def test_get_pk_constraint_with_schema(self):
self._test_get_pk_constraint(schema='test_schema')
self._test_get_pk_constraint(schema=testing.config.test_schema)
@testing.requires.table_reflection
@testing.provide_metadata
@ -385,7 +470,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.foreign_key_constraint_reflection
@testing.requires.schemas
def test_get_foreign_keys_with_schema(self):
self._test_get_foreign_keys(schema='test_schema')
self._test_get_foreign_keys(schema=testing.config.test_schema)
@testing.provide_metadata
def _test_get_indexes(self, schema=None):
@ -418,25 +503,57 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.index_reflection
@testing.requires.schemas
def test_get_indexes_with_schema(self):
self._test_get_indexes(schema='test_schema')
self._test_get_indexes(schema=testing.config.test_schema)
@testing.requires.unique_constraint_reflection
def test_get_unique_constraints(self):
self._test_get_unique_constraints()
@testing.requires.temp_table_reflection
@testing.requires.unique_constraint_reflection
def test_get_temp_table_unique_constraints(self):
insp = inspect(self.bind)
reflected = insp.get_unique_constraints('user_tmp')
for refl in reflected:
# Different dialects handle duplicate index and constraints
# differently, so ignore this flag
refl.pop('duplicates_index', None)
eq_(reflected, [{'column_names': ['name'], 'name': 'user_tmp_uq'}])
@testing.requires.temp_table_reflection
def test_get_temp_table_indexes(self):
insp = inspect(self.bind)
indexes = insp.get_indexes('user_tmp')
for ind in indexes:
ind.pop('dialect_options', None)
eq_(
# TODO: we need to add better filtering for indexes/uq constraints
# that are doubled up
[idx for idx in indexes if idx['name'] == 'user_tmp_ix'],
[{'unique': False, 'column_names': ['foo'], 'name': 'user_tmp_ix'}]
)
@testing.requires.unique_constraint_reflection
@testing.requires.schemas
def test_get_unique_constraints_with_schema(self):
self._test_get_unique_constraints(schema='test_schema')
self._test_get_unique_constraints(schema=testing.config.test_schema)
@testing.provide_metadata
def _test_get_unique_constraints(self, schema=None):
# SQLite dialect needs to parse the names of the constraints
# separately from what it gets from PRAGMA index_list(), and
# then matches them up. so same set of column_names in two
# constraints will confuse it. Perhaps we should no longer
# bother with index_list() here since we have the whole
# CREATE TABLE?
uniques = sorted(
[
{'name': 'unique_a', 'column_names': ['a']},
{'name': 'unique_a_b_c', 'column_names': ['a', 'b', 'c']},
{'name': 'unique_c_a_b', 'column_names': ['c', 'a', 'b']},
{'name': 'unique_asc_key', 'column_names': ['asc', 'key']},
{'name': 'i.have.dots', 'column_names': ['b']},
{'name': 'i have spaces', 'column_names': ['c']},
],
key=operator.itemgetter('name')
)
@ -464,6 +581,9 @@ class ComponentReflectionTest(fixtures.TablesTest):
)
for orig, refl in zip(uniques, reflected):
# Different dialects handle duplicate index and constraints
# differently, so ignore this flag
refl.pop('duplicates_index', None)
eq_(orig, refl)
@testing.provide_metadata
@ -486,7 +606,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.view_reflection
@testing.requires.schemas
def test_get_view_definition_with_schema(self):
self._test_get_view_definition(schema='test_schema')
self._test_get_view_definition(schema=testing.config.test_schema)
@testing.only_on("postgresql", "PG specific feature")
@testing.provide_metadata
@ -503,7 +623,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.schemas
def test_get_table_oid_with_schema(self):
self._test_get_table_oid('users', schema='test_schema')
self._test_get_table_oid('users', schema=testing.config.test_schema)
@testing.requires.table_reflection
@testing.provide_metadata

View file

@ -2,13 +2,13 @@ from .. import fixtures, config
from ..assertions import eq_
from sqlalchemy import util
from sqlalchemy import Integer, String, select, func
from sqlalchemy import Integer, String, select, func, bindparam
from sqlalchemy import testing
from ..schema import Table, Column
class OrderByLabelTest(fixtures.TablesTest):
"""Test the dialect sends appropriate ORDER BY expressions when
labels are used.
@ -85,3 +85,108 @@ class OrderByLabelTest(fixtures.TablesTest):
select([lx]).order_by(lx.desc()),
[(7, ), (5, ), (3, )]
)
def test_group_by_composed(self):
table = self.tables.some_table
expr = (table.c.x + table.c.y).label('lx')
stmt = select([func.count(table.c.id), expr]).group_by(expr).order_by(expr)
self._assert_result(
stmt,
[(1, 3), (1, 5), (1, 7)]
)
class LimitOffsetTest(fixtures.TablesTest):
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table("some_table", metadata,
Column('id', Integer, primary_key=True),
Column('x', Integer),
Column('y', Integer))
@classmethod
def insert_data(cls):
config.db.execute(
cls.tables.some_table.insert(),
[
{"id": 1, "x": 1, "y": 2},
{"id": 2, "x": 2, "y": 3},
{"id": 3, "x": 3, "y": 4},
{"id": 4, "x": 4, "y": 5},
]
)
def _assert_result(self, select, result, params=()):
eq_(
config.db.execute(select, params).fetchall(),
result
)
def test_simple_limit(self):
table = self.tables.some_table
self._assert_result(
select([table]).order_by(table.c.id).limit(2),
[(1, 1, 2), (2, 2, 3)]
)
@testing.requires.offset
def test_simple_offset(self):
table = self.tables.some_table
self._assert_result(
select([table]).order_by(table.c.id).offset(2),
[(3, 3, 4), (4, 4, 5)]
)
@testing.requires.offset
def test_simple_limit_offset(self):
table = self.tables.some_table
self._assert_result(
select([table]).order_by(table.c.id).limit(2).offset(1),
[(2, 2, 3), (3, 3, 4)]
)
@testing.requires.offset
def test_limit_offset_nobinds(self):
"""test that 'literal binds' mode works - no bound params."""
table = self.tables.some_table
stmt = select([table]).order_by(table.c.id).limit(2).offset(1)
sql = stmt.compile(
dialect=config.db.dialect,
compile_kwargs={"literal_binds": True})
sql = str(sql)
self._assert_result(
sql,
[(2, 2, 3), (3, 3, 4)]
)
@testing.requires.bound_limit_offset
def test_bound_limit(self):
table = self.tables.some_table
self._assert_result(
select([table]).order_by(table.c.id).limit(bindparam('l')),
[(1, 1, 2), (2, 2, 3)],
params={"l": 2}
)
@testing.requires.bound_limit_offset
def test_bound_offset(self):
table = self.tables.some_table
self._assert_result(
select([table]).order_by(table.c.id).offset(bindparam('o')),
[(3, 3, 4), (4, 4, 5)],
params={"o": 2}
)
@testing.requires.bound_limit_offset
def test_bound_limit_offset(self):
table = self.tables.some_table
self._assert_result(
select([table]).order_by(table.c.id).
limit(bindparam("l")).offset(bindparam("o")),
[(2, 2, 3), (3, 3, 4)],
params={"l": 2, "o": 1}
)

View file

@ -86,11 +86,11 @@ class HasSequenceTest(fixtures.TestBase):
@testing.requires.schemas
def test_has_sequence_schema(self):
s1 = Sequence('user_id_seq', schema="test_schema")
s1 = Sequence('user_id_seq', schema=config.test_schema)
testing.db.execute(schema.CreateSequence(s1))
try:
eq_(testing.db.dialect.has_sequence(
testing.db, 'user_id_seq', schema="test_schema"), True)
testing.db, 'user_id_seq', schema=config.test_schema), True)
finally:
testing.db.execute(schema.DropSequence(s1))
@ -101,7 +101,7 @@ class HasSequenceTest(fixtures.TestBase):
@testing.requires.schemas
def test_has_sequence_schemas_neg(self):
eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq',
schema="test_schema"),
schema=config.test_schema),
False)
@testing.requires.schemas
@ -110,14 +110,14 @@ class HasSequenceTest(fixtures.TestBase):
testing.db.execute(schema.CreateSequence(s1))
try:
eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq',
schema="test_schema"),
schema=config.test_schema),
False)
finally:
testing.db.execute(schema.DropSequence(s1))
@testing.requires.schemas
def test_has_sequence_remote_not_in_default(self):
s1 = Sequence('user_id_seq', schema="test_schema")
s1 = Sequence('user_id_seq', schema=config.test_schema)
testing.db.execute(schema.CreateSequence(s1))
try:
eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'),

View file

@ -1,5 +1,5 @@
# testing/util.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@ -147,6 +147,10 @@ def run_as_contextmanager(ctx, fn, *arg, **kw):
simulating the behavior of 'with' to support older
Python versions.
This is not necessary anymore as we have placed 2.6
as minimum Python version, however some tests are still using
this structure.
"""
obj = ctx.__enter__()
@ -181,6 +185,7 @@ def provide_metadata(fn, *args, **kw):
"""Provide bound MetaData for a single test, dropping afterwards."""
from . import config
from . import engines
from sqlalchemy import schema
metadata = schema.MetaData(config.db)
@ -190,10 +195,29 @@ def provide_metadata(fn, *args, **kw):
try:
return fn(*args, **kw)
finally:
metadata.drop_all()
engines.drop_all_tables(metadata, config.db)
self.metadata = prev_meta
def force_drop_names(*names):
"""Force the given table names to be dropped after test complete,
isolating for foreign key cycles
"""
from . import config
from sqlalchemy import inspect
@decorator
def go(fn, *args, **kw):
try:
return fn(*args, **kw)
finally:
drop_all_tables(
config.db, inspect(config.db), include_names=names)
return go
class adict(dict):
"""Dict keys available as attributes. Shadows."""
@ -203,5 +227,54 @@ class adict(dict):
except KeyError:
return dict.__getattribute__(self, key)
def get_all(self, *keys):
def __call__(self, *keys):
return tuple([self[key] for key in keys])
get_all = __call__
def drop_all_tables(engine, inspector, schema=None, include_names=None):
from sqlalchemy import Column, Table, Integer, MetaData, \
ForeignKeyConstraint
from sqlalchemy.schema import DropTable, DropConstraint
if include_names is not None:
include_names = set(include_names)
with engine.connect() as conn:
for tname, fkcs in reversed(
inspector.get_sorted_table_and_fkc_names(schema=schema)):
if tname:
if include_names is not None and tname not in include_names:
continue
conn.execute(DropTable(
Table(tname, MetaData(), schema=schema)
))
elif fkcs:
if not engine.dialect.supports_alter:
continue
for tname, fkc in fkcs:
if include_names is not None and \
tname not in include_names:
continue
tb = Table(
tname, MetaData(),
Column('x', Integer),
Column('y', Integer),
schema=schema
)
conn.execute(DropConstraint(
ForeignKeyConstraint(
[tb.c.x], [tb.c.y], name=fkc)
))
def teardown_events(event_cls):
@decorator
def decorate(fn, *arg, **kw):
try:
return fn(*arg, **kw)
finally:
event_cls._clear()
return decorate

View file

@ -1,5 +1,5 @@
# testing/warnings.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@ -9,25 +9,11 @@ from __future__ import absolute_import
import warnings
from .. import exc as sa_exc
from .. import util
import re
from . import assertions
def testing_warn(msg, stacklevel=3):
"""Replaces sqlalchemy.util.warn during tests."""
filename = "sqlalchemy.testing.warnings"
lineno = 1
if isinstance(msg, util.string_types):
warnings.warn_explicit(msg, sa_exc.SAWarning, filename, lineno)
else:
warnings.warn_explicit(msg, filename, lineno)
def resetwarnings():
"""Reset warning behavior to testing defaults."""
util.warn = util.langhelpers.warn = testing_warn
def setup_filters():
"""Set global warning behavior for the test suite."""
warnings.filterwarnings('ignore',
category=sa_exc.SAPendingDeprecationWarning)
@ -35,24 +21,14 @@ def resetwarnings():
warnings.filterwarnings('error', category=sa_exc.SAWarning)
def assert_warnings(fn, warnings, regex=False):
"""Assert that each of the given warnings are emitted by fn."""
def assert_warnings(fn, warning_msgs, regex=False):
"""Assert that each of the given warnings are emitted by fn.
from .assertions import eq_, emits_warning
Deprecated. Please use assertions.expect_warnings().
canary = []
orig_warn = util.warn
"""
def capture_warnings(*args, **kw):
orig_warn(*args, **kw)
popwarn = warnings.pop(0)
canary.append(popwarn)
if regex:
assert re.match(popwarn, args[0])
else:
eq_(args[0], popwarn)
util.warn = util.langhelpers.warn = capture_warnings
with assertions._expect_warnings(
sa_exc.SAWarning, warning_msgs, regex=regex):
return fn()
result = emits_warning()(fn)()
assert canary, "No warning was emitted"
return result