From 2b58800caab2a5997dccb8921c8f5965fd24032c Mon Sep 17 00:00:00 2001 From: j Date: Sat, 8 Jun 2024 13:31:46 +0100 Subject: [PATCH] fix local peer discovery --- oml/localnodes.py | 82 ++++++++++++++++++++++++++++------------------- oml/nodes.py | 4 +-- oml/server.py | 9 +++--- 3 files changed, 56 insertions(+), 39 deletions(-) diff --git a/oml/localnodes.py b/oml/localnodes.py index 6254378..4f8b8e1 100644 --- a/oml/localnodes.py +++ b/oml/localnodes.py @@ -1,11 +1,13 @@ # -*- coding: utf-8 -*- +import asyncio import socket import netifaces from zeroconf import ( - ServiceBrowser, ServiceInfo, ServiceStateChange, Zeroconf + ServiceBrowser, ServiceInfo, ServiceStateChange ) +from zeroconf.asyncio import AsyncZeroconf from tornado.ioloop import PeriodicCallback import settings @@ -64,94 +66,108 @@ class LocalNodes(dict): def setup(self): self.local_ips = get_broadcast_interfaces() - self.zeroconf = {ip: Zeroconf(interfaces=[ip]) for ip in self.local_ips} - self.register_service() + self.zeroconf = {ip: AsyncZeroconf(interfaces=[ip]) for ip in self.local_ips} + asyncio.create_task(self.register_service()) self.browse() def _update_if_ip_changed(self): local_ips = get_broadcast_interfaces() username = settings.preferences.get('username', 'anonymous') if local_ips != self.local_ips or self.username != username: - self.close() + asyncio.run(self.close()) self.setup() def browse(self): self.browser = { - ip: ServiceBrowser(self.zeroconf[ip], self.service_type, handlers=[self.on_service_state_change]) + ip: ServiceBrowser(self.zeroconf[ip].zeroconf, self.service_type, handlers=[self.on_service_state_change]) for ip in self.zeroconf } - def register_service(self): + async def register_service(self): if self.local_info: for local_ip, local_info in self.local_info: - self.zeroconf[local_ip].unregister_service(local_info) + self.zeroconf[local_ip].async_unregister_service(local_info) self.local_info = None local_name = socket.gethostname().partition('.')[0] + '.local.' port = settings.server['node_port'] + self.local_info = [] self.username = settings.preferences.get('username', 'anonymous') desc = { - 'username': self.username + 'username': self.username, + 'id': settings.USER_ID, } - self.local_info = [] + tasks = [] for i, local_ip in enumerate(get_broadcast_interfaces()): if i: - name = '%s-%s [%s].%s' % (desc['username'], i+1, settings.USER_ID, self.service_type) + name = '%s [%s].%s' % (desc['username'], i, self.service_type) else: - name = '%s [%s].%s' % (desc['username'], settings.USER_ID, self.service_type) - local_info = ServiceInfo(self.service_type, name, - socket.inet_aton(local_ip), port, 0, 0, desc, local_name) - self.zeroconf[local_ip].register_service(local_info) + name = '%s.%s' % (desc['username'], self.service_type) + + addresses = [socket.inet_aton(local_ip)] + local_info = ServiceInfo(self.service_type, name, port, 0, 0, desc, local_name, addresses=addresses) + task = self.zeroconf[local_ip].async_register_service(local_info) + tasks.append(task) self.local_info.append((local_ip, local_info)) + await asyncio.gather(*tasks) def __del__(self): self.close() - def close(self): + async def close(self): if self.local_info: + tasks = [] for local_ip, local_info in self.local_info: try: - self.zeroconf[local_ip].unregister_service(local_info) + task = self.zeroconf[local_ip].async_unregister_service(local_info) + tasks.append(task) except: logger.debug('exception closing zeroconf', exc_info=True) self.local_info = None if self.zeroconf: for local_ip in self.zeroconf: try: - self.zeroconf[local_ip].close() + task = self.zeroconf[local_ip].async_close() + tasks.append(task) except: logger.debug('exception closing zeroconf', exc_info=True) self.zeroconf = None for id in list(self): self.pop(id, None) + await asyncio.gather(*tasks) def on_service_state_change(self, zeroconf, service_type, name, state_change): - if '[' not in name: - id = name.split('.')[0] - else: - id = name.split('[')[1].split(']')[0] - if id == settings.USER_ID: - return - if state_change is ServiceStateChange.Added: - info = zeroconf.get_service_info(service_type, name) - if info: + info = zeroconf.get_service_info(service_type, name) + if info and b'id' in info.properties: + id = info.properties[b'id'].decode() + if id == settings.USER_ID: + return + if state_change is ServiceStateChange.Added: + new = id not in self self[id] = { 'id': id, - 'host': socket.inet_ntoa(info.address), + 'host': socket.inet_ntoa(info.addresses[0]), 'port': info.port } if info.properties: for key, value in info.properties.items(): key = key.decode() self[id][key] = value.decode() - logger.debug('add: %s [%s] (%s:%s)', self[id].get('username', 'anon'), id, self[id]['host'], self[id]['port']) + logger.debug( + '%s: %s [%s] (%s:%s)', + 'add' if new else 'update', + self[id].get('username', 'anon'), + id, + self[id]['host'], + self[id]['port'] + ) if state.tasks and id in self: state.tasks.queue('addlocalinfo', self[id]) - elif state_change is ServiceStateChange.Removed: - logger.debug('remove: %s', id) - self.pop(id, None) - if state.tasks: - state.tasks.queue('removelocalinfo', id) + elif state_change is ServiceStateChange.Removed: + logger.debug('remove: %s', id) + self.pop(id, None) + if state.tasks: + state.tasks.queue('removelocalinfo', id) def get_data(self, user_id): data = self.get(user_id) diff --git a/oml/nodes.py b/oml/nodes.py index 7487ca3..d6c1112 100644 --- a/oml/nodes.py +++ b/oml/nodes.py @@ -622,12 +622,12 @@ class Nodes(Thread): node.pullChanges() self._pulling = False - def join(self): + async def join(self): self._q.put(None) for node in list(self._nodes.values()): node.join() if self.local: - self.local.close() + await self.local.close() return super().join(1) def publish_node(): diff --git a/oml/server.py b/oml/server.py index 130f120..a557054 100644 --- a/oml/server.py +++ b/oml/server.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- +import asyncio import os import sys import signal @@ -59,13 +60,13 @@ def log_request(handler): log_method("%d %s %.2fms", handler.get_status(), handler._request_summary(), request_time) -def shutdown(): +async def shutdown(): state.shutdown = True if state.tor: state.tor._shutdown = True if state.nodes: logger.debug('shutdown nodes') - state.nodes.join() + await state.nodes.join() if state.downloads: logger.debug('shutdown downloads') state.downloads.join() @@ -197,10 +198,10 @@ def run(): print('open browser at %s' % url) logger.debug('Starting OML %s at %s', settings.VERSION, url) - signal.signal(signal.SIGTERM, shutdown) + signal.signal(signal.SIGTERM, lambda _, __: sys.exit(0)) try: state.main.start() except: print('shutting down...') - shutdown() + asyncio.run(shutdown())