# -*- coding: utf-8 -*-

import asyncio
import socket

import netifaces
from zeroconf import (
    ServiceBrowser, ServiceInfo, ServiceStateChange
)
from zeroconf.asyncio import AsyncZeroconf
from tornado.ioloop import PeriodicCallback

import settings
import state
from tor_request import get_opener
from utils import time_cache

import logging
logger = logging.getLogger(__name__)


@time_cache(3)
def can_connect(**data):
    try:
        opener = get_opener(data['id'])
        headers = {
            'User-Agent': settings.USER_AGENT,
            'X-Node-Protocol': settings.NODE_PROTOCOL,
            'Accept-Encoding': 'gzip',
        }
        if ':' in data['host']:
            url = 'https://[{host}]:{port}'.format(**data)
        else:
            url = 'https://{host}:{port}'.format(**data)
        opener.addheaders = list(zip(headers.keys(), headers.values()))
        opener.timeout = 1
        r = opener.open(url)
        version = r.headers.get('X-Node-Protocol', None)
        if version != settings.NODE_PROTOCOL:
            logger.debug('version does not match local: %s remote %s (%s)', settings.NODE_PROTOCOL, version, data['id'])
            return False
        c = r.read()
        return True
    except:
        pass
        #logger.debug('failed to connect to local node %s', data, exc_info=True)
    return False

def get_broadcast_interfaces():
    return list(set(
        addr['addr']
        for iface in netifaces.interfaces()
        for addr in netifaces.ifaddresses(iface).get(socket.AF_INET, [])
        if addr.get('netmask') != '255.255.255.255' and addr.get('broadcast')
    ))

class LocalNodes(dict):
    service_type = '_oml._tcp.local.'
    local_info = None
    local_ips = None

    def __init__(self):
        if not settings.server.get('localnode_discovery'):
            return
        self.setup()
        self._ip_changed = PeriodicCallback(self._update_if_ip_changed, 60000)
        state.main.add_callback(self._ip_changed.start)

    def setup(self):
        self.local_ips = get_broadcast_interfaces()
        self.zeroconf = {ip: AsyncZeroconf(interfaces=[ip]) for ip in self.local_ips}
        asyncio.create_task(self.register_service())
        self.browse()

    def _update_if_ip_changed(self):
        local_ips = get_broadcast_interfaces()
        username = settings.preferences.get('username', 'anonymous')
        if local_ips != self.local_ips or self.username != username:
            asyncio.run(self.close())
            self.setup()

    def browse(self):
        self.browser = {
            ip: ServiceBrowser(self.zeroconf[ip].zeroconf, self.service_type, handlers=[self.on_service_state_change])
            for ip in self.zeroconf
        }

    async def register_service(self):
        if self.local_info:
            for local_ip, local_info in self.local_info:
                self.zeroconf[local_ip].async_unregister_service(local_info)
            self.local_info = None

        local_name = socket.gethostname().partition('.')[0] + '.local.'
        port = settings.server['node_port']
        self.local_info = []
        self.username = settings.preferences.get('username', 'anonymous')
        desc = {
            'username': self.username,
            'id': settings.USER_ID,
        }
        tasks = []
        for i, local_ip in enumerate(get_broadcast_interfaces()):
            if i:
                name = '%s [%s].%s' % (desc['username'], i, self.service_type)
            else:
                name = '%s.%s' % (desc['username'], self.service_type)

            addresses = [socket.inet_aton(local_ip)]
            local_info = ServiceInfo(self.service_type, name, port, 0, 0, desc, local_name, addresses=addresses)
            task = self.zeroconf[local_ip].async_register_service(local_info)
            tasks.append(task)
            self.local_info.append((local_ip, local_info))
        await asyncio.gather(*tasks)

    def __del__(self):
        self.close()

    async def close(self):
        if self.local_info:
            tasks = []
            for local_ip, local_info in self.local_info:
                try:
                    task = self.zeroconf[local_ip].async_unregister_service(local_info)
                    tasks.append(task)
                except:
                    logger.debug('exception closing zeroconf', exc_info=True)
            self.local_info = None
        if self.zeroconf:
            for local_ip in self.zeroconf:
                try:
                    task = self.zeroconf[local_ip].async_close()
                    tasks.append(task)
                except:
                    logger.debug('exception closing zeroconf', exc_info=True)
            self.zeroconf = None
        for id in list(self):
            self.pop(id, None)
        await asyncio.gather(*tasks)

    def on_service_state_change(self, zeroconf, service_type, name, state_change):
        info = zeroconf.get_service_info(service_type, name)
        if info and b'id' in info.properties:
            id = info.properties[b'id'].decode()
            if id == settings.USER_ID:
                return
            if len(id) != settings.ID_LENGTH:
                return
            if state_change is ServiceStateChange.Added:
                new = id not in self
                self[id] = {
                    'id': id,
                    'host': socket.inet_ntoa(info.addresses[0]),
                    'port': info.port
                }
                if info.properties:
                    for key, value in info.properties.items():
                        key = key.decode()
                        self[id][key] = value.decode()
                logger.debug(
                    '%s: %s [%s] (%s:%s)',
                    'add' if new else 'update',
                    self[id].get('username', 'anon'),
                    id,
                    self[id]['host'],
                    self[id]['port']
                )
                if state.tasks and id in self:
                    state.tasks.queue('addlocalinfo', self[id])
            elif state_change is ServiceStateChange.Removed:
                logger.debug('remove: %s', id)
                self.pop(id, None)
                if state.tasks:
                    state.tasks.queue('removelocalinfo', id)

    def get_data(self, user_id):
        data = self.get(user_id)
        if data and can_connect(**data):
            return data
        return None