280 lines
10 KiB
Python
280 lines
10 KiB
Python
|
from sqlalchemy.sql.expression import _BindParamClause
|
||
|
from sqlalchemy.ext.compiler import compiles
|
||
|
from sqlalchemy import schema, text
|
||
|
from sqlalchemy import types as sqltypes
|
||
|
|
||
|
from ..compat import string_types, text_type, with_metaclass
|
||
|
from .. import util
|
||
|
from . import base
|
||
|
|
||
|
class ImplMeta(type):
|
||
|
def __init__(cls, classname, bases, dict_):
|
||
|
newtype = type.__init__(cls, classname, bases, dict_)
|
||
|
if '__dialect__' in dict_:
|
||
|
_impls[dict_['__dialect__']] = cls
|
||
|
return newtype
|
||
|
|
||
|
_impls = {}
|
||
|
|
||
|
class DefaultImpl(with_metaclass(ImplMeta)):
|
||
|
"""Provide the entrypoint for major migration operations,
|
||
|
including database-specific behavioral variances.
|
||
|
|
||
|
While individual SQL/DDL constructs already provide
|
||
|
for database-specific implementations, variances here
|
||
|
allow for entirely different sequences of operations
|
||
|
to take place for a particular migration, such as
|
||
|
SQL Server's special 'IDENTITY INSERT' step for
|
||
|
bulk inserts.
|
||
|
|
||
|
"""
|
||
|
__dialect__ = 'default'
|
||
|
|
||
|
transactional_ddl = False
|
||
|
command_terminator = ";"
|
||
|
|
||
|
def __init__(self, dialect, connection, as_sql,
|
||
|
transactional_ddl, output_buffer,
|
||
|
context_opts):
|
||
|
self.dialect = dialect
|
||
|
self.connection = connection
|
||
|
self.as_sql = as_sql
|
||
|
self.output_buffer = output_buffer
|
||
|
self.memo = {}
|
||
|
self.context_opts = context_opts
|
||
|
if transactional_ddl is not None:
|
||
|
self.transactional_ddl = transactional_ddl
|
||
|
|
||
|
@classmethod
|
||
|
def get_by_dialect(cls, dialect):
|
||
|
return _impls[dialect.name]
|
||
|
|
||
|
def static_output(self, text):
|
||
|
self.output_buffer.write(text_type(text + "\n\n"))
|
||
|
self.output_buffer.flush()
|
||
|
|
||
|
@property
|
||
|
def bind(self):
|
||
|
return self.connection
|
||
|
|
||
|
def _exec(self, construct, execution_options=None,
|
||
|
multiparams=(),
|
||
|
params=util.immutabledict()):
|
||
|
if isinstance(construct, string_types):
|
||
|
construct = text(construct)
|
||
|
if self.as_sql:
|
||
|
if multiparams or params:
|
||
|
# TODO: coverage
|
||
|
raise Exception("Execution arguments not allowed with as_sql")
|
||
|
self.static_output(text_type(
|
||
|
construct.compile(dialect=self.dialect)
|
||
|
).replace("\t", " ").strip() + self.command_terminator)
|
||
|
else:
|
||
|
conn = self.connection
|
||
|
if execution_options:
|
||
|
conn = conn.execution_options(**execution_options)
|
||
|
conn.execute(construct, *multiparams, **params)
|
||
|
|
||
|
def execute(self, sql, execution_options=None):
|
||
|
self._exec(sql, execution_options)
|
||
|
|
||
|
def alter_column(self, table_name, column_name,
|
||
|
nullable=None,
|
||
|
server_default=False,
|
||
|
name=None,
|
||
|
type_=None,
|
||
|
schema=None,
|
||
|
autoincrement=None,
|
||
|
existing_type=None,
|
||
|
existing_server_default=None,
|
||
|
existing_nullable=None,
|
||
|
existing_autoincrement=None
|
||
|
):
|
||
|
if autoincrement is not None or existing_autoincrement is not None:
|
||
|
util.warn("nautoincrement and existing_autoincrement only make sense for MySQL")
|
||
|
if nullable is not None:
|
||
|
self._exec(base.ColumnNullable(table_name, column_name,
|
||
|
nullable, schema=schema,
|
||
|
existing_type=existing_type,
|
||
|
existing_server_default=existing_server_default,
|
||
|
existing_nullable=existing_nullable,
|
||
|
))
|
||
|
if server_default is not False:
|
||
|
self._exec(base.ColumnDefault(
|
||
|
table_name, column_name, server_default,
|
||
|
schema=schema,
|
||
|
existing_type=existing_type,
|
||
|
existing_server_default=existing_server_default,
|
||
|
existing_nullable=existing_nullable,
|
||
|
))
|
||
|
if type_ is not None:
|
||
|
self._exec(base.ColumnType(
|
||
|
table_name, column_name, type_, schema=schema,
|
||
|
existing_type=existing_type,
|
||
|
existing_server_default=existing_server_default,
|
||
|
existing_nullable=existing_nullable,
|
||
|
))
|
||
|
# do the new name last ;)
|
||
|
if name is not None:
|
||
|
self._exec(base.ColumnName(
|
||
|
table_name, column_name, name, schema=schema,
|
||
|
existing_type=existing_type,
|
||
|
existing_server_default=existing_server_default,
|
||
|
existing_nullable=existing_nullable,
|
||
|
))
|
||
|
|
||
|
def add_column(self, table_name, column, schema=None):
|
||
|
self._exec(base.AddColumn(table_name, column, schema=schema))
|
||
|
|
||
|
def drop_column(self, table_name, column, schema=None, **kw):
|
||
|
self._exec(base.DropColumn(table_name, column, schema=schema))
|
||
|
|
||
|
def add_constraint(self, const):
|
||
|
if const._create_rule is None or \
|
||
|
const._create_rule(self):
|
||
|
self._exec(schema.AddConstraint(const))
|
||
|
|
||
|
def drop_constraint(self, const):
|
||
|
self._exec(schema.DropConstraint(const))
|
||
|
|
||
|
def rename_table(self, old_table_name, new_table_name, schema=None):
|
||
|
self._exec(base.RenameTable(old_table_name,
|
||
|
new_table_name, schema=schema))
|
||
|
|
||
|
def create_table(self, table):
|
||
|
if util.sqla_07:
|
||
|
table.dispatch.before_create(table, self.connection,
|
||
|
checkfirst=False,
|
||
|
_ddl_runner=self)
|
||
|
self._exec(schema.CreateTable(table))
|
||
|
if util.sqla_07:
|
||
|
table.dispatch.after_create(table, self.connection,
|
||
|
checkfirst=False,
|
||
|
_ddl_runner=self)
|
||
|
for index in table.indexes:
|
||
|
self._exec(schema.CreateIndex(index))
|
||
|
|
||
|
def drop_table(self, table):
|
||
|
self._exec(schema.DropTable(table))
|
||
|
|
||
|
def create_index(self, index):
|
||
|
self._exec(schema.CreateIndex(index))
|
||
|
|
||
|
def drop_index(self, index):
|
||
|
self._exec(schema.DropIndex(index))
|
||
|
|
||
|
def bulk_insert(self, table, rows, multiinsert=True):
|
||
|
if not isinstance(rows, list):
|
||
|
raise TypeError("List expected")
|
||
|
elif rows and not isinstance(rows[0], dict):
|
||
|
raise TypeError("List of dictionaries expected")
|
||
|
if self.as_sql:
|
||
|
for row in rows:
|
||
|
self._exec(table.insert(inline=True).values(**dict(
|
||
|
(k,
|
||
|
_literal_bindparam(k, v, type_=table.c[k].type)
|
||
|
if not isinstance(v, _literal_bindparam) else v)
|
||
|
for k, v in row.items()
|
||
|
)))
|
||
|
else:
|
||
|
# work around http://www.sqlalchemy.org/trac/ticket/2461
|
||
|
if not hasattr(table, '_autoincrement_column'):
|
||
|
table._autoincrement_column = None
|
||
|
if rows:
|
||
|
if multiinsert:
|
||
|
self._exec(table.insert(inline=True), multiparams=rows)
|
||
|
else:
|
||
|
for row in rows:
|
||
|
self._exec(table.insert(inline=True).values(**row))
|
||
|
|
||
|
def compare_type(self, inspector_column, metadata_column):
|
||
|
|
||
|
conn_type = inspector_column.type
|
||
|
metadata_type = metadata_column.type
|
||
|
|
||
|
metadata_impl = metadata_type.dialect_impl(self.dialect)
|
||
|
|
||
|
# work around SQLAlchemy bug "stale value for type affinity"
|
||
|
# fixed in 0.7.4
|
||
|
metadata_impl.__dict__.pop('_type_affinity', None)
|
||
|
|
||
|
if conn_type._compare_type_affinity(
|
||
|
metadata_impl
|
||
|
):
|
||
|
comparator = _type_comparators.get(conn_type._type_affinity, None)
|
||
|
|
||
|
return comparator and comparator(metadata_type, conn_type)
|
||
|
else:
|
||
|
return True
|
||
|
|
||
|
def compare_server_default(self, inspector_column,
|
||
|
metadata_column,
|
||
|
rendered_metadata_default,
|
||
|
rendered_inspector_default):
|
||
|
return rendered_inspector_default != rendered_metadata_default
|
||
|
|
||
|
def correct_for_autogen_constraints(self, conn_uniques, conn_indexes,
|
||
|
metadata_unique_constraints,
|
||
|
metadata_indexes):
|
||
|
pass
|
||
|
|
||
|
def start_migrations(self):
|
||
|
"""A hook called when :meth:`.EnvironmentContext.run_migrations`
|
||
|
is called.
|
||
|
|
||
|
Implementations can set up per-migration-run state here.
|
||
|
|
||
|
"""
|
||
|
|
||
|
def emit_begin(self):
|
||
|
"""Emit the string ``BEGIN``, or the backend-specific
|
||
|
equivalent, on the current connection context.
|
||
|
|
||
|
This is used in offline mode and typically
|
||
|
via :meth:`.EnvironmentContext.begin_transaction`.
|
||
|
|
||
|
"""
|
||
|
self.static_output("BEGIN" + self.command_terminator)
|
||
|
|
||
|
def emit_commit(self):
|
||
|
"""Emit the string ``COMMIT``, or the backend-specific
|
||
|
equivalent, on the current connection context.
|
||
|
|
||
|
This is used in offline mode and typically
|
||
|
via :meth:`.EnvironmentContext.begin_transaction`.
|
||
|
|
||
|
"""
|
||
|
self.static_output("COMMIT" + self.command_terminator)
|
||
|
|
||
|
class _literal_bindparam(_BindParamClause):
|
||
|
pass
|
||
|
|
||
|
@compiles(_literal_bindparam)
|
||
|
def _render_literal_bindparam(element, compiler, **kw):
|
||
|
return compiler.render_literal_bindparam(element, **kw)
|
||
|
|
||
|
|
||
|
def _string_compare(t1, t2):
|
||
|
return \
|
||
|
t1.length is not None and \
|
||
|
t1.length != t2.length
|
||
|
|
||
|
def _numeric_compare(t1, t2):
|
||
|
return \
|
||
|
(
|
||
|
t1.precision is not None and \
|
||
|
t1.precision != t2.precision
|
||
|
) or \
|
||
|
(
|
||
|
t1.scale is not None and \
|
||
|
t1.scale != t2.scale
|
||
|
)
|
||
|
_type_comparators = {
|
||
|
sqltypes.String:_string_compare,
|
||
|
sqltypes.Numeric:_numeric_compare
|
||
|
}
|
||
|
|
||
|
|
||
|
|
||
|
|