From 8b46a85d5618d9c1198ed148b4fb889eb7ba4529 Mon Sep 17 00:00:00 2001 From: j Date: Sat, 9 Aug 2014 18:14:14 +0200 Subject: [PATCH] one sqlalchemy session per thread --- oml/app.py | 1 - oml/changelog.py | 12 +++++------ oml/db.py | 47 ++++++++++++++++++++++++++++++++++------ oml/downloads.py | 3 ++- oml/item/icons.py | 3 ++- oml/item/models.py | 53 ++++++++++++++++++++++----------------------- oml/item/person.py | 4 ++-- oml/item/scan.py | 14 ++++++------ oml/node/server.py | 7 +++--- oml/nodes.py | 54 +++++++++++++++++++++++----------------------- oml/oxtornado.py | 3 ++- oml/server.py | 3 ++- oml/state.py | 3 +++ oml/user/api.py | 8 +++---- oml/user/models.py | 27 +++++++++++------------ 15 files changed, 140 insertions(+), 102 deletions(-) diff --git a/oml/app.py b/oml/app.py index fc1a3a9..0ceea7b 100644 --- a/oml/app.py +++ b/oml/app.py @@ -13,7 +13,6 @@ import settings import changelog -from db import session import item.models import user.models import item.person diff --git a/oml/changelog.py b/oml/changelog.py index a85f6fc..bd7e827 100644 --- a/oml/changelog.py +++ b/oml/changelog.py @@ -57,8 +57,8 @@ class Changelog(db.Model): c.data = json.dumps([action] + list(args)) _data = str(c.revision) + str(c.timestamp) + c.data c.sig = settings.sk.sign(_data, encoding='base64') - db.session.add(c) - db.session.commit() + state.db.session.add(c) + state.db.session.commit() if state.nodes: state.nodes.queue('peered', 'pushChanges', [c.json()]) @@ -94,8 +94,8 @@ class Changelog(db.Model): logger.debug('apply change from %s: %s', user.name, args) if getattr(c, 'action_' + args[0])(user, timestamp, *args[1:]): logger.debug('change applied') - db.session.add(c) - db.session.commit() + state.db.session.add(c) + state.db.session.commit() if trigger: trigger_event('change', {}); return True @@ -118,8 +118,8 @@ class Changelog(db.Model): for c in cls.query.filter_by(user_id=settings.USER_ID): _data = str(c.revision) + str(c.timestamp) + c.data c.sig = settings.sk.sign(_data, encoding='base64') - db.session.add(c) - db.session.commit() + state.db.session.add(c) + state.db.session.commit() def json(self): timestamp = self.timestamp or datetime2ts(self.created) diff --git a/oml/db.py b/oml/db.py index 45943cd..f461409 100644 --- a/oml/db.py +++ b/oml/db.py @@ -1,19 +1,52 @@ -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker +from contextlib import contextmanager +from sqlalchemy import create_engine, MetaData +from sqlalchemy import orm +from sqlalchemy.orm.exc import UnmappedClassError +from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.ext.mutable import Mutable from sqlalchemy.ext.declarative import declarative_base + + import settings - - +import state engine = create_engine('sqlite:////%s' % settings.db_path) -Session = sessionmaker(bind=engine) +Session = scoped_session(sessionmaker(bind=engine)) -# create a Session -session = Session() +metadata = MetaData() +class _QueryProperty(object): + + def __init__(self): + pass + + def __get__(self, obj, type): + try: + mapper = orm.class_mapper(type) + if mapper: + return type.query_class(mapper, session=state.db.session) + except UnmappedClassError: + return None + +class BaseQuery(orm.Query): + pass + Model = declarative_base() +Model.query_class = BaseQuery +Model.query = _QueryProperty() +Model.metadata = metadata +@contextmanager +def session(): + if hasattr(state.db, 'session'): + state.db.count += 1 + else: + state.db.session = Session() + state.db.count = 1 + yield + state.db.count -= 1 + if not state.db.count: + Session.remove() class MutableDict(Mutable, dict): @classmethod diff --git a/oml/downloads.py b/oml/downloads.py index e7d23a8..bc46250 100644 --- a/oml/downloads.py +++ b/oml/downloads.py @@ -6,6 +6,7 @@ from threading import Thread import time import logging +import db import state import settings import update @@ -43,7 +44,7 @@ class Downloads(Thread): def run(self): time.sleep(2) - with self._app.app_context(): + with db.session(): while self._running: self.download_next() time.sleep(0.5) diff --git a/oml/item/icons.py b/oml/item/icons.py index 588d561..e5a70a1 100644 --- a/oml/item/icons.py +++ b/oml/item/icons.py @@ -17,6 +17,7 @@ from utils import resize_image from settings import icons_db_path +import db import logging logger = logging.getLogger('oml.item.icons') @@ -120,7 +121,7 @@ def get_icon(app, id, type_, size, callback): @run_async def get_icon_app(app, id, type_, size, callback): - with app.app_context(): + with db.session(): from item.models import Item item = Item.get(id) if not item: diff --git a/oml/item/models.py b/oml/item/models.py index 3318d13..a72b1e3 100644 --- a/oml/item/models.py +++ b/oml/item/models.py @@ -41,8 +41,7 @@ from utils import remove_empty_folders logger = logging.getLogger('oml.item.model') -metadata = sa.MetaData() -user_items = sa.Table('useritem', metadata, +user_items = sa.Table('useritem', db.metadata, sa.Column('user_id', sa.String(43), sa.ForeignKey('user.id')), sa.Column('item_id', sa.String(32), sa.ForeignKey('item.id')) ) @@ -97,8 +96,8 @@ class Item(db.Model): item = cls(id=id) if info: item.info = info - db.session.add(item) - db.session.commit() + state.db.session.add(item) + state.db.session.commit() return item def json(self, keys=None): @@ -166,7 +165,7 @@ class Item(db.Model): elif isinstance(value, list): #empty list value = '' setattr(s, key['id'], value) - db.session.add(s) + state.db.session.add(s) def update_find(self): @@ -176,7 +175,7 @@ class Item(db.Model): v = v.decode('utf-8') f.findvalue = unicodedata.normalize('NFKD', v).lower() f.value = v - db.session.add(f) + state.db.session.add(f) for key in config['itemKeys']: if key.get('find') or key.get('filter') or key.get('type') in [['string'], 'string']: @@ -195,7 +194,7 @@ class Item(db.Model): else: f = Find.get(self.id, key['id']) if f: - db.session.delete(f) + state.db.session.delete(f) def update(self): for key in ('mediastate', 'coverRatio', 'previewRatio'): @@ -218,15 +217,15 @@ class Item(db.Model): self.save() def save(self): - db.session.add(self) - db.session.commit() + state.db.session.add(self) + state.db.session.commit() def delete(self, commit=True): - db.session.delete(self) + state.db.session.delete(self) Sort.query.filter_by(item_id=self.id).delete() Transfer.query.filter_by(item_id=self.id).delete() if commit: - db.session.commit() + state.db.session.commit() meta_keys = ('title', 'author', 'date', 'publisher', 'edition', 'language') @@ -392,13 +391,13 @@ class Item(db.Model): if os.path.exists(path): os.unlink(path) remove_empty_folders(os.path.dirname(path)) - db.session.delete(f) + state.db.session.delete(f) user = state.user() if user in self.users: self.users.remove(user) for l in self.lists.filter_by(user_id=user.id): l.items.remove(self) - db.session.commit() + state.db.session.commit() if not self.users: self.delete() else: @@ -424,8 +423,8 @@ class Sort(db.Model): f = cls.get(item_id) if not f: f = cls(item_id=item_id) - db.session.add(f) - db.session.commit() + state.db.session.add(f) + state.db.session.commit() return f for key in config['itemKeys']: @@ -467,8 +466,8 @@ class Find(db.Model): f = cls.get(item, key) if not f: f = cls(item_id=item, key=key) - db.session.add(f) - db.session.commit() + state.db.session.add(f) + state.db.session.commit() return f class File(db.Model): @@ -499,8 +498,8 @@ class File(db.Model): if path: f.path = path f.item_id = Item.get_or_create(id=sha1, info=info).id - db.session.add(f) - db.session.commit() + state.db.session.add(f) + state.db.session.commit() return f def __repr__(self): @@ -556,8 +555,8 @@ class File(db.Model): self.save() def save(self): - db.session.add(self) - db.session.commit() + state.db.session.add(self) + state.db.session.commit() class Transfer(db.Model): @@ -587,8 +586,8 @@ class Transfer(db.Model): return t def save(self): - db.session.add(self) - db.session.commit() + state.db.session.add(self) + state.db.session.commit() class Metadata(db.Model): __tablename__ = 'metadata' @@ -626,14 +625,14 @@ class Metadata(db.Model): def save(self): self.modified = datetime.utcnow() - db.session.add(self) - db.session.commit() + state.db.session.add(self) + state.db.session.commit() def reset(self): user = state.user() Changelog.record(user, 'resetmeta', self.key, self.value) - db.session.delete(self) - db.session.commit() + state.db.session.delete(self) + state.db.session.commit() self.update_items() def edit(self, data): diff --git a/oml/item/person.py b/oml/item/person.py index c472487..c69099b 100644 --- a/oml/item/person.py +++ b/oml/item/person.py @@ -41,6 +41,6 @@ class Person(db.Model): self.sortname = unicodedata.normalize('NFKD', self.sortname) self.sortsortname = ox.sort_string(self.sortname) self.numberofnames = len(self.name.split(' ')) - db.session.add(self) - db.session.commit() + state.db.session.add(self) + state.db.session.commit() diff --git a/oml/item/scan.py b/oml/item/scan.py index adc45ae..1708cb3 100644 --- a/oml/item/scan.py +++ b/oml/item/scan.py @@ -11,7 +11,7 @@ import ox from app import app import settings -from settings import db +import db from item.models import File from user.models import User, List @@ -29,7 +29,7 @@ extensions = ['epub', 'pdf', 'txt'] def remove_missing(): dirty = False - with app.app_context(): + with db.session(): prefs = settings.preferences prefix = os.path.join(os.path.expanduser(prefs['libraryPath']), 'Books/') if os.path.exists(prefix): @@ -39,7 +39,7 @@ def remove_missing(): dirty = True f.item.remove_file() if dirty: - db.session.commit() + state.db.session.commit() def add_file(id, f, prefix, from_=None): user = state.user() @@ -49,10 +49,10 @@ def add_file(id, f, prefix, from_=None): item = file.item if 'primaryid' in file.info: del file.info['primaryid'] - db.session.add(file) + state.db.session.add(file) if 'primaryid' in item.info: item.meta['primaryid'] = item.info.pop('primaryid') - db.session.add(item) + state.db.session.add(item) item.users.append(user) Changelog.record(user, 'additem', item.id, file.info) if item.meta.get('primaryid'): @@ -65,7 +65,7 @@ def add_file(id, f, prefix, from_=None): def run_scan(): remove_missing() - with app.app_context(): + with db.session(): prefs = settings.preferences prefix = os.path.join(os.path.expanduser(prefs['libraryPath']), 'Books/') if not prefix[-1] == '/': @@ -96,7 +96,7 @@ def run_scan(): def run_import(options=None): options = options or {} - with app.app_context(): + with db.session(): prefs = settings.preferences prefix = os.path.expanduser(options.get('path', prefs['importPath'])) if os.path.islink(prefix): diff --git a/oml/node/server.py b/oml/node/server.py index bcdabc8..40ce7e8 100644 --- a/oml/node/server.py +++ b/oml/node/server.py @@ -10,6 +10,7 @@ from tornado.ioloop import PeriodicCallback import settings import directory +import db import state import user @@ -81,7 +82,7 @@ class NodeHandler(tornado.web.RequestHandler): @run_async def api_call(app, action, key, args, callback): - with app.app_context(): + with db.session(): u = user.models.User.get(key) if action in ( 'requestPeering', 'acceptPeering', 'rejectPeering', 'removePeering' @@ -125,7 +126,7 @@ class ShareHandler(tornado.web.RequestHandler): def publish_node(app): update_online() if state.online: - with app.app_context(): + with db.session(): for u in user.models.User.query.filter_by(queued=True): logger.debug('adding queued node... %s', u.id) state.nodes.queue('add', u.id) @@ -160,7 +161,7 @@ def update_online(): def check_nodes(app): if state.online: - with app.app_context(): + with db.session(): for u in user.models.User.query.filter_by(queued=True): if not state.nodes.is_online(u.id): logger.debug('queued peering message for %s trying to connect...', u.id) diff --git a/oml/nodes.py b/oml/nodes.py index 1d7b1dd..446e69d 100644 --- a/oml/nodes.py +++ b/oml/nodes.py @@ -26,6 +26,7 @@ from websocket import trigger_event from localnodes import LocalNodes from ssl_request import get_opener import state +import db import logging logger = logging.getLogger('oml.nodes') @@ -55,15 +56,14 @@ class Node(Thread): self.ping() def run(self): - with self._app.app_context(): - while self._running: - action = self._q.get() - if not self._running: - break - if action == 'go_online' or not self.online: - self._go_online() - else: - self.online = self.can_connect() + while self._running: + action = self._q.get() + if not self._running: + break + if action == 'go_online' or not self.online: + self._go_online() + else: + self.online = self.can_connect() def join(self): self._running = False @@ -187,7 +187,8 @@ class Node(Thread): @property def user(self): - return user.models.User.get_or_create(self.user_id) + with db.session(): + return user.models.User.get_or_create(self.user_id) def can_connect(self): try: @@ -248,14 +249,13 @@ class Node(Thread): }) def pullChanges(self): - with self._app.app_context(): - last = Changelog.query.filter_by(user_id=self.user_id).order_by('-revision').first() - from_revision = last.revision + 1 if last else 0 - logger.debug('pullChanges %s from %s', self.user.name, from_revision) - changes = self.request('pullChanges', from_revision) - if not changes: - return False - return Changelog.apply_changes(self.user, changes) + last = Changelog.query.filter_by(user_id=self.user_id).order_by('-revision').first() + from_revision = last.revision + 1 if last else 0 + logger.debug('pullChanges %s from %s', self.user.name, from_revision) + changes = self.request('pullChanges', from_revision) + if not changes: + return False + return Changelog.apply_changes(self.user, changes) def pushChanges(self, changes): logger.debug('pushing changes to %s %s', self.user_id, changes) @@ -391,20 +391,20 @@ class Nodes(Thread): def _add(self, user_id): if user_id not in self._nodes: from user.models import User - self._nodes[user_id] = Node(self, User.get_or_create(user_id)) + with db.session(): + self._nodes[user_id] = Node(self, User.get_or_create(user_id)) else: if not self._nodes[user_id].online: self._nodes[user_id].ping() def run(self): - with self._app.app_context(): - while self._running: - args = self._q.get() - if args: - if args[0] == 'add': - self._add(args[1]) - else: - self._call(*args) + while self._running: + args = self._q.get() + if args: + if args[0] == 'add': + self._add(args[1]) + else: + self._call(*args) def join(self): self._running = False diff --git a/oml/oxtornado.py b/oml/oxtornado.py index 9215e54..d0c1f0e 100644 --- a/oml/oxtornado.py +++ b/oml/oxtornado.py @@ -17,6 +17,7 @@ from functools import wraps import logging logger = logging.getLogger('oxtornado') +import db def json_response(data=None, status=200, text='ok'): if not data: @@ -86,7 +87,7 @@ def api_task(app, request, callback): logger.debug('API %s %s', action, data) f = actions.get(action) if f: - with app.app_context(): + with db.session(): try: response = f(data) except: diff --git a/oml/server.py b/oml/server.py index 83433ab..7a62438 100644 --- a/oml/server.py +++ b/oml/server.py @@ -75,11 +75,12 @@ def run(): import user import downloads import nodes + import db state.node = node.server.start(app) state.nodes = nodes.Nodes(app) state.downloads = downloads.Downloads(app) def add_users(app): - with app.app_context(): + with db.session(): for p in user.models.User.query.filter_by(peered=True): state.nodes.queue('add', p.id) state.main.add_callback(add_users, app) diff --git a/oml/state.py b/oml/state.py index fcb54bd..d1a2b06 100644 --- a/oml/state.py +++ b/oml/state.py @@ -11,3 +11,6 @@ def user(): import settings import user.models return user.models.User.get_or_create(settings.USER_ID) + +from threading import local +db = local() diff --git a/oml/user/api.py b/oml/user/api.py index 943b789..4798438 100644 --- a/oml/user/api.py +++ b/oml/user/api.py @@ -245,8 +245,8 @@ def sortLists(data): n += 1 if l.type == 'static': lists.append(l.name) - models.db.session.add(l) - models.db.session.commit() + state.db.session.add(l) + state.db.session.commit() if lists: Changelog.record(state.user(), 'orderlists', lists) return {} @@ -287,8 +287,8 @@ def sortUsers(data): u = models.User.get(id) u.info['index'] = n n += 1 - models.db.session.add(u) - models.db.session.commit() + state.db.session.add(u) + state.db.session.commit() return {} actions.register(sortUsers, cache=False) diff --git a/oml/user/models.py b/oml/user/models.py index 51046de..72f1687 100644 --- a/oml/user/models.py +++ b/oml/user/models.py @@ -50,8 +50,8 @@ class User(db.Model): return user def save(self): - db.session.add(self) - db.session.commit() + state.db.session.add(self) + state.db.session.commit() @property def name(self): @@ -128,8 +128,7 @@ class User(db.Model): n += 1 self.nickname = nickname -metadata = sa.MetaData() -list_items = sa.Table('listitem', metadata, +list_items = sa.Table('listitem', db.metadata, sa.Column('list_id', sa.Integer(), sa.ForeignKey('list.id')), sa.Column('item_id', sa.String(32), sa.ForeignKey('item.id')) ) @@ -186,8 +185,8 @@ class List(db.Model): l._query = query l.type = 'smart' if l._query else 'static' l.index_ = cls.query.filter_by(user_id=user_id).count() - db.session.add(l) - db.session.commit() + state.db.session.add(l) + state.db.session.commit() if user_id == settings.USER_ID: if not l._query: Changelog.record(state.user(), 'addlist', l.name) @@ -220,8 +219,8 @@ class List(db.Model): if self.user_id == settings.USER_ID: i.queue_download() i.update() - db.session.add(self) - db.session.commit() + state.db.session.add(self) + state.db.session.commit() if self.user_id == settings.USER_ID: Changelog.record(self.user, 'addlistitems', self.name, items) @@ -232,8 +231,8 @@ class List(db.Model): if i in self.items: self.items.remove(i) i.update() - db.session.add(self) - db.session.commit() + state.db.session.add(self) + state.db.session.commit() if self.user_id == settings.USER_ID: Changelog.record(self.user, 'removelistitems', self.name, items) @@ -244,8 +243,8 @@ class List(db.Model): if not self._query: if self.user_id == settings.USER_ID: Changelog.record(self.user, 'removelist', self.name) - db.session.delete(self) - db.session.commit() + state.db.session.delete(self) + state.db.session.commit() @property def public_id(self): @@ -293,5 +292,5 @@ class List(db.Model): return r def save(self): - db.session.add(self) - db.session.commit() + state.db.session.add(self) + state.db.session.commit()