141 lines
4.1 KiB
Python
141 lines
4.1 KiB
Python
import re
|
|
from contextlib import contextmanager
|
|
from sqlalchemy import create_engine, MetaData
|
|
from sqlalchemy import orm
|
|
from sqlalchemy.orm.exc import UnmappedClassError
|
|
from sqlalchemy.orm import sessionmaker, scoped_session
|
|
from sqlalchemy.ext.mutable import Mutable
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
|
|
|
|
import settings
|
|
import state
|
|
|
|
engine = create_engine('sqlite:////%s' % settings.db_path)
|
|
Session = scoped_session(sessionmaker(bind=engine))
|
|
|
|
metadata = MetaData()
|
|
|
|
|
|
class _QueryProperty(object):
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def __get__(self, obj, type):
|
|
try:
|
|
mapper = orm.class_mapper(type)
|
|
if mapper:
|
|
return type.query_class(mapper, session=state.db.session)
|
|
except UnmappedClassError:
|
|
return None
|
|
|
|
class BaseQuery(orm.Query):
|
|
pass
|
|
|
|
Model = declarative_base()
|
|
Model.query_class = BaseQuery
|
|
Model.query = _QueryProperty()
|
|
Model.metadata = metadata
|
|
|
|
@contextmanager
|
|
def session():
|
|
if hasattr(state.db, 'session'):
|
|
state.db.count += 1
|
|
else:
|
|
state.db.session = Session()
|
|
state.db.count = 1
|
|
try:
|
|
yield state.db.session
|
|
finally:
|
|
state.db.count -= 1
|
|
if not state.db.count:
|
|
state.db.session.close()
|
|
Session.remove()
|
|
|
|
class MutableDict(Mutable, dict):
|
|
@classmethod
|
|
def coerce(cls, key, value):
|
|
"Convert plain dictionaries to MutableDict."
|
|
|
|
if not isinstance(value, MutableDict):
|
|
if isinstance(value, dict):
|
|
return MutableDict(value)
|
|
|
|
# this call will raise ValueError
|
|
return Mutable.coerce(key, value)
|
|
else:
|
|
return value
|
|
|
|
def __setitem__(self, key, value):
|
|
"Detect dictionary set events and emit change events."
|
|
|
|
dict.__setitem__(self, key, value)
|
|
self.changed()
|
|
|
|
def __delitem__(self, key):
|
|
"Detect dictionary del events and emit change events."
|
|
|
|
dict.__delitem__(self, key)
|
|
self.changed()
|
|
|
|
def run_sql(sql):
|
|
if isinstance(sql, str):
|
|
sql = [sql]
|
|
with session() as s:
|
|
for q in sql:
|
|
s.connection().execute(q)
|
|
s.commit()
|
|
|
|
def table_exists(table):
|
|
return get_create_table(table) is not None
|
|
|
|
def get_create_table(table):
|
|
with session() as s:
|
|
sql = "SELECT sql FROM sqlite_master WHERE type='table' AND name = ?"
|
|
row = s.connection().execute(sql, (table, )).fetchone()
|
|
return row[0] if row else None
|
|
|
|
def get_table_columns(table):
|
|
create_table = get_create_table(table)
|
|
columns = create_table.split('(', 1)[1].rsplit(')', 1)[0]
|
|
columns = [r.strip().split()[0] for r in re.compile('(.*?),').findall(columns)]
|
|
columns = [col for col in columns if col.islower()]
|
|
return columns
|
|
|
|
def drop_columns(table, columns):
|
|
if isinstance(columns, str):
|
|
columns = [columns]
|
|
new_columns = [c for c in get_table_columns(table) if c not in columns]
|
|
info = {
|
|
'table': table,
|
|
'columns': ','.join(new_columns),
|
|
}
|
|
create_table = get_create_table(table)
|
|
for column in columns:
|
|
create_table = re.sub('(%s .*?,)'%column, '', create_table)
|
|
create_table = create_table.replace('\n', '').replace(',', ',\n')
|
|
create_table = re.sub('\n *', '\n ', create_table).replace('( ', '(\n ')
|
|
sql = [
|
|
'DROP TABLE IF EXISTS {table}_old',
|
|
'ALTER TABLE {table} RENAME TO {table}_old',
|
|
create_table,
|
|
'INSERT INTO {table} ({columns}) SELECT {columns} FROM {table}_old',
|
|
'DROP TABLE {table}_old'
|
|
]
|
|
with session() as s:
|
|
for q in sql:
|
|
q = q.format(**info)
|
|
s.connection().execute(q)
|
|
s.commit()
|
|
|
|
def get_layout():
|
|
layout = {
|
|
'tables': []
|
|
}
|
|
with session() as s:
|
|
sql = "SELECT name FROM sqlite_master WHERE type='table'"
|
|
layout['tables'] = sorted([r[0] for r in s.connection().execute(sql).fetchall() if not r[0].startswith('sqlite_')])
|
|
sql = "SELECT name FROM sqlite_master WHERE type='index'"
|
|
layout['indexes'] = sorted([r[0] for r in s.connection().execute(sql).fetchall() if not r[0].startswith('sqlite_')])
|
|
return layout
|