cleanly shut down

This commit is contained in:
j 2014-08-09 20:32:41 +02:00
parent 9ecf102c01
commit 4b790722ae
6 changed files with 37 additions and 14 deletions

View file

@ -9,6 +9,7 @@ from sqlalchemy.ext.declarative import declarative_base
import settings import settings
import state import state
engine = create_engine('sqlite:////%s' % settings.db_path) engine = create_engine('sqlite:////%s' % settings.db_path)
Session = scoped_session(sessionmaker(bind=engine)) Session = scoped_session(sessionmaker(bind=engine))
@ -43,10 +44,13 @@ def session():
else: else:
state.db.session = Session() state.db.session = Session()
state.db.count = 1 state.db.count = 1
yield try:
state.db.count -= 1 yield state.db.session
if not state.db.count: finally:
Session.remove() state.db.count -= 1
if not state.db.count:
state.db.session.close()
Session.remove()
class MutableDict(Mutable, dict): class MutableDict(Mutable, dict):
@classmethod @classmethod

View file

@ -50,6 +50,4 @@ class Downloads(Thread):
def join(self): def join(self):
self._running = False self._running = False
self._q.put(None)
return Thread.join(self) return Thread.join(self)

View file

@ -39,6 +39,7 @@ class LocalNodesBase(Thread):
_TTL = 1 _TTL = 1
def __init__(self, nodes): def __init__(self, nodes):
self._socket = None
self._active = True self._active = True
self._nodes = nodes self._nodes = nodes
Thread.__init__(self) Thread.__init__(self)
@ -71,11 +72,12 @@ class LocalNodesBase(Thread):
s.bind(('', self._PORT)) s.bind(('', self._PORT))
while self._active: while self._active:
data, addr = s.recvfrom(1024) data, addr = s.recvfrom(1024)
while data[-1] == '\0': if self._active:
data = data[:-1] # Strip trailing \0's while data[-1] == '\0':
data = self.verify(data) data = data[:-1] # Strip trailing \0's
if data: data = self.verify(data)
self.update_node(data) if data:
self.update_node(data)
except: except:
logger.debug('receive failed. restart later', exc_info=1) logger.debug('receive failed. restart later', exc_info=1)
time.sleep(10) time.sleep(10)
@ -133,6 +135,12 @@ class LocalNodesBase(Thread):
def join(self): def join(self):
self._active = False self._active = False
if self._socket:
try:
self._socket.shutdown(socket.SHUT_RDWR)
except:
pass
self._socket.close()
return Thread.join(self) return Thread.join(self)
class LocalNodes4(LocalNodesBase): class LocalNodes4(LocalNodesBase):
@ -156,6 +164,7 @@ class LocalNodes4(LocalNodesBase):
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
mreq = struct.pack("=4sl", socket.inet_aton(self._BROADCAST), socket.INADDR_ANY) mreq = struct.pack("=4sl", socket.inet_aton(self._BROADCAST), socket.INADDR_ANY)
s.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq) s.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)
self._socket = s
return s return s
def get_ip(self): def get_ip(self):
@ -186,6 +195,7 @@ class LocalNodes6(LocalNodesBase):
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
group_bin = socket.inet_pton(socket.AF_INET6, self._BROADCAST) + '\0'*4 group_bin = socket.inet_pton(socket.AF_INET6, self._BROADCAST) + '\0'*4
s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, group_bin) s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, group_bin)
self._socket = s
return s return s
def get_ip(self): def get_ip(self):

View file

@ -67,7 +67,7 @@ class Node(Thread):
def join(self): def join(self):
self._running = False self._running = False
self.ping() self.ping()
return Thread.join(self) #return Thread.join(self)
def ping(self): def ping(self):
self._q.put('') self._q.put('')
@ -409,4 +409,5 @@ class Nodes(Thread):
self._q.put(None) self._q.put(None)
for node in self._nodes.values(): for node in self._nodes.values():
node.join() node.join()
self._local.join()
return Thread.join(self) return Thread.join(self)

View file

@ -92,4 +92,14 @@ def run():
host = settings.server['address'] host = settings.server['address']
url = 'http://%s:%s/' % (host, settings.server['port']) url = 'http://%s:%s/' % (host, settings.server['port'])
print 'open browser at %s' % url print 'open browser at %s' % url
state.main.start() try:
state.main.start()
except:
print 'shutting down...'
if state.downloads:
state.downloads.join()
if state.tasks:
state.tasks.join()
if state.nodes:
state.nodes.join()

View file

@ -41,7 +41,7 @@ class Tasks(Thread):
def join(self): def join(self):
self.connected = False self.connected = False
self.put(None) self.q.put(None)
self.q.join() self.q.join()
return Thread.join(self) return Thread.join(self)