openmedialibrary/oml/db.py

142 lines
4.2 KiB
Python
Raw Normal View History

2016-01-16 15:57:15 +00:00
import re
2014-08-09 16:14:14 +00:00
from contextlib import contextmanager
from sqlalchemy import create_engine, MetaData
from sqlalchemy import orm
from sqlalchemy.orm.exc import UnmappedClassError
from sqlalchemy.orm import sessionmaker, scoped_session
2014-05-04 17:26:43 +00:00
from sqlalchemy.ext.mutable import Mutable
2014-08-09 15:03:16 +00:00
from sqlalchemy.ext.declarative import declarative_base
2014-08-09 16:14:14 +00:00
import settings
import state
2014-08-09 18:32:41 +00:00
2016-01-31 06:46:53 +00:00
engine = create_engine('sqlite:///%s' % settings.db_path, connect_args={'timeout': 90})
2014-08-09 16:14:14 +00:00
Session = scoped_session(sessionmaker(bind=engine))
metadata = MetaData()
class _QueryProperty(object):
def __init__(self):
pass
2014-08-09 15:03:16 +00:00
2014-08-09 16:14:14 +00:00
def __get__(self, obj, type):
try:
mapper = orm.class_mapper(type)
if mapper:
return type.query_class(mapper, session=state.db.session)
except UnmappedClassError:
return None
2014-08-09 15:03:16 +00:00
2014-08-09 16:14:14 +00:00
class BaseQuery(orm.Query):
pass
2014-08-09 15:03:16 +00:00
Model = declarative_base()
2014-08-09 16:14:14 +00:00
Model.query_class = BaseQuery
Model.query = _QueryProperty()
Model.metadata = metadata
2014-08-09 15:03:16 +00:00
2014-08-09 16:14:14 +00:00
@contextmanager
def session():
if hasattr(state.db, 'session'):
state.db.count += 1
else:
state.db.session = Session()
state.db.count = 1
2014-08-09 18:32:41 +00:00
try:
yield state.db.session
finally:
state.db.count -= 1
if not state.db.count:
state.db.session.close()
Session.remove()
2014-05-04 17:26:43 +00:00
class MutableDict(Mutable, dict):
@classmethod
def coerce(cls, key, value):
"Convert plain dictionaries to MutableDict."
if not isinstance(value, MutableDict):
if isinstance(value, dict):
return MutableDict(value)
# this call will raise ValueError
return Mutable.coerce(key, value)
else:
return value
def __setitem__(self, key, value):
"Detect dictionary set events and emit change events."
dict.__setitem__(self, key, value)
self.changed()
def __delitem__(self, key):
"Detect dictionary del events and emit change events."
dict.__delitem__(self, key)
self.changed()
2016-01-16 15:57:15 +00:00
def run_sql(sql):
2016-01-19 09:41:07 +00:00
if isinstance(sql, str):
sql = [sql]
2016-01-16 15:57:15 +00:00
with session() as s:
2016-01-19 09:41:07 +00:00
for q in sql:
s.connection().execute(q)
2016-01-16 15:57:15 +00:00
s.commit()
def table_exists(table):
return get_create_table(table) is not None
def get_create_table(table):
with session() as s:
sql = "SELECT sql FROM sqlite_master WHERE type='table' AND name = ?"
row = s.connection().execute(sql, (table, )).fetchone()
return row[0] if row else None
2016-01-17 08:19:31 +00:00
def get_table_columns(table):
2016-01-16 15:57:15 +00:00
create_table = get_create_table(table)
2016-01-17 08:30:39 +00:00
columns = create_table.split('(', 1)[1].rsplit(')', 1)[0]
columns = [r.strip().split()[0] for r in re.compile('(.*?),').findall(columns)]
2016-01-17 08:19:31 +00:00
columns = [col for col in columns if col.islower()]
return columns
2016-01-16 15:57:15 +00:00
def drop_columns(table, columns):
if isinstance(columns, str):
columns = [columns]
2016-01-17 08:19:31 +00:00
new_columns = [c for c in get_table_columns(table) if c not in columns]
2016-01-16 15:57:15 +00:00
info = {
'table': table,
'columns': ','.join(new_columns),
}
create_table = get_create_table(table)
for column in columns:
2016-01-17 08:19:31 +00:00
create_table = re.sub('(%s .*?,)'%column, '', create_table)
2016-01-16 15:57:15 +00:00
create_table = create_table.replace('\n', '').replace(',', ',\n')
create_table = re.sub('\n *', '\n ', create_table).replace('( ', '(\n ')
sql = [
2016-01-17 08:19:31 +00:00
'DROP TABLE IF EXISTS {table}_old',
2016-01-16 15:57:15 +00:00
'ALTER TABLE {table} RENAME TO {table}_old',
create_table,
'INSERT INTO {table} ({columns}) SELECT {columns} FROM {table}_old',
'DROP TABLE {table}_old'
]
with session() as s:
for q in sql:
q = q.format(**info)
s.connection().execute(q)
s.commit()
2016-01-17 06:29:06 +00:00
def get_layout():
layout = {
'tables': []
}
with session() as s:
sql = "SELECT name FROM sqlite_master WHERE type='table'"
layout['tables'] = sorted([r[0] for r in s.connection().execute(sql).fetchall() if not r[0].startswith('sqlite_')])
sql = "SELECT name FROM sqlite_master WHERE type='index'"
layout['indexes'] = sorted([r[0] for r in s.connection().execute(sql).fetchall() if not r[0].startswith('sqlite_')])
return layout