one sqlalchemy session per thread

This commit is contained in:
j 2014-08-09 18:14:14 +02:00
parent 0c08d37c56
commit 8b46a85d56
15 changed files with 140 additions and 102 deletions

View file

@ -13,7 +13,6 @@ import settings
import changelog
from db import session
import item.models
import user.models
import item.person

View file

@ -57,8 +57,8 @@ class Changelog(db.Model):
c.data = json.dumps([action] + list(args))
_data = str(c.revision) + str(c.timestamp) + c.data
c.sig = settings.sk.sign(_data, encoding='base64')
db.session.add(c)
db.session.commit()
state.db.session.add(c)
state.db.session.commit()
if state.nodes:
state.nodes.queue('peered', 'pushChanges', [c.json()])
@ -94,8 +94,8 @@ class Changelog(db.Model):
logger.debug('apply change from %s: %s', user.name, args)
if getattr(c, 'action_' + args[0])(user, timestamp, *args[1:]):
logger.debug('change applied')
db.session.add(c)
db.session.commit()
state.db.session.add(c)
state.db.session.commit()
if trigger:
trigger_event('change', {});
return True
@ -118,8 +118,8 @@ class Changelog(db.Model):
for c in cls.query.filter_by(user_id=settings.USER_ID):
_data = str(c.revision) + str(c.timestamp) + c.data
c.sig = settings.sk.sign(_data, encoding='base64')
db.session.add(c)
db.session.commit()
state.db.session.add(c)
state.db.session.commit()
def json(self):
timestamp = self.timestamp or datetime2ts(self.created)

View file

@ -1,19 +1,52 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from contextlib import contextmanager
from sqlalchemy import create_engine, MetaData
from sqlalchemy import orm
from sqlalchemy.orm.exc import UnmappedClassError
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.ext.mutable import Mutable
from sqlalchemy.ext.declarative import declarative_base
import settings
import state
engine = create_engine('sqlite:////%s' % settings.db_path)
Session = sessionmaker(bind=engine)
Session = scoped_session(sessionmaker(bind=engine))
# create a Session
session = Session()
metadata = MetaData()
class _QueryProperty(object):
def __init__(self):
pass
def __get__(self, obj, type):
try:
mapper = orm.class_mapper(type)
if mapper:
return type.query_class(mapper, session=state.db.session)
except UnmappedClassError:
return None
class BaseQuery(orm.Query):
pass
Model = declarative_base()
Model.query_class = BaseQuery
Model.query = _QueryProperty()
Model.metadata = metadata
@contextmanager
def session():
if hasattr(state.db, 'session'):
state.db.count += 1
else:
state.db.session = Session()
state.db.count = 1
yield
state.db.count -= 1
if not state.db.count:
Session.remove()
class MutableDict(Mutable, dict):
@classmethod

View file

@ -6,6 +6,7 @@ from threading import Thread
import time
import logging
import db
import state
import settings
import update
@ -43,7 +44,7 @@ class Downloads(Thread):
def run(self):
time.sleep(2)
with self._app.app_context():
with db.session():
while self._running:
self.download_next()
time.sleep(0.5)

View file

@ -17,6 +17,7 @@ from utils import resize_image
from settings import icons_db_path
import db
import logging
logger = logging.getLogger('oml.item.icons')
@ -120,7 +121,7 @@ def get_icon(app, id, type_, size, callback):
@run_async
def get_icon_app(app, id, type_, size, callback):
with app.app_context():
with db.session():
from item.models import Item
item = Item.get(id)
if not item:

View file

@ -41,8 +41,7 @@ from utils import remove_empty_folders
logger = logging.getLogger('oml.item.model')
metadata = sa.MetaData()
user_items = sa.Table('useritem', metadata,
user_items = sa.Table('useritem', db.metadata,
sa.Column('user_id', sa.String(43), sa.ForeignKey('user.id')),
sa.Column('item_id', sa.String(32), sa.ForeignKey('item.id'))
)
@ -97,8 +96,8 @@ class Item(db.Model):
item = cls(id=id)
if info:
item.info = info
db.session.add(item)
db.session.commit()
state.db.session.add(item)
state.db.session.commit()
return item
def json(self, keys=None):
@ -166,7 +165,7 @@ class Item(db.Model):
elif isinstance(value, list): #empty list
value = ''
setattr(s, key['id'], value)
db.session.add(s)
state.db.session.add(s)
def update_find(self):
@ -176,7 +175,7 @@ class Item(db.Model):
v = v.decode('utf-8')
f.findvalue = unicodedata.normalize('NFKD', v).lower()
f.value = v
db.session.add(f)
state.db.session.add(f)
for key in config['itemKeys']:
if key.get('find') or key.get('filter') or key.get('type') in [['string'], 'string']:
@ -195,7 +194,7 @@ class Item(db.Model):
else:
f = Find.get(self.id, key['id'])
if f:
db.session.delete(f)
state.db.session.delete(f)
def update(self):
for key in ('mediastate', 'coverRatio', 'previewRatio'):
@ -218,15 +217,15 @@ class Item(db.Model):
self.save()
def save(self):
db.session.add(self)
db.session.commit()
state.db.session.add(self)
state.db.session.commit()
def delete(self, commit=True):
db.session.delete(self)
state.db.session.delete(self)
Sort.query.filter_by(item_id=self.id).delete()
Transfer.query.filter_by(item_id=self.id).delete()
if commit:
db.session.commit()
state.db.session.commit()
meta_keys = ('title', 'author', 'date', 'publisher', 'edition', 'language')
@ -392,13 +391,13 @@ class Item(db.Model):
if os.path.exists(path):
os.unlink(path)
remove_empty_folders(os.path.dirname(path))
db.session.delete(f)
state.db.session.delete(f)
user = state.user()
if user in self.users:
self.users.remove(user)
for l in self.lists.filter_by(user_id=user.id):
l.items.remove(self)
db.session.commit()
state.db.session.commit()
if not self.users:
self.delete()
else:
@ -424,8 +423,8 @@ class Sort(db.Model):
f = cls.get(item_id)
if not f:
f = cls(item_id=item_id)
db.session.add(f)
db.session.commit()
state.db.session.add(f)
state.db.session.commit()
return f
for key in config['itemKeys']:
@ -467,8 +466,8 @@ class Find(db.Model):
f = cls.get(item, key)
if not f:
f = cls(item_id=item, key=key)
db.session.add(f)
db.session.commit()
state.db.session.add(f)
state.db.session.commit()
return f
class File(db.Model):
@ -499,8 +498,8 @@ class File(db.Model):
if path:
f.path = path
f.item_id = Item.get_or_create(id=sha1, info=info).id
db.session.add(f)
db.session.commit()
state.db.session.add(f)
state.db.session.commit()
return f
def __repr__(self):
@ -556,8 +555,8 @@ class File(db.Model):
self.save()
def save(self):
db.session.add(self)
db.session.commit()
state.db.session.add(self)
state.db.session.commit()
class Transfer(db.Model):
@ -587,8 +586,8 @@ class Transfer(db.Model):
return t
def save(self):
db.session.add(self)
db.session.commit()
state.db.session.add(self)
state.db.session.commit()
class Metadata(db.Model):
__tablename__ = 'metadata'
@ -626,14 +625,14 @@ class Metadata(db.Model):
def save(self):
self.modified = datetime.utcnow()
db.session.add(self)
db.session.commit()
state.db.session.add(self)
state.db.session.commit()
def reset(self):
user = state.user()
Changelog.record(user, 'resetmeta', self.key, self.value)
db.session.delete(self)
db.session.commit()
state.db.session.delete(self)
state.db.session.commit()
self.update_items()
def edit(self, data):

View file

@ -41,6 +41,6 @@ class Person(db.Model):
self.sortname = unicodedata.normalize('NFKD', self.sortname)
self.sortsortname = ox.sort_string(self.sortname)
self.numberofnames = len(self.name.split(' '))
db.session.add(self)
db.session.commit()
state.db.session.add(self)
state.db.session.commit()

View file

@ -11,7 +11,7 @@ import ox
from app import app
import settings
from settings import db
import db
from item.models import File
from user.models import User, List
@ -29,7 +29,7 @@ extensions = ['epub', 'pdf', 'txt']
def remove_missing():
dirty = False
with app.app_context():
with db.session():
prefs = settings.preferences
prefix = os.path.join(os.path.expanduser(prefs['libraryPath']), 'Books/')
if os.path.exists(prefix):
@ -39,7 +39,7 @@ def remove_missing():
dirty = True
f.item.remove_file()
if dirty:
db.session.commit()
state.db.session.commit()
def add_file(id, f, prefix, from_=None):
user = state.user()
@ -49,10 +49,10 @@ def add_file(id, f, prefix, from_=None):
item = file.item
if 'primaryid' in file.info:
del file.info['primaryid']
db.session.add(file)
state.db.session.add(file)
if 'primaryid' in item.info:
item.meta['primaryid'] = item.info.pop('primaryid')
db.session.add(item)
state.db.session.add(item)
item.users.append(user)
Changelog.record(user, 'additem', item.id, file.info)
if item.meta.get('primaryid'):
@ -65,7 +65,7 @@ def add_file(id, f, prefix, from_=None):
def run_scan():
remove_missing()
with app.app_context():
with db.session():
prefs = settings.preferences
prefix = os.path.join(os.path.expanduser(prefs['libraryPath']), 'Books/')
if not prefix[-1] == '/':
@ -96,7 +96,7 @@ def run_scan():
def run_import(options=None):
options = options or {}
with app.app_context():
with db.session():
prefs = settings.preferences
prefix = os.path.expanduser(options.get('path', prefs['importPath']))
if os.path.islink(prefix):

View file

@ -10,6 +10,7 @@ from tornado.ioloop import PeriodicCallback
import settings
import directory
import db
import state
import user
@ -81,7 +82,7 @@ class NodeHandler(tornado.web.RequestHandler):
@run_async
def api_call(app, action, key, args, callback):
with app.app_context():
with db.session():
u = user.models.User.get(key)
if action in (
'requestPeering', 'acceptPeering', 'rejectPeering', 'removePeering'
@ -125,7 +126,7 @@ class ShareHandler(tornado.web.RequestHandler):
def publish_node(app):
update_online()
if state.online:
with app.app_context():
with db.session():
for u in user.models.User.query.filter_by(queued=True):
logger.debug('adding queued node... %s', u.id)
state.nodes.queue('add', u.id)
@ -160,7 +161,7 @@ def update_online():
def check_nodes(app):
if state.online:
with app.app_context():
with db.session():
for u in user.models.User.query.filter_by(queued=True):
if not state.nodes.is_online(u.id):
logger.debug('queued peering message for %s trying to connect...', u.id)

View file

@ -26,6 +26,7 @@ from websocket import trigger_event
from localnodes import LocalNodes
from ssl_request import get_opener
import state
import db
import logging
logger = logging.getLogger('oml.nodes')
@ -55,15 +56,14 @@ class Node(Thread):
self.ping()
def run(self):
with self._app.app_context():
while self._running:
action = self._q.get()
if not self._running:
break
if action == 'go_online' or not self.online:
self._go_online()
else:
self.online = self.can_connect()
while self._running:
action = self._q.get()
if not self._running:
break
if action == 'go_online' or not self.online:
self._go_online()
else:
self.online = self.can_connect()
def join(self):
self._running = False
@ -187,7 +187,8 @@ class Node(Thread):
@property
def user(self):
return user.models.User.get_or_create(self.user_id)
with db.session():
return user.models.User.get_or_create(self.user_id)
def can_connect(self):
try:
@ -248,14 +249,13 @@ class Node(Thread):
})
def pullChanges(self):
with self._app.app_context():
last = Changelog.query.filter_by(user_id=self.user_id).order_by('-revision').first()
from_revision = last.revision + 1 if last else 0
logger.debug('pullChanges %s from %s', self.user.name, from_revision)
changes = self.request('pullChanges', from_revision)
if not changes:
return False
return Changelog.apply_changes(self.user, changes)
last = Changelog.query.filter_by(user_id=self.user_id).order_by('-revision').first()
from_revision = last.revision + 1 if last else 0
logger.debug('pullChanges %s from %s', self.user.name, from_revision)
changes = self.request('pullChanges', from_revision)
if not changes:
return False
return Changelog.apply_changes(self.user, changes)
def pushChanges(self, changes):
logger.debug('pushing changes to %s %s', self.user_id, changes)
@ -391,20 +391,20 @@ class Nodes(Thread):
def _add(self, user_id):
if user_id not in self._nodes:
from user.models import User
self._nodes[user_id] = Node(self, User.get_or_create(user_id))
with db.session():
self._nodes[user_id] = Node(self, User.get_or_create(user_id))
else:
if not self._nodes[user_id].online:
self._nodes[user_id].ping()
def run(self):
with self._app.app_context():
while self._running:
args = self._q.get()
if args:
if args[0] == 'add':
self._add(args[1])
else:
self._call(*args)
while self._running:
args = self._q.get()
if args:
if args[0] == 'add':
self._add(args[1])
else:
self._call(*args)
def join(self):
self._running = False

View file

@ -17,6 +17,7 @@ from functools import wraps
import logging
logger = logging.getLogger('oxtornado')
import db
def json_response(data=None, status=200, text='ok'):
if not data:
@ -86,7 +87,7 @@ def api_task(app, request, callback):
logger.debug('API %s %s', action, data)
f = actions.get(action)
if f:
with app.app_context():
with db.session():
try:
response = f(data)
except:

View file

@ -75,11 +75,12 @@ def run():
import user
import downloads
import nodes
import db
state.node = node.server.start(app)
state.nodes = nodes.Nodes(app)
state.downloads = downloads.Downloads(app)
def add_users(app):
with app.app_context():
with db.session():
for p in user.models.User.query.filter_by(peered=True):
state.nodes.queue('add', p.id)
state.main.add_callback(add_users, app)

View file

@ -11,3 +11,6 @@ def user():
import settings
import user.models
return user.models.User.get_or_create(settings.USER_ID)
from threading import local
db = local()

View file

@ -245,8 +245,8 @@ def sortLists(data):
n += 1
if l.type == 'static':
lists.append(l.name)
models.db.session.add(l)
models.db.session.commit()
state.db.session.add(l)
state.db.session.commit()
if lists:
Changelog.record(state.user(), 'orderlists', lists)
return {}
@ -287,8 +287,8 @@ def sortUsers(data):
u = models.User.get(id)
u.info['index'] = n
n += 1
models.db.session.add(u)
models.db.session.commit()
state.db.session.add(u)
state.db.session.commit()
return {}
actions.register(sortUsers, cache=False)

View file

@ -50,8 +50,8 @@ class User(db.Model):
return user
def save(self):
db.session.add(self)
db.session.commit()
state.db.session.add(self)
state.db.session.commit()
@property
def name(self):
@ -128,8 +128,7 @@ class User(db.Model):
n += 1
self.nickname = nickname
metadata = sa.MetaData()
list_items = sa.Table('listitem', metadata,
list_items = sa.Table('listitem', db.metadata,
sa.Column('list_id', sa.Integer(), sa.ForeignKey('list.id')),
sa.Column('item_id', sa.String(32), sa.ForeignKey('item.id'))
)
@ -186,8 +185,8 @@ class List(db.Model):
l._query = query
l.type = 'smart' if l._query else 'static'
l.index_ = cls.query.filter_by(user_id=user_id).count()
db.session.add(l)
db.session.commit()
state.db.session.add(l)
state.db.session.commit()
if user_id == settings.USER_ID:
if not l._query:
Changelog.record(state.user(), 'addlist', l.name)
@ -220,8 +219,8 @@ class List(db.Model):
if self.user_id == settings.USER_ID:
i.queue_download()
i.update()
db.session.add(self)
db.session.commit()
state.db.session.add(self)
state.db.session.commit()
if self.user_id == settings.USER_ID:
Changelog.record(self.user, 'addlistitems', self.name, items)
@ -232,8 +231,8 @@ class List(db.Model):
if i in self.items:
self.items.remove(i)
i.update()
db.session.add(self)
db.session.commit()
state.db.session.add(self)
state.db.session.commit()
if self.user_id == settings.USER_ID:
Changelog.record(self.user, 'removelistitems', self.name, items)
@ -244,8 +243,8 @@ class List(db.Model):
if not self._query:
if self.user_id == settings.USER_ID:
Changelog.record(self.user, 'removelist', self.name)
db.session.delete(self)
db.session.commit()
state.db.session.delete(self)
state.db.session.commit()
@property
def public_id(self):
@ -293,5 +292,5 @@ class List(db.Model):
return r
def save(self):
db.session.add(self)
db.session.commit()
state.db.session.add(self)
state.db.session.commit()