openmedialibrary/oml/db.py

145 lines
4.2 KiB
Python

import re
from contextlib import contextmanager
from sqlalchemy import create_engine, MetaData
from sqlalchemy import orm
from sqlalchemy.orm.exc import UnmappedClassError
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.ext.mutable import Mutable
from sqlalchemy.ext.declarative import declarative_base
import settings
import state
import logging
logger = logging.getLogger(__name__)
engine = create_engine('sqlite:///%s' % settings.db_path, connect_args={'timeout': 90})
Session = scoped_session(sessionmaker(bind=engine))
metadata = MetaData()
class _QueryProperty(object):
def __init__(self):
pass
def __get__(self, obj, type):
try:
mapper = orm.class_mapper(type)
if mapper:
return type.query_class(mapper, session=state.db.session)
except UnmappedClassError:
return None
class BaseQuery(orm.Query):
pass
Model = declarative_base()
Model.query_class = BaseQuery
Model.query = _QueryProperty()
Model.metadata = metadata
@contextmanager
def session():
if hasattr(state.db, 'session'):
state.db.count += 1
else:
state.db.session = Session()
state.db.count = 1
try:
yield state.db.session
finally:
state.db.count -= 1
if not state.db.count:
state.db.session.close()
Session.remove()
class MutableDict(Mutable, dict):
@classmethod
def coerce(cls, key, value):
"Convert plain dictionaries to MutableDict."
if not isinstance(value, MutableDict):
if isinstance(value, dict):
return MutableDict(value)
# this call will raise ValueError
return Mutable.coerce(key, value)
else:
return value
def __setitem__(self, key, value):
"Detect dictionary set events and emit change events."
dict.__setitem__(self, key, value)
self.changed()
def __delitem__(self, key):
"Detect dictionary del events and emit change events."
dict.__delitem__(self, key)
self.changed()
def run_sql(sql):
if isinstance(sql, str):
sql = [sql]
with session() as s:
for q in sql:
s.connection().execute(q)
s.commit()
def table_exists(table):
return get_create_table(table) is not None
def get_create_table(table):
with session() as s:
sql = "SELECT sql FROM sqlite_master WHERE type='table' AND name = ?"
row = s.connection().execute(sql, (table, )).fetchone()
return row[0] if row else None
def get_table_columns(table):
create_table = get_create_table(table)
columns = create_table.split('(', 1)[1].rsplit(')', 1)[0]
columns = [r.strip().split()[0] for r in re.compile('(.*?),').findall(columns)]
columns = [col for col in columns if col.islower()]
return columns
def drop_columns(table, columns):
if isinstance(columns, str):
columns = [columns]
new_columns = [c for c in get_table_columns(table) if c not in columns]
info = {
'table': table,
'columns': ','.join(new_columns),
}
create_table = get_create_table(table)
for column in columns:
create_table = re.sub('(%s .*?,)' % column, '', create_table)
create_table = create_table.replace('\n', '').replace(',', ',\n')
create_table = re.sub('\n *', '\n ', create_table).replace('( ', '(\n ')
sql = [
'DROP TABLE IF EXISTS {table}_old',
'ALTER TABLE {table} RENAME TO {table}_old',
create_table,
'INSERT INTO {table} ({columns}) SELECT {columns} FROM {table}_old',
'DROP TABLE {table}_old'
]
with session() as s:
for q in sql:
q = q.format(**info)
s.connection().execute(q)
s.commit()
def get_layout():
layout = {
'tables': []
}
with session() as s:
sql = "SELECT name FROM sqlite_master WHERE type='table'"
layout['tables'] = sorted([r[0] for r in s.connection().execute(sql).fetchall() if not r[0].startswith('sqlite_')])
sql = "SELECT name FROM sqlite_master WHERE type='index'"
layout['indexes'] = sorted([r[0] for r in s.connection().execute(sql).fetchall() if not r[0].startswith('sqlite_')])
return layout