use one variable to track app shutdown state

This commit is contained in:
j 2016-01-31 22:15:14 +05:30
parent 0cc3a4523e
commit 567952d91d
7 changed files with 35 additions and 44 deletions

View file

@ -18,7 +18,6 @@ logger = logging.getLogger(__name__)
class Downloads(Thread): class Downloads(Thread):
def __init__(self): def __init__(self):
self._running = True
Thread.__init__(self) Thread.__init__(self)
self.daemon = True self.daemon = True
self.start() self.start()
@ -36,9 +35,11 @@ class Downloads(Thread):
for t in item.models.Transfer.query.filter( for t in item.models.Transfer.query.filter(
item.models.Transfer.added!=None, item.models.Transfer.added!=None,
item.models.Transfer.progress<1).order_by(item.models.Transfer.added): item.models.Transfer.progress<1).order_by(item.models.Transfer.added):
if not self._running: if state.shutdown:
return False return False
for u in t.item.users: for u in t.item.users:
if state.shutdown:
return False
if state.nodes.is_online(u.id): if state.nodes.is_online(u.id):
logger.debug('DOWNLOAD %s %s', t.item, u) logger.debug('DOWNLOAD %s %s', t.item, u)
r = state.nodes.download(u.id, t.item) r = state.nodes.download(u.id, t.item)
@ -47,14 +48,13 @@ class Downloads(Thread):
def run(self): def run(self):
self.wait(10) self.wait(10)
while self._running: while not state.shutdown:
self.wait_online() self.wait_online()
with db.session(): with db.session():
self.download_next() self.download_next()
self.wait(10) self.wait(10)
def join(self): def join(self):
self._running = False
return Thread.join(self) return Thread.join(self)
def wait_online(self): def wait_online(self):
@ -63,7 +63,7 @@ class Downloads(Thread):
def wait(self, timeout): def wait(self, timeout):
step = min(timeout, 1) step = min(timeout, 1)
while self._running and timeout > 0: while not state.shutdown and timeout > 0:
time.sleep(step) time.sleep(step)
timeout -= step timeout -= step

View file

@ -31,7 +31,7 @@ def remove_missing():
prefix = os.path.join(os.path.expanduser(prefs['libraryPath']), 'Books' + os.sep) prefix = os.path.join(os.path.expanduser(prefs['libraryPath']), 'Books' + os.sep)
if os.path.exists(prefix): if os.path.exists(prefix):
for f in File.query: for f in File.query:
if not state.tasks.connected: if state.shutdown:
return return
if f.item: if f.item:
path = f.item.get_path() path = f.item.get_path()
@ -45,7 +45,7 @@ def remove_missing():
state.db.session.commit() state.db.session.commit()
state.cache.clear('group:') state.cache.clear('group:')
for f in File.query: for f in File.query:
if not state.tasks.connected: if state.shutdown:
return return
f.move() f.move()
remove_empty_folders(prefix, True) remove_empty_folders(prefix, True)
@ -82,7 +82,7 @@ def run_scan():
books = [] books = []
for root, folders, files in os.walk(prefix): for root, folders, files in os.walk(prefix):
for f in files: for f in files:
if not state.tasks.connected: if state.shutdown:
return return
#if f.startswith('._') or f == '.DS_Store': #if f.startswith('._') or f == '.DS_Store':
if f.startswith('.'): if f.startswith('.'):
@ -97,7 +97,7 @@ def run_scan():
position = 0 position = 0
added = 0 added = 0
for f in ox.sorted_strings(books): for f in ox.sorted_strings(books):
if not state.tasks.connected: if state.shutdown:
return return
position += 1 position += 1
with db.session(): with db.session():
@ -159,7 +159,7 @@ def run_import(options=None):
count = 0 count = 0
for root, folders, files in os.walk(prefix): for root, folders, files in os.walk(prefix):
for f in files: for f in files:
if not state.tasks.connected: if state.shutdown:
return return
#if f.startswith('._') or f == '.DS_Store': #if f.startswith('._') or f == '.DS_Store':
if f.startswith('.'): if f.startswith('.'):
@ -217,7 +217,7 @@ def run_import(options=None):
if state.activity.get('cancel'): if state.activity.get('cancel'):
state.activity = {} state.activity = {}
return return
if not state.tasks.connected: if state.shutdown:
return return
if time.time() - last > 5: if time.time() - last > 5:
last = time.time() last = time.time()

View file

@ -55,7 +55,6 @@ class LocalNodesBase(Thread):
def __init__(self, nodes): def __init__(self, nodes):
self._socket = None self._socket = None
self._active = True
self._nodes = nodes self._nodes = nodes
Thread.__init__(self) Thread.__init__(self)
if not server['localnode_discovery']: if not server['localnode_discovery']:
@ -87,12 +86,12 @@ class LocalNodesBase(Thread):
last = time.mktime(time.localtime()) last = time.mktime(time.localtime())
s = self.get_socket() s = self.get_socket()
s.bind(('', self._PORT)) s.bind(('', self._PORT))
while self._active: while not state.shutdown:
try: try:
r, _, _ = select.select([s], [], [], 3) r, _, _ = select.select([s], [], [], 3)
if r: if r:
data, addr = s.recvfrom(1024) data, addr = s.recvfrom(1024)
if self._active: if not state.shutdown:
while data[-1] == 0: while data[-1] == 0:
data = data[:-1] # Strip trailing \0's data = data[:-1] # Strip trailing \0's
data = self.verify(data) data = self.verify(data)
@ -101,11 +100,11 @@ class LocalNodesBase(Thread):
except OSError: # no local interface exists except OSError: # no local interface exists
self.wait(60) self.wait(60)
except: except:
if self._active: if not state.shutdown:
logger.debug('receive failed. restart later', exc_info=True) logger.debug('receive failed. restart later', exc_info=True)
self.wait(60) self.wait(60)
finally: finally:
if self._active: if not state.shutdown:
now = time.mktime(time.localtime()) now = time.mktime(time.localtime())
if now - last > 60: if now - last > 60:
last = now last = now
@ -159,7 +158,6 @@ class LocalNodesBase(Thread):
self.receive() self.receive()
def join(self): def join(self):
self._active = False
if self._socket: if self._socket:
try: try:
self._socket.shutdown(socket.SHUT_RDWR) self._socket.shutdown(socket.SHUT_RDWR)
@ -170,7 +168,7 @@ class LocalNodesBase(Thread):
def wait(self, timeout): def wait(self, timeout):
step = min(timeout, 1) step = min(timeout, 1)
while self._active and timeout > 0: while not state.shutdown and timeout > 0:
time.sleep(step) time.sleep(step)
timeout -= step timeout -= step
@ -241,7 +239,6 @@ class LocalNodes6(LocalNodesBase):
class LocalNodes(object): class LocalNodes(object):
_active = True
_nodes4 = None _nodes4 = None
_nodes6 = None _nodes6 = None
@ -253,7 +250,7 @@ class LocalNodes(object):
#self._nodes6 = LocalNodes6(self._nodes) #self._nodes6 = LocalNodes6(self._nodes)
def cleanup(self): def cleanup(self):
if self._active: if not state.shutdown:
for id in list(self._nodes.keys()): for id in list(self._nodes.keys()):
if not can_connect(self._nodes[id]): if not can_connect(self._nodes[id]):
with db.session(): with db.session():
@ -262,7 +259,7 @@ class LocalNodes(object):
del u.info['local'] del u.info['local']
u.save() u.save()
del self._nodes[id] del self._nodes[id]
if not self._active: if state.shutdown:
break break
def get(self, user_id): def get(self, user_id):
@ -271,7 +268,6 @@ class LocalNodes(object):
return self._nodes[user_id] return self._nodes[user_id]
def join(self): def join(self):
self._active = False
if self._nodes4: if self._nodes4:
self._nodes4.join() self._nodes4.join()
if self._nodes6: if self._nodes6:

View file

@ -31,7 +31,6 @@ logger = logging.getLogger(__name__)
ENCODING='base64' ENCODING='base64'
class Node(Thread): class Node(Thread):
_running = True
host = None host = None
local = None local = None
_online = None _online = None
@ -46,13 +45,13 @@ class Node(Thread):
self._q = Queue() self._q = Queue()
Thread.__init__(self) Thread.__init__(self)
self.daemon = True self.daemon = True
self.start()
self.ping() self.ping()
self.start()
def run(self): def run(self):
while self._running: while not state.shutdown:
action = self._q.get() action = self._q.get()
if not self._running: if state.shutdown:
break break
if action == 'send_response': if action == 'send_response':
self._send_response() self._send_response()
@ -62,11 +61,9 @@ class Node(Thread):
logger.debug('unknown action %s', action) logger.debug('unknown action %s', action)
def join(self): def join(self):
self._running = False
self._q.put('') self._q.put('')
#return Thread.join(self) #return Thread.join(self)
def ping(self): def ping(self):
if state.online: if state.online:
self._q.put('ping') self._q.put('ping')
@ -155,7 +152,6 @@ class Node(Thread):
except urllib.error.HTTPError as e: except urllib.error.HTTPError as e:
if e.code == 403: if e.code == 403:
logger.debug('403: %s (%s)', url, self.user_id) logger.debug('403: %s (%s)', url, self.user_id)
self._running = False
if state.tasks: if state.tasks:
state.tasks.queue('peering', (self.user_id, False)) state.tasks.queue('peering', (self.user_id, False))
del self._nodes[self.user_id] del self._nodes[self.user_id]
@ -335,7 +331,7 @@ class Node(Thread):
'id': item.id, 'progress': t.progress 'id': item.id, 'progress': t.progress
}) })
if state.bandwidth: if state.bandwidth:
while not state.bandwidth.download(chunk_size) and self._running: while not state.bandwidth.download(chunk_size) and not state.shutdown:
time.sleep(0.1) time.sleep(0.1)
t2 = datetime.utcnow() t2 = datetime.utcnow()
duration = (t2-t1).total_seconds() duration = (t2-t1).total_seconds()
@ -407,7 +403,6 @@ class Nodes(Thread):
def __init__(self): def __init__(self):
self._q = Queue() self._q = Queue()
self._running = True
with db.session(): with db.session():
for u in user.models.User.query.filter_by(peered=True): for u in user.models.User.query.filter_by(peered=True):
if 'local' in u.info: if 'local' in u.info:
@ -427,7 +422,7 @@ class Nodes(Thread):
self.start() self.start()
def cleanup(self): def cleanup(self):
if self._running and self._local: if not state.shutdown and self._local:
self._local.cleanup() self._local.cleanup()
def pull(self): def pull(self):
@ -479,14 +474,14 @@ class Nodes(Thread):
return return
self._pulling = True self._pulling = True
for node in list(self._nodes.values()): for node in list(self._nodes.values()):
if self._running: if not state.shutdown:
node.online = node.can_connect() node.online = node.can_connect()
if self._running and node.online: if not state.shutdown and node.online:
node.pullChanges() node.pullChanges()
self._pulling = False self._pulling = False
def run(self): def run(self):
while self._running: while not state.shutdown:
args = self._q.get() args = self._q.get()
if args: if args:
if args[0] == 'cleanup': if args[0] == 'cleanup':
@ -499,7 +494,6 @@ class Nodes(Thread):
self._call(*args) self._call(*args)
def join(self): def join(self):
self._running = False
self._q.put(None) self._q.put(None)
for node in list(self._nodes.values()): for node in list(self._nodes.values()):
node.join() node.join()

View file

@ -60,8 +60,12 @@ def log_request(handler):
handler._request_summary(), request_time) handler._request_summary(), request_time)
def shutdown(): def shutdown():
state.shutdown = True
if state.tor: if state.tor:
state.tor._shutdown = True state.tor._shutdown = True
if state.nodes:
logger.debug('shutdown nodes')
state.nodes.join()
if state.downloads: if state.downloads:
logger.debug('shutdown downloads') logger.debug('shutdown downloads')
state.downloads.join() state.downloads.join()
@ -70,9 +74,6 @@ def shutdown():
if state.tasks: if state.tasks:
logger.debug('shutdown tasks') logger.debug('shutdown tasks')
state.tasks.join() state.tasks.join()
if state.nodes:
logger.debug('shutdown nodes')
state.nodes.join()
if state.node: if state.node:
state.node.stop() state.node.stop()
if state.tor: if state.tor:

View file

@ -8,6 +8,7 @@ tasks = False
downloads = False downloads = False
tor = False tor = False
update = False update = False
shutdown = False
websockets = [] websockets = []
activity = {} activity = {}

View file

@ -5,6 +5,7 @@ from queue import Queue
from threading import Thread from threading import Thread
from websocket import trigger_event from websocket import trigger_event
import state
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -13,7 +14,6 @@ class Tasks(Thread):
def __init__(self): def __init__(self):
self.q = Queue() self.q = Queue()
self.connected = True
Thread.__init__(self) Thread.__init__(self)
self.daemon = True self.daemon = True
self.start() self.start()
@ -23,9 +23,9 @@ class Tasks(Thread):
import item.scan import item.scan
from item.models import sync_metadata, get_preview from item.models import sync_metadata, get_preview
from user.models import export_list, update_user_peering from user.models import export_list, update_user_peering
while self.connected: while not state.shutdown:
m = self.q.get() m = self.q.get()
if m and self.connected: if m and not state.shutdown:
try: try:
action, data = m action, data = m
logger.debug('%s start', action) logger.debug('%s start', action)
@ -55,12 +55,11 @@ class Tasks(Thread):
self.q.task_done() self.q.task_done()
def join(self): def join(self):
self.connected = False
self.q.put(None) self.q.put(None)
return Thread.join(self) return Thread.join(self)
def queue(self, action, data=None): def queue(self, action, data=None):
if self.connected: if not state.shutdown:
logger.debug('%s queued', action) logger.debug('%s queued', action)
self.q.put((action, data)) self.q.put((action, data))