fix local peer discovery

This commit is contained in:
j 2024-06-08 13:31:46 +01:00
parent d8cd9ecd4f
commit 2b58800caa
3 changed files with 56 additions and 39 deletions

View file

@ -1,11 +1,13 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import asyncio
import socket import socket
import netifaces import netifaces
from zeroconf import ( from zeroconf import (
ServiceBrowser, ServiceInfo, ServiceStateChange, Zeroconf ServiceBrowser, ServiceInfo, ServiceStateChange
) )
from zeroconf.asyncio import AsyncZeroconf
from tornado.ioloop import PeriodicCallback from tornado.ioloop import PeriodicCallback
import settings import settings
@ -64,87 +66,101 @@ class LocalNodes(dict):
def setup(self): def setup(self):
self.local_ips = get_broadcast_interfaces() self.local_ips = get_broadcast_interfaces()
self.zeroconf = {ip: Zeroconf(interfaces=[ip]) for ip in self.local_ips} self.zeroconf = {ip: AsyncZeroconf(interfaces=[ip]) for ip in self.local_ips}
self.register_service() asyncio.create_task(self.register_service())
self.browse() self.browse()
def _update_if_ip_changed(self): def _update_if_ip_changed(self):
local_ips = get_broadcast_interfaces() local_ips = get_broadcast_interfaces()
username = settings.preferences.get('username', 'anonymous') username = settings.preferences.get('username', 'anonymous')
if local_ips != self.local_ips or self.username != username: if local_ips != self.local_ips or self.username != username:
self.close() asyncio.run(self.close())
self.setup() self.setup()
def browse(self): def browse(self):
self.browser = { self.browser = {
ip: ServiceBrowser(self.zeroconf[ip], self.service_type, handlers=[self.on_service_state_change]) ip: ServiceBrowser(self.zeroconf[ip].zeroconf, self.service_type, handlers=[self.on_service_state_change])
for ip in self.zeroconf for ip in self.zeroconf
} }
def register_service(self): async def register_service(self):
if self.local_info: if self.local_info:
for local_ip, local_info in self.local_info: for local_ip, local_info in self.local_info:
self.zeroconf[local_ip].unregister_service(local_info) self.zeroconf[local_ip].async_unregister_service(local_info)
self.local_info = None self.local_info = None
local_name = socket.gethostname().partition('.')[0] + '.local.' local_name = socket.gethostname().partition('.')[0] + '.local.'
port = settings.server['node_port'] port = settings.server['node_port']
self.local_info = []
self.username = settings.preferences.get('username', 'anonymous') self.username = settings.preferences.get('username', 'anonymous')
desc = { desc = {
'username': self.username 'username': self.username,
'id': settings.USER_ID,
} }
self.local_info = [] tasks = []
for i, local_ip in enumerate(get_broadcast_interfaces()): for i, local_ip in enumerate(get_broadcast_interfaces()):
if i: if i:
name = '%s-%s [%s].%s' % (desc['username'], i+1, settings.USER_ID, self.service_type) name = '%s [%s].%s' % (desc['username'], i, self.service_type)
else: else:
name = '%s [%s].%s' % (desc['username'], settings.USER_ID, self.service_type) name = '%s.%s' % (desc['username'], self.service_type)
local_info = ServiceInfo(self.service_type, name,
socket.inet_aton(local_ip), port, 0, 0, desc, local_name) addresses = [socket.inet_aton(local_ip)]
self.zeroconf[local_ip].register_service(local_info) local_info = ServiceInfo(self.service_type, name, port, 0, 0, desc, local_name, addresses=addresses)
task = self.zeroconf[local_ip].async_register_service(local_info)
tasks.append(task)
self.local_info.append((local_ip, local_info)) self.local_info.append((local_ip, local_info))
await asyncio.gather(*tasks)
def __del__(self): def __del__(self):
self.close() self.close()
def close(self): async def close(self):
if self.local_info: if self.local_info:
tasks = []
for local_ip, local_info in self.local_info: for local_ip, local_info in self.local_info:
try: try:
self.zeroconf[local_ip].unregister_service(local_info) task = self.zeroconf[local_ip].async_unregister_service(local_info)
tasks.append(task)
except: except:
logger.debug('exception closing zeroconf', exc_info=True) logger.debug('exception closing zeroconf', exc_info=True)
self.local_info = None self.local_info = None
if self.zeroconf: if self.zeroconf:
for local_ip in self.zeroconf: for local_ip in self.zeroconf:
try: try:
self.zeroconf[local_ip].close() task = self.zeroconf[local_ip].async_close()
tasks.append(task)
except: except:
logger.debug('exception closing zeroconf', exc_info=True) logger.debug('exception closing zeroconf', exc_info=True)
self.zeroconf = None self.zeroconf = None
for id in list(self): for id in list(self):
self.pop(id, None) self.pop(id, None)
await asyncio.gather(*tasks)
def on_service_state_change(self, zeroconf, service_type, name, state_change): def on_service_state_change(self, zeroconf, service_type, name, state_change):
if '[' not in name: info = zeroconf.get_service_info(service_type, name)
id = name.split('.')[0] if info and b'id' in info.properties:
else: id = info.properties[b'id'].decode()
id = name.split('[')[1].split(']')[0]
if id == settings.USER_ID: if id == settings.USER_ID:
return return
if state_change is ServiceStateChange.Added: if state_change is ServiceStateChange.Added:
info = zeroconf.get_service_info(service_type, name) new = id not in self
if info:
self[id] = { self[id] = {
'id': id, 'id': id,
'host': socket.inet_ntoa(info.address), 'host': socket.inet_ntoa(info.addresses[0]),
'port': info.port 'port': info.port
} }
if info.properties: if info.properties:
for key, value in info.properties.items(): for key, value in info.properties.items():
key = key.decode() key = key.decode()
self[id][key] = value.decode() self[id][key] = value.decode()
logger.debug('add: %s [%s] (%s:%s)', self[id].get('username', 'anon'), id, self[id]['host'], self[id]['port']) logger.debug(
'%s: %s [%s] (%s:%s)',
'add' if new else 'update',
self[id].get('username', 'anon'),
id,
self[id]['host'],
self[id]['port']
)
if state.tasks and id in self: if state.tasks and id in self:
state.tasks.queue('addlocalinfo', self[id]) state.tasks.queue('addlocalinfo', self[id])
elif state_change is ServiceStateChange.Removed: elif state_change is ServiceStateChange.Removed:

View file

@ -622,12 +622,12 @@ class Nodes(Thread):
node.pullChanges() node.pullChanges()
self._pulling = False self._pulling = False
def join(self): async def join(self):
self._q.put(None) self._q.put(None)
for node in list(self._nodes.values()): for node in list(self._nodes.values()):
node.join() node.join()
if self.local: if self.local:
self.local.close() await self.local.close()
return super().join(1) return super().join(1)
def publish_node(): def publish_node():

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import asyncio
import os import os
import sys import sys
import signal import signal
@ -59,13 +60,13 @@ def log_request(handler):
log_method("%d %s %.2fms", handler.get_status(), log_method("%d %s %.2fms", handler.get_status(),
handler._request_summary(), request_time) handler._request_summary(), request_time)
def shutdown(): async def shutdown():
state.shutdown = True state.shutdown = True
if state.tor: if state.tor:
state.tor._shutdown = True state.tor._shutdown = True
if state.nodes: if state.nodes:
logger.debug('shutdown nodes') logger.debug('shutdown nodes')
state.nodes.join() await state.nodes.join()
if state.downloads: if state.downloads:
logger.debug('shutdown downloads') logger.debug('shutdown downloads')
state.downloads.join() state.downloads.join()
@ -197,10 +198,10 @@ def run():
print('open browser at %s' % url) print('open browser at %s' % url)
logger.debug('Starting OML %s at %s', settings.VERSION, url) logger.debug('Starting OML %s at %s', settings.VERSION, url)
signal.signal(signal.SIGTERM, shutdown) signal.signal(signal.SIGTERM, lambda _, __: sys.exit(0))
try: try:
state.main.start() state.main.start()
except: except:
print('shutting down...') print('shutting down...')
shutdown() asyncio.run(shutdown())