612 lines
20 KiB
Python
612 lines
20 KiB
Python
# sql/util.py
|
|
# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
|
|
# <see AUTHORS file>
|
|
#
|
|
# This module is part of SQLAlchemy and is released under
|
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
|
|
|
"""High level utilities which build upon other modules here.
|
|
|
|
"""
|
|
|
|
from .. import exc, util
|
|
from .base import _from_objects, ColumnSet
|
|
from . import operators, visitors
|
|
from itertools import chain
|
|
from collections import deque
|
|
|
|
from .elements import BindParameter, ColumnClause, ColumnElement, \
|
|
Null, UnaryExpression, literal_column, Label, _label_reference, \
|
|
_textual_label_reference
|
|
from .selectable import ScalarSelect, Join, FromClause, FromGrouping
|
|
from .schema import Column
|
|
|
|
join_condition = util.langhelpers.public_factory(
|
|
Join._join_condition,
|
|
".sql.util.join_condition")
|
|
|
|
# names that are still being imported from the outside
|
|
from .annotation import _shallow_annotate, _deep_annotate, _deep_deannotate
|
|
from .elements import _find_columns
|
|
from .ddl import sort_tables
|
|
|
|
|
|
def find_join_source(clauses, join_to):
|
|
"""Given a list of FROM clauses and a selectable,
|
|
return the first index and element from the list of
|
|
clauses which can be joined against the selectable. returns
|
|
None, None if no match is found.
|
|
|
|
e.g.::
|
|
|
|
clause1 = table1.join(table2)
|
|
clause2 = table4.join(table5)
|
|
|
|
join_to = table2.join(table3)
|
|
|
|
find_join_source([clause1, clause2], join_to) == clause1
|
|
|
|
"""
|
|
|
|
selectables = list(_from_objects(join_to))
|
|
for i, f in enumerate(clauses):
|
|
for s in selectables:
|
|
if f.is_derived_from(s):
|
|
return i, f
|
|
else:
|
|
return None, None
|
|
|
|
|
|
def visit_binary_product(fn, expr):
|
|
"""Produce a traversal of the given expression, delivering
|
|
column comparisons to the given function.
|
|
|
|
The function is of the form::
|
|
|
|
def my_fn(binary, left, right)
|
|
|
|
For each binary expression located which has a
|
|
comparison operator, the product of "left" and
|
|
"right" will be delivered to that function,
|
|
in terms of that binary.
|
|
|
|
Hence an expression like::
|
|
|
|
and_(
|
|
(a + b) == q + func.sum(e + f),
|
|
j == r
|
|
)
|
|
|
|
would have the traversal::
|
|
|
|
a <eq> q
|
|
a <eq> e
|
|
a <eq> f
|
|
b <eq> q
|
|
b <eq> e
|
|
b <eq> f
|
|
j <eq> r
|
|
|
|
That is, every combination of "left" and
|
|
"right" that doesn't further contain
|
|
a binary comparison is passed as pairs.
|
|
|
|
"""
|
|
stack = []
|
|
|
|
def visit(element):
|
|
if isinstance(element, ScalarSelect):
|
|
# we don't want to dig into correlated subqueries,
|
|
# those are just column elements by themselves
|
|
yield element
|
|
elif element.__visit_name__ == 'binary' and \
|
|
operators.is_comparison(element.operator):
|
|
stack.insert(0, element)
|
|
for l in visit(element.left):
|
|
for r in visit(element.right):
|
|
fn(stack[0], l, r)
|
|
stack.pop(0)
|
|
for elem in element.get_children():
|
|
visit(elem)
|
|
else:
|
|
if isinstance(element, ColumnClause):
|
|
yield element
|
|
for elem in element.get_children():
|
|
for e in visit(elem):
|
|
yield e
|
|
list(visit(expr))
|
|
|
|
|
|
def find_tables(clause, check_columns=False,
|
|
include_aliases=False, include_joins=False,
|
|
include_selects=False, include_crud=False):
|
|
"""locate Table objects within the given expression."""
|
|
|
|
tables = []
|
|
_visitors = {}
|
|
|
|
if include_selects:
|
|
_visitors['select'] = _visitors['compound_select'] = tables.append
|
|
|
|
if include_joins:
|
|
_visitors['join'] = tables.append
|
|
|
|
if include_aliases:
|
|
_visitors['alias'] = tables.append
|
|
|
|
if include_crud:
|
|
_visitors['insert'] = _visitors['update'] = \
|
|
_visitors['delete'] = lambda ent: tables.append(ent.table)
|
|
|
|
if check_columns:
|
|
def visit_column(column):
|
|
tables.append(column.table)
|
|
_visitors['column'] = visit_column
|
|
|
|
_visitors['table'] = tables.append
|
|
|
|
visitors.traverse(clause, {'column_collections': False}, _visitors)
|
|
return tables
|
|
|
|
|
|
def unwrap_order_by(clause):
|
|
"""Break up an 'order by' expression into individual column-expressions,
|
|
without DESC/ASC/NULLS FIRST/NULLS LAST"""
|
|
|
|
cols = util.column_set()
|
|
stack = deque([clause])
|
|
while stack:
|
|
t = stack.popleft()
|
|
if isinstance(t, ColumnElement) and \
|
|
(
|
|
not isinstance(t, UnaryExpression) or
|
|
not operators.is_ordering_modifier(t.modifier)
|
|
):
|
|
if isinstance(t, _label_reference):
|
|
t = t.element
|
|
if isinstance(t, (_textual_label_reference)):
|
|
continue
|
|
cols.add(t)
|
|
else:
|
|
for c in t.get_children():
|
|
stack.append(c)
|
|
return cols
|
|
|
|
|
|
def clause_is_present(clause, search):
|
|
"""Given a target clause and a second to search within, return True
|
|
if the target is plainly present in the search without any
|
|
subqueries or aliases involved.
|
|
|
|
Basically descends through Joins.
|
|
|
|
"""
|
|
|
|
for elem in surface_selectables(search):
|
|
if clause == elem: # use == here so that Annotated's compare
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def surface_selectables(clause):
|
|
stack = [clause]
|
|
while stack:
|
|
elem = stack.pop()
|
|
yield elem
|
|
if isinstance(elem, Join):
|
|
stack.extend((elem.left, elem.right))
|
|
elif isinstance(elem, FromGrouping):
|
|
stack.append(elem.element)
|
|
|
|
|
|
def selectables_overlap(left, right):
|
|
"""Return True if left/right have some overlapping selectable"""
|
|
|
|
return bool(
|
|
set(surface_selectables(left)).intersection(
|
|
surface_selectables(right)
|
|
)
|
|
)
|
|
|
|
|
|
def bind_values(clause):
|
|
"""Return an ordered list of "bound" values in the given clause.
|
|
|
|
E.g.::
|
|
|
|
>>> expr = and_(
|
|
... table.c.foo==5, table.c.foo==7
|
|
... )
|
|
>>> bind_values(expr)
|
|
[5, 7]
|
|
"""
|
|
|
|
v = []
|
|
|
|
def visit_bindparam(bind):
|
|
v.append(bind.effective_value)
|
|
|
|
visitors.traverse(clause, {}, {'bindparam': visit_bindparam})
|
|
return v
|
|
|
|
|
|
def _quote_ddl_expr(element):
|
|
if isinstance(element, util.string_types):
|
|
element = element.replace("'", "''")
|
|
return "'%s'" % element
|
|
else:
|
|
return repr(element)
|
|
|
|
|
|
class _repr_params(object):
|
|
"""A string view of bound parameters, truncating
|
|
display to the given number of 'multi' parameter sets.
|
|
|
|
"""
|
|
|
|
def __init__(self, params, batches):
|
|
self.params = params
|
|
self.batches = batches
|
|
|
|
def __repr__(self):
|
|
if isinstance(self.params, (list, tuple)) and \
|
|
len(self.params) > self.batches and \
|
|
isinstance(self.params[0], (list, dict, tuple)):
|
|
msg = " ... displaying %i of %i total bound parameter sets ... "
|
|
return ' '.join((
|
|
repr(self.params[:self.batches - 2])[0:-1],
|
|
msg % (self.batches, len(self.params)),
|
|
repr(self.params[-2:])[1:]
|
|
))
|
|
else:
|
|
return repr(self.params)
|
|
|
|
|
|
def adapt_criterion_to_null(crit, nulls):
|
|
"""given criterion containing bind params, convert selected elements
|
|
to IS NULL.
|
|
|
|
"""
|
|
|
|
def visit_binary(binary):
|
|
if isinstance(binary.left, BindParameter) \
|
|
and binary.left._identifying_key in nulls:
|
|
# reverse order if the NULL is on the left side
|
|
binary.left = binary.right
|
|
binary.right = Null()
|
|
binary.operator = operators.is_
|
|
binary.negate = operators.isnot
|
|
elif isinstance(binary.right, BindParameter) \
|
|
and binary.right._identifying_key in nulls:
|
|
binary.right = Null()
|
|
binary.operator = operators.is_
|
|
binary.negate = operators.isnot
|
|
|
|
return visitors.cloned_traverse(crit, {}, {'binary': visit_binary})
|
|
|
|
|
|
def splice_joins(left, right, stop_on=None):
|
|
if left is None:
|
|
return right
|
|
|
|
stack = [(right, None)]
|
|
|
|
adapter = ClauseAdapter(left)
|
|
ret = None
|
|
while stack:
|
|
(right, prevright) = stack.pop()
|
|
if isinstance(right, Join) and right is not stop_on:
|
|
right = right._clone()
|
|
right._reset_exported()
|
|
right.onclause = adapter.traverse(right.onclause)
|
|
stack.append((right.left, right))
|
|
else:
|
|
right = adapter.traverse(right)
|
|
if prevright is not None:
|
|
prevright.left = right
|
|
if ret is None:
|
|
ret = right
|
|
|
|
return ret
|
|
|
|
|
|
def reduce_columns(columns, *clauses, **kw):
|
|
"""given a list of columns, return a 'reduced' set based on natural
|
|
equivalents.
|
|
|
|
the set is reduced to the smallest list of columns which have no natural
|
|
equivalent present in the list. A "natural equivalent" means that two
|
|
columns will ultimately represent the same value because they are related
|
|
by a foreign key.
|
|
|
|
\*clauses is an optional list of join clauses which will be traversed
|
|
to further identify columns that are "equivalent".
|
|
|
|
\**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys
|
|
whose tables are not yet configured, or columns that aren't yet present.
|
|
|
|
This function is primarily used to determine the most minimal "primary
|
|
key" from a selectable, by reducing the set of primary key columns present
|
|
in the selectable to just those that are not repeated.
|
|
|
|
"""
|
|
ignore_nonexistent_tables = kw.pop('ignore_nonexistent_tables', False)
|
|
only_synonyms = kw.pop('only_synonyms', False)
|
|
|
|
columns = util.ordered_column_set(columns)
|
|
|
|
omit = util.column_set()
|
|
for col in columns:
|
|
for fk in chain(*[c.foreign_keys for c in col.proxy_set]):
|
|
for c in columns:
|
|
if c is col:
|
|
continue
|
|
try:
|
|
fk_col = fk.column
|
|
except exc.NoReferencedColumnError:
|
|
# TODO: add specific coverage here
|
|
# to test/sql/test_selectable ReduceTest
|
|
if ignore_nonexistent_tables:
|
|
continue
|
|
else:
|
|
raise
|
|
except exc.NoReferencedTableError:
|
|
# TODO: add specific coverage here
|
|
# to test/sql/test_selectable ReduceTest
|
|
if ignore_nonexistent_tables:
|
|
continue
|
|
else:
|
|
raise
|
|
if fk_col.shares_lineage(c) and \
|
|
(not only_synonyms or
|
|
c.name == col.name):
|
|
omit.add(col)
|
|
break
|
|
|
|
if clauses:
|
|
def visit_binary(binary):
|
|
if binary.operator == operators.eq:
|
|
cols = util.column_set(
|
|
chain(*[c.proxy_set for c in columns.difference(omit)]))
|
|
if binary.left in cols and binary.right in cols:
|
|
for c in reversed(columns):
|
|
if c.shares_lineage(binary.right) and \
|
|
(not only_synonyms or
|
|
c.name == binary.left.name):
|
|
omit.add(c)
|
|
break
|
|
for clause in clauses:
|
|
if clause is not None:
|
|
visitors.traverse(clause, {}, {'binary': visit_binary})
|
|
|
|
return ColumnSet(columns.difference(omit))
|
|
|
|
|
|
def criterion_as_pairs(expression, consider_as_foreign_keys=None,
|
|
consider_as_referenced_keys=None, any_operator=False):
|
|
"""traverse an expression and locate binary criterion pairs."""
|
|
|
|
if consider_as_foreign_keys and consider_as_referenced_keys:
|
|
raise exc.ArgumentError("Can only specify one of "
|
|
"'consider_as_foreign_keys' or "
|
|
"'consider_as_referenced_keys'")
|
|
|
|
def col_is(a, b):
|
|
# return a is b
|
|
return a.compare(b)
|
|
|
|
def visit_binary(binary):
|
|
if not any_operator and binary.operator is not operators.eq:
|
|
return
|
|
if not isinstance(binary.left, ColumnElement) or \
|
|
not isinstance(binary.right, ColumnElement):
|
|
return
|
|
|
|
if consider_as_foreign_keys:
|
|
if binary.left in consider_as_foreign_keys and \
|
|
(col_is(binary.right, binary.left) or
|
|
binary.right not in consider_as_foreign_keys):
|
|
pairs.append((binary.right, binary.left))
|
|
elif binary.right in consider_as_foreign_keys and \
|
|
(col_is(binary.left, binary.right) or
|
|
binary.left not in consider_as_foreign_keys):
|
|
pairs.append((binary.left, binary.right))
|
|
elif consider_as_referenced_keys:
|
|
if binary.left in consider_as_referenced_keys and \
|
|
(col_is(binary.right, binary.left) or
|
|
binary.right not in consider_as_referenced_keys):
|
|
pairs.append((binary.left, binary.right))
|
|
elif binary.right in consider_as_referenced_keys and \
|
|
(col_is(binary.left, binary.right) or
|
|
binary.left not in consider_as_referenced_keys):
|
|
pairs.append((binary.right, binary.left))
|
|
else:
|
|
if isinstance(binary.left, Column) and \
|
|
isinstance(binary.right, Column):
|
|
if binary.left.references(binary.right):
|
|
pairs.append((binary.right, binary.left))
|
|
elif binary.right.references(binary.left):
|
|
pairs.append((binary.left, binary.right))
|
|
pairs = []
|
|
visitors.traverse(expression, {}, {'binary': visit_binary})
|
|
return pairs
|
|
|
|
|
|
class ClauseAdapter(visitors.ReplacingCloningVisitor):
|
|
"""Clones and modifies clauses based on column correspondence.
|
|
|
|
E.g.::
|
|
|
|
table1 = Table('sometable', metadata,
|
|
Column('col1', Integer),
|
|
Column('col2', Integer)
|
|
)
|
|
table2 = Table('someothertable', metadata,
|
|
Column('col1', Integer),
|
|
Column('col2', Integer)
|
|
)
|
|
|
|
condition = table1.c.col1 == table2.c.col1
|
|
|
|
make an alias of table1::
|
|
|
|
s = table1.alias('foo')
|
|
|
|
calling ``ClauseAdapter(s).traverse(condition)`` converts
|
|
condition to read::
|
|
|
|
s.c.col1 == table2.c.col1
|
|
|
|
"""
|
|
|
|
def __init__(self, selectable, equivalents=None,
|
|
include_fn=None, exclude_fn=None,
|
|
adapt_on_names=False, anonymize_labels=False):
|
|
self.__traverse_options__ = {
|
|
'stop_on': [selectable],
|
|
'anonymize_labels': anonymize_labels}
|
|
self.selectable = selectable
|
|
self.include_fn = include_fn
|
|
self.exclude_fn = exclude_fn
|
|
self.equivalents = util.column_dict(equivalents or {})
|
|
self.adapt_on_names = adapt_on_names
|
|
|
|
def _corresponding_column(self, col, require_embedded,
|
|
_seen=util.EMPTY_SET):
|
|
newcol = self.selectable.corresponding_column(
|
|
col,
|
|
require_embedded=require_embedded)
|
|
if newcol is None and col in self.equivalents and col not in _seen:
|
|
for equiv in self.equivalents[col]:
|
|
newcol = self._corresponding_column(
|
|
equiv, require_embedded=require_embedded,
|
|
_seen=_seen.union([col]))
|
|
if newcol is not None:
|
|
return newcol
|
|
if self.adapt_on_names and newcol is None:
|
|
newcol = self.selectable.c.get(col.name)
|
|
return newcol
|
|
|
|
def replace(self, col):
|
|
if isinstance(col, FromClause) and \
|
|
self.selectable.is_derived_from(col):
|
|
return self.selectable
|
|
elif not isinstance(col, ColumnElement):
|
|
return None
|
|
elif self.include_fn and not self.include_fn(col):
|
|
return None
|
|
elif self.exclude_fn and self.exclude_fn(col):
|
|
return None
|
|
else:
|
|
return self._corresponding_column(col, True)
|
|
|
|
|
|
class ColumnAdapter(ClauseAdapter):
|
|
"""Extends ClauseAdapter with extra utility functions.
|
|
|
|
Key aspects of ColumnAdapter include:
|
|
|
|
* Expressions that are adapted are stored in a persistent
|
|
.columns collection; so that an expression E adapted into
|
|
an expression E1, will return the same object E1 when adapted
|
|
a second time. This is important in particular for things like
|
|
Label objects that are anonymized, so that the ColumnAdapter can
|
|
be used to present a consistent "adapted" view of things.
|
|
|
|
* Exclusion of items from the persistent collection based on
|
|
include/exclude rules, but also independent of hash identity.
|
|
This because "annotated" items all have the same hash identity as their
|
|
parent.
|
|
|
|
* "wrapping" capability is added, so that the replacement of an expression
|
|
E can proceed through a series of adapters. This differs from the
|
|
visitor's "chaining" feature in that the resulting object is passed
|
|
through all replacing functions unconditionally, rather than stopping
|
|
at the first one that returns non-None.
|
|
|
|
* An adapt_required option, used by eager loading to indicate that
|
|
We don't trust a result row column that is not translated.
|
|
This is to prevent a column from being interpreted as that
|
|
of the child row in a self-referential scenario, see
|
|
inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency
|
|
|
|
"""
|
|
|
|
def __init__(self, selectable, equivalents=None,
|
|
chain_to=None, adapt_required=False,
|
|
include_fn=None, exclude_fn=None,
|
|
adapt_on_names=False,
|
|
allow_label_resolve=True,
|
|
anonymize_labels=False):
|
|
ClauseAdapter.__init__(self, selectable, equivalents,
|
|
include_fn=include_fn, exclude_fn=exclude_fn,
|
|
adapt_on_names=adapt_on_names,
|
|
anonymize_labels=anonymize_labels)
|
|
|
|
if chain_to:
|
|
self.chain(chain_to)
|
|
self.columns = util.populate_column_dict(self._locate_col)
|
|
if self.include_fn or self.exclude_fn:
|
|
self.columns = self._IncludeExcludeMapping(self, self.columns)
|
|
self.adapt_required = adapt_required
|
|
self.allow_label_resolve = allow_label_resolve
|
|
self._wrap = None
|
|
|
|
class _IncludeExcludeMapping(object):
|
|
def __init__(self, parent, columns):
|
|
self.parent = parent
|
|
self.columns = columns
|
|
|
|
def __getitem__(self, key):
|
|
if (
|
|
self.parent.include_fn and not self.parent.include_fn(key)
|
|
) or (
|
|
self.parent.exclude_fn and self.parent.exclude_fn(key)
|
|
):
|
|
if self.parent._wrap:
|
|
return self.parent._wrap.columns[key]
|
|
else:
|
|
return key
|
|
return self.columns[key]
|
|
|
|
def wrap(self, adapter):
|
|
ac = self.__class__.__new__(self.__class__)
|
|
ac.__dict__.update(self.__dict__)
|
|
ac._wrap = adapter
|
|
ac.columns = util.populate_column_dict(ac._locate_col)
|
|
if ac.include_fn or ac.exclude_fn:
|
|
ac.columns = self._IncludeExcludeMapping(ac, ac.columns)
|
|
|
|
return ac
|
|
|
|
def traverse(self, obj):
|
|
return self.columns[obj]
|
|
|
|
adapt_clause = traverse
|
|
adapt_list = ClauseAdapter.copy_and_process
|
|
|
|
def _locate_col(self, col):
|
|
|
|
c = ClauseAdapter.traverse(self, col)
|
|
|
|
if self._wrap:
|
|
c2 = self._wrap._locate_col(c)
|
|
if c2 is not None:
|
|
c = c2
|
|
|
|
if self.adapt_required and c is col:
|
|
return None
|
|
|
|
c._allow_label_resolve = self.allow_label_resolve
|
|
|
|
return c
|
|
|
|
def __getstate__(self):
|
|
d = self.__dict__.copy()
|
|
del d['columns']
|
|
return d
|
|
|
|
def __setstate__(self, state):
|
|
self.__dict__.update(state)
|
|
self.columns = util.PopulateDict(self._locate_col)
|