import re from contextlib import contextmanager from sqlalchemy import create_engine, MetaData from sqlalchemy import orm from sqlalchemy.orm.exc import UnmappedClassError from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.ext.mutable import Mutable from sqlalchemy.ext.declarative import declarative_base import settings import state engine = create_engine('sqlite:////%s' % settings.db_path) Session = scoped_session(sessionmaker(bind=engine)) metadata = MetaData() class _QueryProperty(object): def __init__(self): pass def __get__(self, obj, type): try: mapper = orm.class_mapper(type) if mapper: return type.query_class(mapper, session=state.db.session) except UnmappedClassError: return None class BaseQuery(orm.Query): pass Model = declarative_base() Model.query_class = BaseQuery Model.query = _QueryProperty() Model.metadata = metadata @contextmanager def session(): if hasattr(state.db, 'session'): state.db.count += 1 else: state.db.session = Session() state.db.count = 1 try: yield state.db.session finally: state.db.count -= 1 if not state.db.count: state.db.session.close() Session.remove() class MutableDict(Mutable, dict): @classmethod def coerce(cls, key, value): "Convert plain dictionaries to MutableDict." if not isinstance(value, MutableDict): if isinstance(value, dict): return MutableDict(value) # this call will raise ValueError return Mutable.coerce(key, value) else: return value def __setitem__(self, key, value): "Detect dictionary set events and emit change events." dict.__setitem__(self, key, value) self.changed() def __delitem__(self, key): "Detect dictionary del events and emit change events." dict.__delitem__(self, key) self.changed() def run_sql(sql): with session() as s: s.connection().execute(sql) s.commit() def table_exists(table): return get_create_table(table) is not None def get_create_table(table): with session() as s: sql = "SELECT sql FROM sqlite_master WHERE type='table' AND name = ?" row = s.connection().execute(sql, (table, )).fetchone() return row[0] if row else None def get_table_columms(table): create_table = get_create_table(table) return [r.strip().split()[0] for r in re.compile('(.*?),').findall(create_table)][:-1] def drop_columns(table, columns): if isinstance(columns, str): columns = [columns] new_columns = [c for c in get_table_columms(table) if c not in columns] info = { 'table': table, 'columns': ','.join(new_columns), } create_table = get_create_table(table) for column in columns: create_table = re.sub('( %s .*?,)'%column, '', create_table) create_table = create_table.replace('\n', '').replace(',', ',\n') create_table = re.sub('\n *', '\n ', create_table).replace('( ', '(\n ') sql = [ 'ALTER TABLE {table} RENAME TO {table}_old', create_table, 'INSERT INTO {table} ({columns}) SELECT {columns} FROM {table}_old', 'DROP TABLE {table}_old' ] with session() as s: for q in sql: q = q.format(**info) s.connection().execute(q) s.commit() def get_layout(): layout = { 'tables': [] } with session() as s: sql = "SELECT name FROM sqlite_master WHERE type='table'" layout['tables'] = sorted([r[0] for r in s.connection().execute(sql).fetchall() if not r[0].startswith('sqlite_')]) sql = "SELECT name FROM sqlite_master WHERE type='index'" layout['indexes'] = sorted([r[0] for r in s.connection().execute(sql).fetchall() if not r[0].startswith('sqlite_')]) return layout