# -*- coding: utf-8 -*-
from socketserver import ThreadingMixIn
from threading import Thread
import base64
import gzip
import hashlib
import http.server
import io
import json
import os
import socket
import socketserver
import time

from OpenSSL.SSL import (
    Connection,
    Context,
    TLSv1_2_METHOD,
    VERIFY_CLIENT_ONCE,
    VERIFY_FAIL_IF_NO_PEER_CERT,
    VERIFY_PEER,
)

import db
import settings
import state
import user
import utils
from changelog import changelog_size, changelog_path
from websocket import trigger_event

from . import nodeapi
from .sslsocket import fileobject

import logging
logger = logging.getLogger(__name__)


def get_service_id(connection):
    certs = connection.get_peer_cert_chain()
    for cert in certs:
        if cert.get_signature_algorithm().decode() == "ED25519":
            pubkey = cert.get_pubkey()
            public_key = pubkey.to_cryptography_key().public_bytes_raw()
            service_id = utils.get_onion(public_key)
            return service_id
    raise Exception("connection with invalid certificate")

class TLSTCPServer(socketserver.TCPServer):

    def _accept(self, connection, x509, errnum, errdepth, ok):
        # client_id is validated in request
        return True

    def __init__(self, server_address, HandlerClass, bind_and_activate=True):
        socketserver.TCPServer.__init__(self, server_address, HandlerClass)
        ctx = Context(TLSv1_2_METHOD)
        ctx.use_privatekey_file(settings.ssl_key_path)
        ctx.use_certificate_chain_file(settings.ssl_cert_path)
        # only allow clients with cert:
        ctx.set_verify(VERIFY_PEER | VERIFY_CLIENT_ONCE | VERIFY_FAIL_IF_NO_PEER_CERT, self._accept)
        #ctx.set_verify(VERIFY_PEER | VERIFY_CLIENT_ONCE, self._accept)
        self.socket = Connection(ctx, socket.socket(self.address_family, self.socket_type))
        if bind_and_activate:
            self.server_bind()
            self.server_activate()

    def shutdown_request(self, request):
        try:
            request.shutdown()
        except:
            pass

class NodeServer(ThreadingMixIn, TLSTCPServer):
    _running = True
    allow_reuse_address = True


def api_call(action, user_id, args):
    with db.session():
        u = user.models.User.get(user_id)
        if action in (
            'requestPeering', 'acceptPeering', 'rejectPeering',
            'removePeering', 'cancelPeering'
        ) or (u and u.peered):
            content = getattr(nodeapi, 'api_' + action)(user_id, *args)
        else:
            if u and u.pending:
                logger.debug('ignore request from pending peer[%s] %s (%s) (pending state: %s)',
                             user_id, action, args, u.pending)
                content = {}
            else:
                content = None
    return content

class Handler(http.server.SimpleHTTPRequestHandler):

    def setup(self):
        self.connection = self.request
        self.rfile = fileobject(self.connection, 'rb', self.rbufsize)
        self.wfile = fileobject(self.connection, 'wb', self.wbufsize)

    def version_string(self):
        return settings.USER_AGENT

    def log_message(self, format, *args):
        if settings.DEBUG_HTTP:
            logger.debug("%s - - [%s] %s\n", self.address_string(),
                         self.log_date_time_string(), format % args)

    def do_HEAD(self):
        return self.do_GET()

    def do_GET(self):
        user_id = get_service_id(self.connection)
        import item.models
        parts = self.path.split('/')
        if len(parts) == 3 and parts[1] in ('get', 'preview'):
            id = parts[2]
            preview = parts[1] == 'preview'
        else:
            id = None
        if id and len(id) == 32 and id.isalnum():
            path = None
            data = None
            with db.session():
                if preview:
                    from item.icons import get_icon_sync
                    try:
                        content = get_icon_sync(id, 'preview', 512)
                    except:
                        content = None
                    if content:
                        self.send_response(200, 'ok')
                        mimetype = 'image/jpg'
                    else:
                        self.send_response(404, 'Not Found')
                        content = b'404 - Not Found'
                        mimetype = 'text/plain'
                else:
                    file = item.models.File.get(id)
                    if file:
                        path = file.fullpath()
                        mimetype = {
                            'epub': 'application/epub+zip',
                            'pdf': 'application/pdf',
                            'txt': 'text/plain',
                        }.get(path.split('.')[-1], None)
                        self.send_response(200, 'OK')
                    else:
                        self.send_response(404, 'Not Found')
                        content = b'404 - Not Found'
                        mimetype = 'text/plain'
            self.send_header('Content-Type', mimetype)
            self.send_header('X-Node-Protocol', settings.NODE_PROTOCOL)
            if mimetype == 'text/plain' and path:
                with open(path, 'rb') as f:
                    content = f.read()
                content = self.gzip_data(content)
                content_length = len(content)
            elif path:
                content = None
                content_length = os.path.getsize(path)
            elif content:
                content_length = len(content)
            else:
                content_length = 0
            self.send_header('Content-Length', str(content_length))
            self.end_headers()
            if content:
                self.write_with_limit(content, content_length)
            elif path:
                self.write_file_with_limit(path, content_length)
        elif len(parts) == 2 and parts[1] == 'log':
            self._changelog()
        else:
            self.send_response(200, 'OK')
            self.send_header('Content-type', 'text/plain')
            self.send_header('X-Node-Protocol', settings.NODE_PROTOCOL)
            self.end_headers()
            self.wfile.write('Open Media Library\n'.encode())

    def _denied(self):
        self.send_response(403, 'denied')
        self.end_headers()

    def _changelog(self):
        user_id = get_service_id(self.connection)
        with db.session():
            u = user.models.User.get(user_id)
            if not u:
                return self._denied()
            if not u.peered and u.pending == 'sent':
                u.update_peering(True)
                state.nodes.queue('add', u.id, True)
                trigger_event('peering.accept', u.json())
            if u.pending:
                logger.debug('ignore request from pending peer[%s] changelog (pending sate: %s)', user_id, u.pending)
                return self._denied()
            if not u.peered:
                return self._denied()
        path = changelog_path()
        content_length = changelog_size()
        with open(path, 'rb') as log:
            request_range = self.headers.get('Range', '')
            if request_range:
                r = request_range.split('=')[-1].split('-')
                start = int(r[0])
                end = int(r[1]) if r[1] else (content_length - 1)
                if start == content_length:
                    content_length = 0
                else:
                    content_length = end - start + 1
                if content_length < 0:
                    content_length = os.path.getsize(path)
                    self.send_response(200, 'OK')
                else:
                    log.seek(start)
                    self.send_response(206, 'OK')
            else:
                self.send_response(200, 'OK')
            self.send_header('Content-type', 'text/json')
            self.send_header('X-Node-Protocol', settings.NODE_PROTOCOL)
            self.send_header('Content-Length', str(content_length))
            self.end_headers()
            self.write_fd_with_limit(log, content_length)

    def gzip_data(self, data):
        encoding = self.headers.get('Accept-Encoding')
        if encoding.find('gzip') != -1:
            self.send_header('Content-Encoding', 'gzip')
            bytes_io = io.BytesIO()
            gzip_file = gzip.GzipFile(fileobj=bytes_io, mode='wb')
            gzip_file.write(data)
            gzip_file.close()
            result = bytes_io.getvalue()
            bytes_io.close()
            return result
        else:
            return data

    def gunzip_data(self, data):
        bytes_io = io.BytesIO(data)
        gzip_file = gzip.GzipFile(fileobj=bytes_io, mode='rb')
        result = gzip_file.read()
        gzip_file.close()
        return result

    def do_POST(self):
        '''
            API
            requestPeering  username message
            acceptPeering   username message
            rejectPeering   message
            removePeering   message

            ping            responds public ip
        '''
        user_id = get_service_id(self.connection)

        content = {}
        try:
            content_len = int(self.headers.get('content-length', 0))
            data = self.rfile.read(content_len)
            if self.headers.get('Content-Encoding') == 'gzip':
                data = self.gunzip_data(data)
        except:
            logger.debug('invalid request', exc_info=True)
            response_status = (500, 'invalid request')
            self.write_response(response_status, content)
            return

        response_status = (200, 'OK')
        if self.headers.get('X-Node-Protocol', '') > settings.NODE_PROTOCOL:
            state.update_required = True
        if self.headers.get('X-Node-Protocol', '') != settings.NODE_PROTOCOL:
            logger.debug('protocol missmatch %s vs %s',
                         self.headers.get('X-Node-Protocol', ''), settings.NODE_PROTOCOL)
            logger.debug('headers %s', self.headers)
            content = settings.release
        else:
            try:
                action, args = json.loads(data.decode('utf-8'))
            except:
                logger.debug('invalid data: %s', data, exc_info=True)
                response_status = (500, 'invalid request')
                content = {
                    'status': 'invalid request'
                }
                self.write_response(response_status, content)
                return
            logger.debug('%s requests %s%s', user_id, action, args)
            if action == 'ping':
                content = {
                    'status': 'ok'
                }
            else:
                content = api_call(action, user_id, args)
                if content is None:
                    content = {'status': 'not peered'}
                    logger.debug('PEER %s IS UNKNOWN SEND 403', user_id)
                    response_status = (403, 'UNKNOWN USER')
                    content = {}
                #else:
                #    logger.debug('RESPONSE %s: %s', action, content)
        self.write_response(response_status, content)

    def write_response(self, response_status, content):
        self.send_response(*response_status)
        self.send_header('X-Node-Protocol', settings.NODE_PROTOCOL)
        self.send_header('Content-Type', 'application/json')
        content = json.dumps(content, ensure_ascii=False).encode('utf-8')
        content = self.gzip_data(content)
        content_length = len(content)
        self.send_header('Content-Length', str(content_length))
        self.end_headers()
        self.wfile.write(content)

    def chunk_size(self, content_length):
        return min(16*1024, content_length)

    def write_with_limit(self, content, content_length):
        chunk_size = self.chunk_size(content_length)
        position = 0
        while position < content_length:
            if state.bandwidth:
                while not state.bandwidth.upload(chunk_size) and self.server._running:
                    time.sleep(0.1)
            data = content[position:position+chunk_size]
            self.wfile.write(data)
            position += chunk_size

    def write_fd_with_limit(self, f, content_length):
        chunk_size = self.chunk_size(content_length)
        position = 0
        while True:
            data = f.read(chunk_size)
            if not data:
                break
            self.wfile.write(data)
            position += chunk_size
            if position + chunk_size > content_length:
                chunk_size = content_length - position
            if chunk_size <= 0:
                break
            if state.bandwidth:
                while not state.bandwidth.upload(chunk_size) and self.server._running:
                    time.sleep(0.1)

    def write_file_with_limit(self, path, content_length):
        with open(path, 'rb') as f:
            self.write_fd_with_limit(f, content_length)

class Server(Thread):
    http_server = None

    def __init__(self):
        Thread.__init__(self)
        address = (settings.server['node_address'], settings.server['node_port'])
        self.http_server = NodeServer(address, Handler)
        self.daemon = True
        self.start()

    def run(self):
        self.http_server.serve_forever()

    def stop(self):
        if self.http_server:
            self.http_server._running = False
            self.http_server.shutdown()
            self.http_server.socket.close()
        return Thread.join(self)

def start():
    return Server()