one sqlalchemy session per thread

This commit is contained in:
j 2014-08-09 18:14:14 +02:00
parent 0c08d37c56
commit 8b46a85d56
15 changed files with 140 additions and 102 deletions

View file

@ -13,7 +13,6 @@ import settings
import changelog import changelog
from db import session
import item.models import item.models
import user.models import user.models
import item.person import item.person

View file

@ -57,8 +57,8 @@ class Changelog(db.Model):
c.data = json.dumps([action] + list(args)) c.data = json.dumps([action] + list(args))
_data = str(c.revision) + str(c.timestamp) + c.data _data = str(c.revision) + str(c.timestamp) + c.data
c.sig = settings.sk.sign(_data, encoding='base64') c.sig = settings.sk.sign(_data, encoding='base64')
db.session.add(c) state.db.session.add(c)
db.session.commit() state.db.session.commit()
if state.nodes: if state.nodes:
state.nodes.queue('peered', 'pushChanges', [c.json()]) state.nodes.queue('peered', 'pushChanges', [c.json()])
@ -94,8 +94,8 @@ class Changelog(db.Model):
logger.debug('apply change from %s: %s', user.name, args) logger.debug('apply change from %s: %s', user.name, args)
if getattr(c, 'action_' + args[0])(user, timestamp, *args[1:]): if getattr(c, 'action_' + args[0])(user, timestamp, *args[1:]):
logger.debug('change applied') logger.debug('change applied')
db.session.add(c) state.db.session.add(c)
db.session.commit() state.db.session.commit()
if trigger: if trigger:
trigger_event('change', {}); trigger_event('change', {});
return True return True
@ -118,8 +118,8 @@ class Changelog(db.Model):
for c in cls.query.filter_by(user_id=settings.USER_ID): for c in cls.query.filter_by(user_id=settings.USER_ID):
_data = str(c.revision) + str(c.timestamp) + c.data _data = str(c.revision) + str(c.timestamp) + c.data
c.sig = settings.sk.sign(_data, encoding='base64') c.sig = settings.sk.sign(_data, encoding='base64')
db.session.add(c) state.db.session.add(c)
db.session.commit() state.db.session.commit()
def json(self): def json(self):
timestamp = self.timestamp or datetime2ts(self.created) timestamp = self.timestamp or datetime2ts(self.created)

View file

@ -1,19 +1,52 @@
from sqlalchemy import create_engine from contextlib import contextmanager
from sqlalchemy.orm import sessionmaker from sqlalchemy import create_engine, MetaData
from sqlalchemy import orm
from sqlalchemy.orm.exc import UnmappedClassError
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.ext.mutable import Mutable from sqlalchemy.ext.mutable import Mutable
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
import settings import settings
import state
engine = create_engine('sqlite:////%s' % settings.db_path) engine = create_engine('sqlite:////%s' % settings.db_path)
Session = sessionmaker(bind=engine) Session = scoped_session(sessionmaker(bind=engine))
# create a Session metadata = MetaData()
session = Session()
class _QueryProperty(object):
def __init__(self):
pass
def __get__(self, obj, type):
try:
mapper = orm.class_mapper(type)
if mapper:
return type.query_class(mapper, session=state.db.session)
except UnmappedClassError:
return None
class BaseQuery(orm.Query):
pass
Model = declarative_base() Model = declarative_base()
Model.query_class = BaseQuery
Model.query = _QueryProperty()
Model.metadata = metadata
@contextmanager
def session():
if hasattr(state.db, 'session'):
state.db.count += 1
else:
state.db.session = Session()
state.db.count = 1
yield
state.db.count -= 1
if not state.db.count:
Session.remove()
class MutableDict(Mutable, dict): class MutableDict(Mutable, dict):
@classmethod @classmethod

View file

@ -6,6 +6,7 @@ from threading import Thread
import time import time
import logging import logging
import db
import state import state
import settings import settings
import update import update
@ -43,7 +44,7 @@ class Downloads(Thread):
def run(self): def run(self):
time.sleep(2) time.sleep(2)
with self._app.app_context(): with db.session():
while self._running: while self._running:
self.download_next() self.download_next()
time.sleep(0.5) time.sleep(0.5)

View file

@ -17,6 +17,7 @@ from utils import resize_image
from settings import icons_db_path from settings import icons_db_path
import db
import logging import logging
logger = logging.getLogger('oml.item.icons') logger = logging.getLogger('oml.item.icons')
@ -120,7 +121,7 @@ def get_icon(app, id, type_, size, callback):
@run_async @run_async
def get_icon_app(app, id, type_, size, callback): def get_icon_app(app, id, type_, size, callback):
with app.app_context(): with db.session():
from item.models import Item from item.models import Item
item = Item.get(id) item = Item.get(id)
if not item: if not item:

View file

@ -41,8 +41,7 @@ from utils import remove_empty_folders
logger = logging.getLogger('oml.item.model') logger = logging.getLogger('oml.item.model')
metadata = sa.MetaData() user_items = sa.Table('useritem', db.metadata,
user_items = sa.Table('useritem', metadata,
sa.Column('user_id', sa.String(43), sa.ForeignKey('user.id')), sa.Column('user_id', sa.String(43), sa.ForeignKey('user.id')),
sa.Column('item_id', sa.String(32), sa.ForeignKey('item.id')) sa.Column('item_id', sa.String(32), sa.ForeignKey('item.id'))
) )
@ -97,8 +96,8 @@ class Item(db.Model):
item = cls(id=id) item = cls(id=id)
if info: if info:
item.info = info item.info = info
db.session.add(item) state.db.session.add(item)
db.session.commit() state.db.session.commit()
return item return item
def json(self, keys=None): def json(self, keys=None):
@ -166,7 +165,7 @@ class Item(db.Model):
elif isinstance(value, list): #empty list elif isinstance(value, list): #empty list
value = '' value = ''
setattr(s, key['id'], value) setattr(s, key['id'], value)
db.session.add(s) state.db.session.add(s)
def update_find(self): def update_find(self):
@ -176,7 +175,7 @@ class Item(db.Model):
v = v.decode('utf-8') v = v.decode('utf-8')
f.findvalue = unicodedata.normalize('NFKD', v).lower() f.findvalue = unicodedata.normalize('NFKD', v).lower()
f.value = v f.value = v
db.session.add(f) state.db.session.add(f)
for key in config['itemKeys']: for key in config['itemKeys']:
if key.get('find') or key.get('filter') or key.get('type') in [['string'], 'string']: if key.get('find') or key.get('filter') or key.get('type') in [['string'], 'string']:
@ -195,7 +194,7 @@ class Item(db.Model):
else: else:
f = Find.get(self.id, key['id']) f = Find.get(self.id, key['id'])
if f: if f:
db.session.delete(f) state.db.session.delete(f)
def update(self): def update(self):
for key in ('mediastate', 'coverRatio', 'previewRatio'): for key in ('mediastate', 'coverRatio', 'previewRatio'):
@ -218,15 +217,15 @@ class Item(db.Model):
self.save() self.save()
def save(self): def save(self):
db.session.add(self) state.db.session.add(self)
db.session.commit() state.db.session.commit()
def delete(self, commit=True): def delete(self, commit=True):
db.session.delete(self) state.db.session.delete(self)
Sort.query.filter_by(item_id=self.id).delete() Sort.query.filter_by(item_id=self.id).delete()
Transfer.query.filter_by(item_id=self.id).delete() Transfer.query.filter_by(item_id=self.id).delete()
if commit: if commit:
db.session.commit() state.db.session.commit()
meta_keys = ('title', 'author', 'date', 'publisher', 'edition', 'language') meta_keys = ('title', 'author', 'date', 'publisher', 'edition', 'language')
@ -392,13 +391,13 @@ class Item(db.Model):
if os.path.exists(path): if os.path.exists(path):
os.unlink(path) os.unlink(path)
remove_empty_folders(os.path.dirname(path)) remove_empty_folders(os.path.dirname(path))
db.session.delete(f) state.db.session.delete(f)
user = state.user() user = state.user()
if user in self.users: if user in self.users:
self.users.remove(user) self.users.remove(user)
for l in self.lists.filter_by(user_id=user.id): for l in self.lists.filter_by(user_id=user.id):
l.items.remove(self) l.items.remove(self)
db.session.commit() state.db.session.commit()
if not self.users: if not self.users:
self.delete() self.delete()
else: else:
@ -424,8 +423,8 @@ class Sort(db.Model):
f = cls.get(item_id) f = cls.get(item_id)
if not f: if not f:
f = cls(item_id=item_id) f = cls(item_id=item_id)
db.session.add(f) state.db.session.add(f)
db.session.commit() state.db.session.commit()
return f return f
for key in config['itemKeys']: for key in config['itemKeys']:
@ -467,8 +466,8 @@ class Find(db.Model):
f = cls.get(item, key) f = cls.get(item, key)
if not f: if not f:
f = cls(item_id=item, key=key) f = cls(item_id=item, key=key)
db.session.add(f) state.db.session.add(f)
db.session.commit() state.db.session.commit()
return f return f
class File(db.Model): class File(db.Model):
@ -499,8 +498,8 @@ class File(db.Model):
if path: if path:
f.path = path f.path = path
f.item_id = Item.get_or_create(id=sha1, info=info).id f.item_id = Item.get_or_create(id=sha1, info=info).id
db.session.add(f) state.db.session.add(f)
db.session.commit() state.db.session.commit()
return f return f
def __repr__(self): def __repr__(self):
@ -556,8 +555,8 @@ class File(db.Model):
self.save() self.save()
def save(self): def save(self):
db.session.add(self) state.db.session.add(self)
db.session.commit() state.db.session.commit()
class Transfer(db.Model): class Transfer(db.Model):
@ -587,8 +586,8 @@ class Transfer(db.Model):
return t return t
def save(self): def save(self):
db.session.add(self) state.db.session.add(self)
db.session.commit() state.db.session.commit()
class Metadata(db.Model): class Metadata(db.Model):
__tablename__ = 'metadata' __tablename__ = 'metadata'
@ -626,14 +625,14 @@ class Metadata(db.Model):
def save(self): def save(self):
self.modified = datetime.utcnow() self.modified = datetime.utcnow()
db.session.add(self) state.db.session.add(self)
db.session.commit() state.db.session.commit()
def reset(self): def reset(self):
user = state.user() user = state.user()
Changelog.record(user, 'resetmeta', self.key, self.value) Changelog.record(user, 'resetmeta', self.key, self.value)
db.session.delete(self) state.db.session.delete(self)
db.session.commit() state.db.session.commit()
self.update_items() self.update_items()
def edit(self, data): def edit(self, data):

View file

@ -41,6 +41,6 @@ class Person(db.Model):
self.sortname = unicodedata.normalize('NFKD', self.sortname) self.sortname = unicodedata.normalize('NFKD', self.sortname)
self.sortsortname = ox.sort_string(self.sortname) self.sortsortname = ox.sort_string(self.sortname)
self.numberofnames = len(self.name.split(' ')) self.numberofnames = len(self.name.split(' '))
db.session.add(self) state.db.session.add(self)
db.session.commit() state.db.session.commit()

View file

@ -11,7 +11,7 @@ import ox
from app import app from app import app
import settings import settings
from settings import db import db
from item.models import File from item.models import File
from user.models import User, List from user.models import User, List
@ -29,7 +29,7 @@ extensions = ['epub', 'pdf', 'txt']
def remove_missing(): def remove_missing():
dirty = False dirty = False
with app.app_context(): with db.session():
prefs = settings.preferences prefs = settings.preferences
prefix = os.path.join(os.path.expanduser(prefs['libraryPath']), 'Books/') prefix = os.path.join(os.path.expanduser(prefs['libraryPath']), 'Books/')
if os.path.exists(prefix): if os.path.exists(prefix):
@ -39,7 +39,7 @@ def remove_missing():
dirty = True dirty = True
f.item.remove_file() f.item.remove_file()
if dirty: if dirty:
db.session.commit() state.db.session.commit()
def add_file(id, f, prefix, from_=None): def add_file(id, f, prefix, from_=None):
user = state.user() user = state.user()
@ -49,10 +49,10 @@ def add_file(id, f, prefix, from_=None):
item = file.item item = file.item
if 'primaryid' in file.info: if 'primaryid' in file.info:
del file.info['primaryid'] del file.info['primaryid']
db.session.add(file) state.db.session.add(file)
if 'primaryid' in item.info: if 'primaryid' in item.info:
item.meta['primaryid'] = item.info.pop('primaryid') item.meta['primaryid'] = item.info.pop('primaryid')
db.session.add(item) state.db.session.add(item)
item.users.append(user) item.users.append(user)
Changelog.record(user, 'additem', item.id, file.info) Changelog.record(user, 'additem', item.id, file.info)
if item.meta.get('primaryid'): if item.meta.get('primaryid'):
@ -65,7 +65,7 @@ def add_file(id, f, prefix, from_=None):
def run_scan(): def run_scan():
remove_missing() remove_missing()
with app.app_context(): with db.session():
prefs = settings.preferences prefs = settings.preferences
prefix = os.path.join(os.path.expanduser(prefs['libraryPath']), 'Books/') prefix = os.path.join(os.path.expanduser(prefs['libraryPath']), 'Books/')
if not prefix[-1] == '/': if not prefix[-1] == '/':
@ -96,7 +96,7 @@ def run_scan():
def run_import(options=None): def run_import(options=None):
options = options or {} options = options or {}
with app.app_context(): with db.session():
prefs = settings.preferences prefs = settings.preferences
prefix = os.path.expanduser(options.get('path', prefs['importPath'])) prefix = os.path.expanduser(options.get('path', prefs['importPath']))
if os.path.islink(prefix): if os.path.islink(prefix):

View file

@ -10,6 +10,7 @@ from tornado.ioloop import PeriodicCallback
import settings import settings
import directory import directory
import db
import state import state
import user import user
@ -81,7 +82,7 @@ class NodeHandler(tornado.web.RequestHandler):
@run_async @run_async
def api_call(app, action, key, args, callback): def api_call(app, action, key, args, callback):
with app.app_context(): with db.session():
u = user.models.User.get(key) u = user.models.User.get(key)
if action in ( if action in (
'requestPeering', 'acceptPeering', 'rejectPeering', 'removePeering' 'requestPeering', 'acceptPeering', 'rejectPeering', 'removePeering'
@ -125,7 +126,7 @@ class ShareHandler(tornado.web.RequestHandler):
def publish_node(app): def publish_node(app):
update_online() update_online()
if state.online: if state.online:
with app.app_context(): with db.session():
for u in user.models.User.query.filter_by(queued=True): for u in user.models.User.query.filter_by(queued=True):
logger.debug('adding queued node... %s', u.id) logger.debug('adding queued node... %s', u.id)
state.nodes.queue('add', u.id) state.nodes.queue('add', u.id)
@ -160,7 +161,7 @@ def update_online():
def check_nodes(app): def check_nodes(app):
if state.online: if state.online:
with app.app_context(): with db.session():
for u in user.models.User.query.filter_by(queued=True): for u in user.models.User.query.filter_by(queued=True):
if not state.nodes.is_online(u.id): if not state.nodes.is_online(u.id):
logger.debug('queued peering message for %s trying to connect...', u.id) logger.debug('queued peering message for %s trying to connect...', u.id)

View file

@ -26,6 +26,7 @@ from websocket import trigger_event
from localnodes import LocalNodes from localnodes import LocalNodes
from ssl_request import get_opener from ssl_request import get_opener
import state import state
import db
import logging import logging
logger = logging.getLogger('oml.nodes') logger = logging.getLogger('oml.nodes')
@ -55,7 +56,6 @@ class Node(Thread):
self.ping() self.ping()
def run(self): def run(self):
with self._app.app_context():
while self._running: while self._running:
action = self._q.get() action = self._q.get()
if not self._running: if not self._running:
@ -187,6 +187,7 @@ class Node(Thread):
@property @property
def user(self): def user(self):
with db.session():
return user.models.User.get_or_create(self.user_id) return user.models.User.get_or_create(self.user_id)
def can_connect(self): def can_connect(self):
@ -248,7 +249,6 @@ class Node(Thread):
}) })
def pullChanges(self): def pullChanges(self):
with self._app.app_context():
last = Changelog.query.filter_by(user_id=self.user_id).order_by('-revision').first() last = Changelog.query.filter_by(user_id=self.user_id).order_by('-revision').first()
from_revision = last.revision + 1 if last else 0 from_revision = last.revision + 1 if last else 0
logger.debug('pullChanges %s from %s', self.user.name, from_revision) logger.debug('pullChanges %s from %s', self.user.name, from_revision)
@ -391,13 +391,13 @@ class Nodes(Thread):
def _add(self, user_id): def _add(self, user_id):
if user_id not in self._nodes: if user_id not in self._nodes:
from user.models import User from user.models import User
with db.session():
self._nodes[user_id] = Node(self, User.get_or_create(user_id)) self._nodes[user_id] = Node(self, User.get_or_create(user_id))
else: else:
if not self._nodes[user_id].online: if not self._nodes[user_id].online:
self._nodes[user_id].ping() self._nodes[user_id].ping()
def run(self): def run(self):
with self._app.app_context():
while self._running: while self._running:
args = self._q.get() args = self._q.get()
if args: if args:

View file

@ -17,6 +17,7 @@ from functools import wraps
import logging import logging
logger = logging.getLogger('oxtornado') logger = logging.getLogger('oxtornado')
import db
def json_response(data=None, status=200, text='ok'): def json_response(data=None, status=200, text='ok'):
if not data: if not data:
@ -86,7 +87,7 @@ def api_task(app, request, callback):
logger.debug('API %s %s', action, data) logger.debug('API %s %s', action, data)
f = actions.get(action) f = actions.get(action)
if f: if f:
with app.app_context(): with db.session():
try: try:
response = f(data) response = f(data)
except: except:

View file

@ -75,11 +75,12 @@ def run():
import user import user
import downloads import downloads
import nodes import nodes
import db
state.node = node.server.start(app) state.node = node.server.start(app)
state.nodes = nodes.Nodes(app) state.nodes = nodes.Nodes(app)
state.downloads = downloads.Downloads(app) state.downloads = downloads.Downloads(app)
def add_users(app): def add_users(app):
with app.app_context(): with db.session():
for p in user.models.User.query.filter_by(peered=True): for p in user.models.User.query.filter_by(peered=True):
state.nodes.queue('add', p.id) state.nodes.queue('add', p.id)
state.main.add_callback(add_users, app) state.main.add_callback(add_users, app)

View file

@ -11,3 +11,6 @@ def user():
import settings import settings
import user.models import user.models
return user.models.User.get_or_create(settings.USER_ID) return user.models.User.get_or_create(settings.USER_ID)
from threading import local
db = local()

View file

@ -245,8 +245,8 @@ def sortLists(data):
n += 1 n += 1
if l.type == 'static': if l.type == 'static':
lists.append(l.name) lists.append(l.name)
models.db.session.add(l) state.db.session.add(l)
models.db.session.commit() state.db.session.commit()
if lists: if lists:
Changelog.record(state.user(), 'orderlists', lists) Changelog.record(state.user(), 'orderlists', lists)
return {} return {}
@ -287,8 +287,8 @@ def sortUsers(data):
u = models.User.get(id) u = models.User.get(id)
u.info['index'] = n u.info['index'] = n
n += 1 n += 1
models.db.session.add(u) state.db.session.add(u)
models.db.session.commit() state.db.session.commit()
return {} return {}
actions.register(sortUsers, cache=False) actions.register(sortUsers, cache=False)

View file

@ -50,8 +50,8 @@ class User(db.Model):
return user return user
def save(self): def save(self):
db.session.add(self) state.db.session.add(self)
db.session.commit() state.db.session.commit()
@property @property
def name(self): def name(self):
@ -128,8 +128,7 @@ class User(db.Model):
n += 1 n += 1
self.nickname = nickname self.nickname = nickname
metadata = sa.MetaData() list_items = sa.Table('listitem', db.metadata,
list_items = sa.Table('listitem', metadata,
sa.Column('list_id', sa.Integer(), sa.ForeignKey('list.id')), sa.Column('list_id', sa.Integer(), sa.ForeignKey('list.id')),
sa.Column('item_id', sa.String(32), sa.ForeignKey('item.id')) sa.Column('item_id', sa.String(32), sa.ForeignKey('item.id'))
) )
@ -186,8 +185,8 @@ class List(db.Model):
l._query = query l._query = query
l.type = 'smart' if l._query else 'static' l.type = 'smart' if l._query else 'static'
l.index_ = cls.query.filter_by(user_id=user_id).count() l.index_ = cls.query.filter_by(user_id=user_id).count()
db.session.add(l) state.db.session.add(l)
db.session.commit() state.db.session.commit()
if user_id == settings.USER_ID: if user_id == settings.USER_ID:
if not l._query: if not l._query:
Changelog.record(state.user(), 'addlist', l.name) Changelog.record(state.user(), 'addlist', l.name)
@ -220,8 +219,8 @@ class List(db.Model):
if self.user_id == settings.USER_ID: if self.user_id == settings.USER_ID:
i.queue_download() i.queue_download()
i.update() i.update()
db.session.add(self) state.db.session.add(self)
db.session.commit() state.db.session.commit()
if self.user_id == settings.USER_ID: if self.user_id == settings.USER_ID:
Changelog.record(self.user, 'addlistitems', self.name, items) Changelog.record(self.user, 'addlistitems', self.name, items)
@ -232,8 +231,8 @@ class List(db.Model):
if i in self.items: if i in self.items:
self.items.remove(i) self.items.remove(i)
i.update() i.update()
db.session.add(self) state.db.session.add(self)
db.session.commit() state.db.session.commit()
if self.user_id == settings.USER_ID: if self.user_id == settings.USER_ID:
Changelog.record(self.user, 'removelistitems', self.name, items) Changelog.record(self.user, 'removelistitems', self.name, items)
@ -244,8 +243,8 @@ class List(db.Model):
if not self._query: if not self._query:
if self.user_id == settings.USER_ID: if self.user_id == settings.USER_ID:
Changelog.record(self.user, 'removelist', self.name) Changelog.record(self.user, 'removelist', self.name)
db.session.delete(self) state.db.session.delete(self)
db.session.commit() state.db.session.commit()
@property @property
def public_id(self): def public_id(self):
@ -293,5 +292,5 @@ class List(db.Model):
return r return r
def save(self): def save(self):
db.session.add(self) state.db.session.add(self)
db.session.commit() state.db.session.commit()