use mdns for local peer discovery

This commit is contained in:
j 2016-03-14 14:31:56 +01:00
parent bea7c57515
commit 417195cfd1
7 changed files with 132 additions and 259 deletions

View file

@ -1,26 +1,23 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# vi:si:et:sw=4:sts=4:ts=4 # vi:si:et:sw=4:sts=4:ts=4
import json
import socket import socket
import struct
import _thread
from threading import Thread
import time
import select
from utils import get_public_ipv6, get_local_ipv4, get_interface from zeroconf import (
from settings import preferences, server, USER_ID get_all_addresses,
import state ServiceBrowser, ServiceInfo, ServiceStateChange, Zeroconf
import db )
import user.models from tornado.ioloop import PeriodicCallback
from tor_request import get_opener
import settings import settings
import state
from tor_request import get_opener
from utils import get_local_ipv4
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def can_connect(data): def can_connect(data):
try: try:
opener = get_opener(data['id']) opener = get_opener(data['id'])
@ -47,235 +44,92 @@ def can_connect(data):
#logger.debug('failed to connect to local node %s', data, exc_info=True) #logger.debug('failed to connect to local node %s', data, exc_info=True)
return False return False
class LocalNodes(dict):
class LocalNodesBase(Thread): service_type = '_oml._tcp.local.'
local_info = None
_PORT = 9851 local_ips = None
_TTL = 1
def __init__(self, nodes):
self._socket = None
self._nodes = nodes
Thread.__init__(self)
if not server['localnode_discovery']:
return
self.daemon = True
self.start()
def get_packet(self):
self.host = self.get_ip()
if self.host:
message = json.dumps({
'id': USER_ID,
'username': preferences.get('username', 'anonymous'),
'host': self.host,
'port': server['node_port']
})
packet = message.encode()
else:
packet = None
return packet
def get_socket(self):
pass
def send(self):
pass
def receive(self):
last = time.time()
s = None
while not s and not state.shutdown:
try:
s = self.get_socket()
s.bind(('', self._PORT))
except OSError: # no local interface exists
self.wait(60)
while not state.shutdown:
try:
r, _, _ = select.select([s], [], [], 3)
if r:
data, addr = s.recvfrom(1024)
if not state.shutdown:
while data[-1] == 0:
data = data[:-1] # Strip trailing \0's
data = self.verify(data)
if data:
self.update_node(data)
except OSError: # no local interface exists
self.wait(60)
except:
if not state.shutdown:
logger.debug('receive failed. restart later', exc_info=True)
self.wait(60)
finally:
if not state.shutdown:
now = time.time()
if now - last > 60:
last = now
_thread.start_new_thread(self.send, ())
def verify(self, data):
try:
message = json.loads(data.decode())
except:
return None
for key in ['id', 'username', 'host', 'port']:
if key not in message:
return None
return message
def update_node(self, data):
#fixme use local link address
#print addr
if data['id'] != USER_ID:
if data['id'] not in self._nodes:
_thread.start_new_thread(self.new_node, (data, ))
elif can_connect(data):
self._nodes[data['id']] = data
def get(self, user_id):
if user_id in self._nodes:
if can_connect(self._nodes[user_id]):
return self._nodes[user_id]
def new_node(self, data):
logger.debug('NEW NODE %s', data)
if can_connect(data):
self._nodes[data['id']] = data
with db.session():
u = user.models.User.get(data['id'])
if u:
if u.info['username'] != data['username']:
u.info['username'] = data['username']
u.update_name()
u.info['local'] = data
u.save()
state.nodes.queue('add', u.id)
self.send()
def get_ip(self):
pass
def run(self):
self.send()
self.receive()
def join(self):
if self._socket:
try:
self._socket.shutdown(socket.SHUT_RDWR)
except OSError:
pass
self._socket.close()
return Thread.join(self)
def wait(self, timeout):
step = min(timeout, 1)
while not state.shutdown and timeout > 0:
time.sleep(step)
timeout -= step
class LocalNodes4(LocalNodesBase):
_BROADCAST = "239.255.255.250"
_TTL = 1
def send(self):
packet = self.get_packet()
if packet:
#logger.debug('send4 %s', packet)
sockaddr = (self._BROADCAST, self._PORT)
s = socket.socket (socket.AF_INET, socket.SOCK_DGRAM)
s.setsockopt (socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, self._TTL)
try:
s.sendto(packet + b'\0', sockaddr)
except OSError:
pass
except:
logger.debug('LocalNodes4.send failed', exc_info=True)
s.close()
def get_socket(self):
s = socket.socket (socket.AF_INET, socket.SOCK_DGRAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if hasattr(socket, 'SO_REUSEPORT'):
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
mreq = struct.pack("=4sl", socket.inet_aton(self._BROADCAST), socket.INADDR_ANY)
s.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)
self._socket = s
return s
def get_ip(self):
return get_local_ipv4()
class LocalNodes6(LocalNodesBase):
_BROADCAST = "ff02::1"
def send(self):
packet = self.get_packet()
if packet:
#logger.debug('send6 %s', packet)
ttl = struct.pack('@i', self._TTL)
address = self._BROADCAST + get_interface()
addrs = socket.getaddrinfo(address, self._PORT, socket.AF_INET6, socket.SOCK_DGRAM)
addr = addrs[0]
(family, socktype, proto, canonname, sockaddr) = addr
s = socket.socket(family, socktype, proto)
s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, ttl)
try:
s.sendto(packet + b'\0', sockaddr)
except:
logger.debug('LocalNodes6.send failed', exc_info=True)
s.close()
def get_socket(self):
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if hasattr(socket, 'SO_REUSEPORT'):
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
group_bin = socket.inet_pton(socket.AF_INET6, self._BROADCAST) + b'\0'*4
s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, group_bin)
self._socket = s
return s
def get_ip(self):
return get_public_ipv6()
class LocalNodes(object):
_nodes4 = None
_nodes6 = None
def __init__(self): def __init__(self):
self._nodes = {} if not settings.server.get('localnode_discovery'):
if not server['localnode_discovery']:
return return
self._nodes4 = LocalNodes4(self._nodes) self.setup()
#self._nodes6 = LocalNodes6(self._nodes) self._ip_changed = PeriodicCallback(self._update_if_ip_changed, 60000)
def cleanup(self): def setup(self):
if not state.shutdown: self.local_ips = sorted(get_all_addresses(socket.AF_INET))
for id in list(self._nodes.keys()): self.zeroconf = Zeroconf()
if not can_connect(self._nodes[id]): self.register_service()
with db.session(): self.browse()
u = user.models.User.get(id)
if u and 'local' in u.info: def _update_if_ip_changed(self):
del u.info['local'] local_ips = sorted(get_all_addresses(socket.AF_INET))
u.save() if local_ips != self.local_ips:
del self._nodes[id] self.close()
if state.shutdown: self.setup()
break
def browse(self):
self.browser = ServiceBrowser(self.zeroconf, self.service_type, handlers=[self.on_service_state_change])
def register_service(self):
if self.local_info:
self.zeroconf.unregister_service(self.local_info)
self.local_info = None
local_ip = get_local_ipv4()
if local_ip:
local_name = socket.gethostname().partition('.')[0] + '.local.'
port = settings.server['node_port']
desc = {
'username': settings.preferences.get('username', 'anonymous'),
}
self.local_info = ServiceInfo(self.service_type,
'%s.%s' % (settings.USER_ID, self.service_type),
socket.inet_aton(local_ip), port, 0, 0,
desc, local_name)
self.zeroconf.register_service(self.local_info)
def __del__(self):
self.close()
def close(self):
if self.local_info:
self.zeroconf.unregister_service(self.local_info)
self.local_info = None
if self.zeroconf:
try:
self.zeroconf.close()
except:
logger.debug('exception closing zeroconf', exc_info=True)
self.zeroconf = None
for id in list(self):
self.pop(id, None)
def on_service_state_change(self, zeroconf, service_type, name, state_change):
id = name.split('.')[0]
if id == settings.USER_ID:
return
if state_change is ServiceStateChange.Added:
info = zeroconf.get_service_info(service_type, name)
if info:
self[id] = {
'id': id,
'host': socket.inet_ntoa(info.address),
'port': info.port
}
if info.properties:
for key, value in info.properties.items():
key = key.decode()
self[id][key] = value.decode()
logger.debug('add localnode: %s', self[id])
if state.tasks:
state.tasks.queue('addlocalinfo', self[id])
elif state_change is ServiceStateChange.Removed:
logger.debug('remove localnode: %s', id)
self.pop(id, None)
if state.tasks:
state.tasks.queue('removelocalinfo', id)
def get(self, user_id): def get(self, user_id):
if user_id in self._nodes: if user_id in self:
if can_connect(self._nodes[user_id]): if can_connect(self[user_id]):
return self._nodes[user_id] return self[user_id]
def join(self):
if self._nodes4:
self._nodes4.join()
if self._nodes6:
self._nodes6.join()

View file

@ -105,8 +105,8 @@ class Node(Thread):
self.port = 9851 self.port = 9851
def get_local(self): def get_local(self):
if self._nodes and self._nodes._local: if self._nodes and self._nodes.local:
return self._nodes._local.get(self.user_id) return self._nodes.local.get(self.user_id)
return None return None
def request(self, action, *args): def request(self, action, *args):
@ -405,7 +405,7 @@ class Node(Thread):
class Nodes(Thread): class Nodes(Thread):
_nodes = {} _nodes = {}
_local = None local = None
_pulling = False _pulling = False
def __init__(self): def __init__(self):
@ -420,9 +420,7 @@ class Nodes(Thread):
for u in user.models.User.query.filter_by(queued=True): for u in user.models.User.query.filter_by(queued=True):
logger.debug('adding queued node... %s', u.id) logger.debug('adding queued node... %s', u.id)
self.queue('add', u.id, True) self.queue('add', u.id, True)
self._local = LocalNodes() self.local = LocalNodes()
self._cleanup = PeriodicCallback(lambda: self.queue('cleanup'), 120000)
self._cleanup.start()
self._pullcb = PeriodicCallback(self.pull, settings.server['pull_interval']) self._pullcb = PeriodicCallback(self.pull, settings.server['pull_interval'])
self._pullcb.start() self._pullcb.start()
Thread.__init__(self) Thread.__init__(self)
@ -435,19 +433,13 @@ class Nodes(Thread):
while not state.shutdown: while not state.shutdown:
args = self._q.get() args = self._q.get()
if args: if args:
if args[0] == 'cleanup': if args[0] == 'add':
self.cleanup()
elif args[0] == 'add':
self._add(*args[1:]) self._add(*args[1:])
elif args[0] == 'pull': elif args[0] == 'pull':
self._pull() self._pull()
else: else:
self._call(*args) self._call(*args)
def cleanup(self):
if not state.shutdown and self._local:
self._local.cleanup()
def queue(self, *args): def queue(self, *args):
self._q.put(list(args)) self._q.put(list(args))
@ -515,8 +507,8 @@ class Nodes(Thread):
self._q.put(None) self._q.put(None)
for node in list(self._nodes.values()): for node in list(self._nodes.values()):
node.join() node.join()
if self._local: if self.local:
self._local.join() self.local.close()
return Thread.join(self) return Thread.join(self)
def publish_node(): def publish_node():

View file

@ -40,7 +40,10 @@ class Tasks(Thread):
self.queue('scan') self.queue('scan')
import item.scan import item.scan
from item.models import sync_metadata, get_preview, get_cover from item.models import sync_metadata, get_preview, get_cover
from user.models import export_list, update_user_peering from user.models import (
export_list, update_user_peering,
add_local_info, remove_local_info,
)
shutdown = False shutdown = False
while not shutdown: while not shutdown:
p, m = self.q.get() p, m = self.q.get()
@ -68,6 +71,10 @@ class Tasks(Thread):
update_user_peering(*data) update_user_peering(*data)
elif action == 'ping': elif action == 'ping':
trigger_event('pong', data) trigger_event('pong', data)
elif action == 'addlocalinfo':
add_local_info(data)
elif action == 'removelocalinfo':
remove_local_info(data)
elif action == 'scan': elif action == 'scan':
item.scan.run_scan() item.scan.run_scan()
elif action == 'scanimport': elif action == 'scanimport':

View file

@ -123,9 +123,9 @@ def getUsers(data):
users.append(u.json()) users.append(u.json())
ids.add(u.id) ids.add(u.id)
if state.nodes: if state.nodes:
for id in state.nodes._local._nodes: for id in state.nodes.local:
if id not in ids: if id not in ids:
n = state.nodes._local._nodes[id].copy() n = state.nodes.local[id].copy()
n['online'] = True n['online'] = True
n['name'] = n['username'] n['name'] = n['username']
users.append(n) users.append(n)

View file

@ -54,8 +54,8 @@ class User(db.Model):
if not user: if not user:
user = cls(id=id, peered=False, online=False) user = cls(id=id, peered=False, online=False)
user.info = {} user.info = {}
if state.nodes and state.nodes._local and id in state.nodes._local._nodes: if state.nodes and state.nodes.local and id in state.nodes.local:
user.info['local'] = state.nodes._local._nodes[id] user.info['local'] = state.nodes.local[id]
user.info['username'] = user.info['local']['username'] user.info['username'] = user.info['local']['username']
user.update_name() user.update_name()
user.save() user.save()
@ -598,3 +598,21 @@ def update_user_peering(user_id, peered, username=None):
if u: if u:
u.update_peering(peered, username) u.update_peering(peered, username)
def remove_local_info(id):
with db.session():
u = User.get(id)
if u and 'local' in u.info:
del u.info['local']
u.save()
u.trigger_status()
def add_local_info(data):
with db.session():
u = User.get(data['id'])
if u:
if u.info['username'] != data['username']:
u.info['username'] = data['username']
u.update_name()
u.info['local'] = data
u.save()
state.nodes.queue('add', u.id)

View file

@ -9,3 +9,4 @@ PyPDF2==1.25.1
pysocks pysocks
stem stem
sqlitedict==1.4.0 sqlitedict==1.4.0
zeroconf

View file

@ -5,3 +5,4 @@ SQLAlchemy==1.0.12
pyopenssl>=0.15 pyopenssl>=0.15
pyCrypto>=2.6.1 pyCrypto>=2.6.1
pillow pillow
netifaces