# -*- coding: utf-8 -*-

from io import BytesIO
from queue import Queue
from threading import Thread
import gzip
import json
import os
import socket
import socks
import time
import urllib.error
import urllib.parse
import urllib.request

import ox
from tornado.ioloop import PeriodicCallback

import settings
import user.models

from websocket import trigger_event
from localnodes import LocalNodes
from tor_request import get_opener
from utils import user_sort_key, get_peer
import state
import db
import library

import logging
logger = logging.getLogger(__name__)

DEBUG_NODES = False

class Node(Thread):
    host = None
    local = None
    _online = None
    TIMEOUT = 5

    def __init__(self, nodes, user_id):
        self._nodes = nodes
        self.user_id = user_id
        self._opener = get_opener(self.user_id)
        self._q = Queue()
        self._pingcb = PeriodicCallback(self.ping, 10 * settings.server['pull_interval'])
        state.main.add_callback(self._pingcb.start)
        Thread.__init__(self)
        self.daemon = True
        self.start()

    def run(self):
        self.ping()
        while not state.shutdown:
            action = self._q.get()
            if state.shutdown:
                break
            if action == 'ping':
                self.online = self.can_connect()
            elif action == 'send_response':
                if self.online:
                    self._send_response()
                else:
                    if not self._q.qsize():
                        time.sleep(5)
                    self.send_response()
            elif isinstance(action, list) and len(action) == 2:
                if self.online:
                    self._call(action[0], *action[1])
                else:
                    if not self._q.qsize():
                        time.sleep(5)
                    self.queue(action[0], *action[1])
            else:
                logger.debug('unknown action %s', action)

    def join(self):
        self._q.put('')
        #return Thread.join(self)

    def ping(self):
        if state.online or self.get_local():
            self._q.put('ping')

    def queue(self, action, *args):
        logger.debug('queue node action %s->%s%s', self.user_id, action, args)
        self._q.put([action, args])

    def _call(self, action, *args):
        r = getattr(self, action)(*args)
        logger.debug('call node api %s->%s%s = %s', self.user_id, action, args, r)

    @property
    def url(self):
        if self.local:
            if ':' in self.local:
                url = 'https://[%s]:%s' % (self.local, self.port)
            else:
                url = 'https://%s:%s' % (self.local, self.port)
        else:
            url = 'https://%s.onion:9851' % self.user_id
        return url

    @property
    def online(self):
        return self._online

    @online.setter
    def online(self, online):
        if self._online != online:
            self._online = online
            self.trigger_status()
        else:
            self._online = online

    def resolve(self):
        #logger.debug('resolve node %s', self.user_id)
        r = self.get_local()
        if r:
            self.local = r['host']
            if 'port' in r:
                self.port = r['port']
        else:
            self.local = None
            self.port = 9851

    def is_local(self):
        return self._nodes and self.user_id in self._nodes.local

    def get_local(self):
        if self._nodes and self._nodes.local:
            return self._nodes.local.get_data(self.user_id)
        return None

    def request(self, action, *args):
        self.resolve()
        url = self.url
        if self.local:
            logger.debug('request:%s(%s:%s): %s%s', self.user_id, self.local, self.port, action, list(args))
        else:
            logger.debug('request:%s: %s%s', self.user_id, action, list(args))
        content = json.dumps([action, args]).encode()
        headers = {
            'User-Agent': settings.USER_AGENT,
            'X-Node-Protocol': settings.NODE_PROTOCOL,
            'Accept': 'text/plain',
            'Accept-Encoding': 'gzip',
            'Content-Type': 'application/json',
        }
        self._opener.addheaders = list(zip(headers.keys(), headers.values()))
        #logger.debug('headers: %s', self._opener.addheaders)
        try:
            r = self._opener.open(url, data=content, timeout=self.TIMEOUT*12)
        except urllib.error.HTTPError as e:
            if e.code == 403:
                logger.debug('403: %s (%s)', url, self.user_id)
                if state.tasks:
                    state.tasks.queue('peering', (self.user_id, False))
                del self._nodes._nodes[self.user_id]
                self.online = False
                return None
            logger.debug('urllib2.HTTPError %s %s', e, e.code)
            self.online = False
            return None
        except urllib.error.URLError as e:
            logger.debug('urllib2.URLError %s', e)
            self.online = False
            return None
        except socket.timeout:
            logger.debug('timeout %s', url)
            self.online = False
            return None
        except:
            logger.debug('unknown url error', exc_info=True)
            self.online = False
            return None
        data = r.read()
        if r.headers.get('content-encoding', None) == 'gzip':
            data = gzip.GzipFile(fileobj=BytesIO(data)).read()

        version = r.headers.get('X-Node-Protocol', None)
        if version != settings.NODE_PROTOCOL:
            logger.debug('version does not match local: %s remote %s (%s)', settings.NODE_PROTOCOL, version, self.user_id)
            self.online = False
            if version > settings.NODE_PROTOCOL:
                state.update_required = True
            return None

        response = json.loads(data.decode('utf-8'))
        return response

    def can_connect(self):
        self.resolve()
        url = self.url
        if not state.online and not self.local:
            return False
        try:
            if url:
                headers = {
                    'User-Agent': settings.USER_AGENT,
                    'X-Node-Protocol': settings.NODE_PROTOCOL,
                    'Accept-Encoding': 'gzip',
                }
                self._opener.addheaders = list(zip(headers.keys(), headers.values()))
                self._opener.timeout = 2
                r = self._opener.open(url)
                version = r.headers.get('X-Node-Protocol', None)
                if version != settings.NODE_PROTOCOL:
                    logger.debug('version does not match local: %s remote %s', settings.NODE_PROTOCOL, version)
                    return False
                c = r.read()
                if DEBUG_NODES:
                    logger.debug('can connect to: %s', url)
                return True
        except:
            if DEBUG_NODES:
                logger.debug('can not connect to: %s', url)
            pass
        return False

    def is_online(self):
        return self.online or self.is_local()

    def send_response(self):
        self._q.put('send_response')

    def _send_response(self):
        with db.session():
            u = user.models.User.get(self.user_id)
            if u:
                user_pending = u.pending
                user_peered = u.peered
                user_queued = u.queued
            else:
                user_queued = False
            if DEBUG_NODES:
                logger.debug('go online peered=%s queued=%s %s (%s)', u.peered, u.queued, u.id, u.nickname)

        if user_queued:
            if DEBUG_NODES:
                logger.debug('connected to %s', self.url)
                logger.debug('queued peering event pending=%s peered=%s', user_pending, user_peered)
            if user_pending == 'sent':
                self.peering('requestPeering')
            elif user_pending == '' and user_peered:
                self.peering('acceptPeering')
            else:
                #fixme, what about cancel/reject peering here?
                self.peering('removePeering')

    def trigger_status(self):
        if self.online is not None:
            trigger_event('status', {
                'id': self.user_id,
                'online': self.online
            })

    def pullChanges(self):
        if state.shutdown:
            return
        self.online = self.can_connect()
        if not self.online or state.shutdown:
            return
        self.resolve()
        peer = get_peer(self.user_id)
        path = peer._logpath
        if os.path.exists(path):
            size = os.path.getsize(path)
        else:
            size = 0
        url = '%s/log' % self.url
        if DEBUG_NODES:
            logger.debug('pullChanges: %s [%s]', self.user_id, url)
        headers = self.headers.copy()
        if size:
            headers['Range'] = '%s-' % size
        self._opener.addheaders = list(zip(headers.keys(), headers.values()))
        try:
            r = self._opener.open(url, timeout=self.TIMEOUT*60)
        except urllib.error.HTTPError as e:
            if e.code == 403:
                logger.debug('pullChanges 403: %s (%s)', url, self.user_id)
                if state.tasks:
                    state.tasks.queue('peering', (self.user_id, False))
                del self._nodes._nodes[self.user_id]
                self.online = False
            else:
                logger.debug('unknown http errpr %s %s (%s)', e.code, url, self.user_id)
            return False
        except socket.timeout:
            logger.debug('timeout %s', url)
            return False
        except socks.GeneralProxyError:
            logger.debug('openurl failed %s', url)
            return False
        except urllib.error.URLError as e:
            logger.debug('openurl failed urllib2.URLError %s', e.reason)
            return False
        except:
            logger.debug('openurl failed %s', url, exc_info=True)
            return False
        if r.getcode() in (200, 206):
            changed = False
            chunk_size = 16 * 1024
            mode = 'ab' if r.getcode() == 206 else 'wb'
            content = b''

            try:
                if r.headers.get('content-encoding', None) == 'gzip':
                    fileobj = gzip.GzipFile(fileobj=r)
                else:
                    fileobj = r
                for chunk in iter(lambda: fileobj.read(chunk_size), b''):
                    content += chunk
                    eol = content.rfind(b'\n') + 1
                    if eol > 0:
                        with open(path, mode) as fd:
                            fd.write(content[:eol])
                        content = content[eol:]
                        mode = 'ab'
                        changed = True
                    if state.shutdown:
                        return False
                    if state.bandwidth:
                        while not state.bandwidth.download(chunk_size) and not state.shutdown:
                            time.sleep(0.1)
                if content:
                    with open(path, mode) as fd:
                        fd.write(content)
                    changed = True
                if changed:
                    peer.apply_log()
            except:
                logger.debug('download failed %s', url, exc_info=True)
                return False
        else:
            logger.debug('FAILED %s', url)
            return False

    def peering(self, action):
        pull_changes = False
        with db.session():
            u = user.models.User.get_or_create(self.user_id)
            user_info = u.info
        if action in ('requestPeering', 'acceptPeering'):
            r = self.request(action, settings.preferences['username'], user_info.get('message'))
        else:
            r = self.request(action, user_info.get('message'))
        if r is not None:
            with db.session():
                u = user.models.User.get(self.user_id)
                u.queued = False
                if 'message' in u.info:
                    del u.info['message']
                u.save()
            if action == 'acceptPeering':
                pull_changes = True
        else:
            logger.debug('peering failed? %s %s', action, r)
        if action in ('cancelPeering', 'rejectPeering', 'removePeering'):
            self.online = False
        with db.session():
            u = user.models.User.get(self.user_id)
            trigger_event('peering.%s' % action.replace('Peering', ''), u.json())
        if pull_changes:
            self.pullChanges()
        return True

    headers = {
        'X-Node-Protocol': settings.NODE_PROTOCOL,
        'User-Agent': settings.USER_AGENT,
        'Accept-Encoding': 'gzip',
    }

    def download(self, item):
        self.resolve()
        url = '%s/get/%s' % (self.url, item.id)
        if DEBUG_NODES:
            logger.debug('download %s', url)
        self._opener.addheaders = list(zip(self.headers.keys(), self.headers.values()))
        try:
            r = self._opener.open(url, timeout=self.TIMEOUT*5)
        except socket.timeout:
            logger.debug('timeout %s', url)
            return False
        except socks.GeneralProxyError:
            logger.debug('openurl failed %s', url)
            return False
        except urllib.error.URLError as e:
            logger.debug('openurl failed urllib2.URLError %s', e.reason)
            return False
        except:
            logger.debug('openurl failed %s', url, exc_info=True)
            return False
        if r.getcode() == 200:
            try:
                if r.headers.get('content-encoding', None) == 'gzip':
                    fileobj = gzip.GzipFile(fileobj=r)
                else:
                    fileobj = r
                content = []
                ct = time.time()
                size = item.info['size']
                received = 0
                chunk_size = 16*1024
                for chunk in iter(lambda: fileobj.read(chunk_size), b''):
                    content.append(chunk)
                    received += len(chunk)
                    if time.time() - ct > 1:
                        ct = time.time()
                        if state.shutdown:
                            return False
                        t = state.downloads.transfers.get(item.id)
                        if not t: # transfer was canceled
                            trigger_event('transfer', {
                                'id': item.id, 'progress': -1
                            })
                            return False
                        else:
                            t['progress'] = received / size
                            trigger_event('transfer', {
                                'id': item.id, 'progress': t['progress']
                            })
                            state.downloads.transfers[item.id] = t
                    if state.bandwidth:
                        while not state.bandwidth.download(chunk_size) and not state.shutdown:
                            time.sleep(0.1)
                return item.save_file(b''.join(content))
            except:
                logger.debug('download failed %s', url, exc_info=True)
                return False
        else:
            logger.debug('FAILED %s', url)
            return False

    def download_preview(self, item_id):
        from item.icons import icons
        self.resolve()
        if DEBUG_NODES:
            logger.debug('download preview for %s from %s', item_id, self.url)
        url = '%s/preview/%s' % (self.url, item_id)
        self._opener.addheaders = list(zip(self.headers.keys(), self.headers.values()))
        try:
            r = self._opener.open(url, timeout=self.TIMEOUT*2)
        except socket.timeout:
            logger.debug('timeout %s', url)
            return False
        except socks.GeneralProxyError:
            logger.debug('download failed %s', url)
            return False
        except:
            logger.debug('download failed %s', url, exc_info=True)
            self.online = False
            return False
        code = r.getcode()
        if code == 200:
            try:
                if r.headers.get('content-encoding', None) == 'gzip':
                    fileobj = gzip.GzipFile(fileobj=r)
                else:
                    fileobj = r
                content = fileobj.read()
                key = 'preview:' + item_id
                icons[key] = content
                icons.clear(key+':')
                return True
            except:
                logger.debug('preview download failed %s', url, exc_info=True)
        elif code == 404:
            pass
        else:
            logger.debug('FAILED %s', url)
        return False

    def download_upgrade(self, release):
        for module in release['modules']:
            path = os.path.join(settings.update_path, release['modules'][module]['name'])
            if not os.path.exists(path):
                url = '%s/oml/%s' % (self.url, release['modules'][module]['name'])
                sha1 = release['modules'][module]['sha1']
                headers = {
                    'User-Agent': settings.USER_AGENT,
                }
                self._opener.addheaders = list(zip(headers.keys(), headers.values()))
                r = self._opener.open(url)
                if r.getcode() == 200:
                    with open(path, 'w') as fd:
                        fd.write(r.read())
                        if (ox.sha1sum(path) != sha1):
                            logger.error('invalid update!')
                            os.unlink(path)
                            return False
                else:
                    return False

    def upload(self, items):
        logger.debug('add items to %s\'s inbox: %s', self.user_id, items)
        r = self.request('upload', items)
        return bool(r)


class Nodes(Thread):
    _nodes = {}
    local = None
    _pulling = False

    def __init__(self):
        self._q = Queue()
        with db.session():
            for u in user.models.User.query.filter_by(peered=True):
                self.queue('add', u.id)
                get_peer(u.id)
            for u in user.models.User.query.filter_by(queued=True):
                logger.debug('adding queued node... %s', u.id)
                self.queue('add', u.id, True)
        self.local = LocalNodes()
        self._pullcb = PeriodicCallback(self.pull, settings.server['pull_interval'])
        state.main.add_callback(self._pullcb.start)
        Thread.__init__(self)
        self.daemon = True
        self.start()

    def run(self):
        library.sync_db()
        self.queue('pull')
        while not state.shutdown:
            args = self._q.get()
            if args:
                if DEBUG_NODES:
                    logger.debug('processing nodes queue: next: "%s", %s entries in queue', args[0], self._q.qsize())
                if args[0] == 'add':
                    self._add(*args[1:])
                elif args[0] == 'pull':
                    self._pull()
                else:
                    self._call(*args)

    def queue(self, *args):
        if args and DEBUG_NODES:
            logger.debug('queue "%s", %s entries in queue', args, self._q.qsize())
        self._q.put(list(args))

    def peer_queue(self, peer, action, *args):
        if peer not in self._nodes:
            self._add(peer)
        elif not self._nodes[peer].is_online():
            self._nodes[peer].ping()
        self._nodes[peer].queue(action, *args)

    def is_online(self, id):
        return id in self._nodes and self._nodes[id].is_online()

    def download(self, id, item):
        return id in self._nodes and self._nodes[id].download(item)

    def download_preview(self, id, item):
        return id in self._nodes and \
            self._nodes[id].is_online() and \
            self._nodes[id].download_preview(item)

    def _call(self, target, action, *args):
        if target == 'all':
            nodes = list(self._nodes.values())
        elif target == 'peered':
            ids = []
            with db.session():
                from user.models import User
                for u in User.query.filter(User.id != settings.USER_ID).filter_by(peered=True).all():
                    ids.append(u.id)
            nodes = [n for n in list(self._nodes.values()) if n.user_id in ids]
        elif target == 'online':
            nodes = [n for n in list(self._nodes.values()) if n.online]
        else:
            if target not in self._nodes:
                self._add(target)
            nodes = [self._nodes[target]]
        for node in nodes:
            node._call(action, *args)

    def _add(self, user_id, send_response=False):
        if user_id not in self._nodes:
            from user.models import User
            with db.session():
                User.get_or_create(user_id)
            self._nodes[user_id] = Node(self, user_id)
        else:
            self._nodes[user_id].ping()
        if send_response:
            self._nodes[user_id].send_response()

    def pull(self):
        if not self._pulling:
            self.queue('pull')

    def _pull(self):
        if not state.sync_enabled or settings.preferences.get('downloadRate') == 0:
            return
        if state.activity and state.activity.get('activity') == 'import':
            return
        self._pulling = True
        if state.shutdown:
            return
        users = []
        with db.session():
            from user.models import User
            for u in User.query.filter(User.id != settings.USER_ID).filter_by(peered=True).all():
                users.append(u.json(['id', 'index', 'name']))
        users.sort(key=user_sort_key)
        for u in users:
            if state.shutdown:
                break
            node = self._nodes.get(u['id'])
            if node and node.is_online():
                node.pullChanges()
        self._pulling = False

    def join(self):
        self._q.put(None)
        for node in list(self._nodes.values()):
            node.join()
        if self.local:
            self.local.close()
        return super().join(1)

def publish_node():
    update_online()
    state._online = PeriodicCallback(update_online, 60000)
    state._online.start()

def update_online():
    online = state.tor and state.tor.is_online()
    if online != state.online:
        state.online = online
        trigger_event('status', {
            'id': settings.USER_ID,
            'online': state.online
        })
        if state.online:
            for node in list(state.nodes._nodes.values()):
                node.trigger_status()