# -*- coding: utf-8 -*-
from socketserver import ThreadingMixIn
from threading import Thread
import base64
import gzip
import hashlib
import http.server
import io
import json
import os
import socket
import socketserver
import time

from Crypto.PublicKey import RSA
from Crypto.Util.asn1 import DerSequence
from OpenSSL.crypto import dump_privatekey, FILETYPE_ASN1
from OpenSSL.SSL import (
    Context, Connection, TLSv1_2_METHOD,
    VERIFY_PEER, VERIFY_FAIL_IF_NO_PEER_CERT, VERIFY_CLIENT_ONCE
)

import db
import settings
import state
import user
from changelog import changelog_size, changelog_path
from websocket import trigger_event

from . import nodeapi
from .sslsocket import fileobject

import logging
logger = logging.getLogger(__name__)


def get_service_id(key):
    '''
    service_id is the first half of the sha1 of the rsa public key encoded in base32
    '''
    # compute sha1 of public key and encode first half in base32
    pub_der = DerSequence()
    pub_der.decode(dump_privatekey(FILETYPE_ASN1, key))
    public_key = RSA.construct((pub_der._seq[1], pub_der._seq[2])).exportKey('DER')[22:]
    service_id = base64.b32encode(hashlib.sha1(public_key).digest()[:10]).lower().decode()
    return service_id

class TLSTCPServer(socketserver.TCPServer):

    def _accept(self, connection, x509, errnum, errdepth, ok):
        # client_id is validated in request
        return True

    def __init__(self, server_address, HandlerClass, bind_and_activate=True):
        socketserver.TCPServer.__init__(self, server_address, HandlerClass)
        ctx = Context(TLSv1_2_METHOD)
        ctx.use_privatekey_file(settings.ssl_key_path)
        ctx.use_certificate_file(settings.ssl_cert_path)
        # only allow clients with cert:
        ctx.set_verify(VERIFY_PEER | VERIFY_CLIENT_ONCE | VERIFY_FAIL_IF_NO_PEER_CERT, self._accept)
        #ctx.set_verify(VERIFY_PEER | VERIFY_CLIENT_ONCE, self._accept)
        self.socket = Connection(ctx, socket.socket(self.address_family, self.socket_type))
        if bind_and_activate:
            self.server_bind()
            self.server_activate()

    def shutdown_request(self, request):
        try:
            request.shutdown()
        except:
            pass

class NodeServer(ThreadingMixIn, TLSTCPServer):
    _running = True
    allow_reuse_address = True


def api_call(action, user_id, args):
    with db.session():
        u = user.models.User.get(user_id)
        if action in (
            'requestPeering', 'acceptPeering', 'rejectPeering',
            'removePeering', 'cancelPeering'
        ) or (u and u.peered):
            content = getattr(nodeapi, 'api_' + action)(user_id, *args)
        else:
            if u and u.pending:
                logger.debug('ignore request from pending peer[%s] %s (%s) (pending state: %s)',
                             user_id, action, args, u.pending)
                content = {}
            else:
                content = None
    return content

class Handler(http.server.SimpleHTTPRequestHandler):

    def setup(self):
        self.connection = self.request
        self.rfile = fileobject(self.connection, 'rb', self.rbufsize)
        self.wfile = fileobject(self.connection, 'wb', self.wbufsize)

    def version_string(self):
        return settings.USER_AGENT

    def log_message(self, format, *args):
        if settings.DEBUG_HTTP:
            logger.debug("%s - - [%s] %s\n", self.address_string(),
                         self.log_date_time_string(), format % args)

    def do_HEAD(self):
        return self.do_GET()

    def do_GET(self):
        #x509 = self.connection.get_peer_certificate()
        #user_id = get_service_id(x509.get_pubkey()) if x509 else None
        import item.models
        parts = self.path.split('/')
        if len(parts) == 3 and parts[1] in ('get', 'preview'):
            id = parts[2]
            preview = parts[1] == 'preview'
        else:
            id = None
        if id and len(id) == 32 and id.isalnum():
            path = None
            data = None
            with db.session():
                if preview:
                    from item.icons import get_icon_sync
                    try:
                        content = get_icon_sync(id, 'preview', 512)
                    except:
                        content = None
                    if content:
                        self.send_response(200, 'ok')
                        mimetype = 'image/jpg'
                    else:
                        self.send_response(404, 'Not Found')
                        content = b'404 - Not Found'
                        mimetype = 'text/plain'
                else:
                    file = item.models.File.get(id)
                    if file:
                        path = file.fullpath()
                        mimetype = {
                            'epub': 'application/epub+zip',
                            'pdf': 'application/pdf',
                            'txt': 'text/plain',
                        }.get(path.split('.')[-1], None)
                        self.send_response(200, 'OK')
                    else:
                        self.send_response(404, 'Not Found')
                        content = b'404 - Not Found'
                        mimetype = 'text/plain'
            self.send_header('Content-Type', mimetype)
            self.send_header('X-Node-Protocol', settings.NODE_PROTOCOL)
            if mimetype == 'text/plain' and path:
                with open(path, 'rb') as f:
                    content = f.read()
                content = self.gzip_data(content)
                content_length = len(content)
            elif path:
                content = None
                content_length = os.path.getsize(path)
            elif content:
                content_length = len(content)
            else:
                content_length = 0
            self.send_header('Content-Length', str(content_length))
            self.end_headers()
            if content:
                self.write_with_limit(content, content_length)
            elif path:
                self.write_file_with_limit(path, content_length)
        elif len(parts) == 2 and parts[1] == 'log':
            self._changelog()
        else:
            self.send_response(200, 'OK')
            self.send_header('Content-type', 'text/plain')
            self.send_header('X-Node-Protocol', settings.NODE_PROTOCOL)
            self.end_headers()
            self.wfile.write('Open Media Library\n'.encode())

    def _denied(self):
        self.send_response(403, 'denied')
        self.end_headers()

    def _changelog(self):
        x509 = self.connection.get_peer_certificate()
        user_id = get_service_id(x509.get_pubkey()) if x509 else None
        with db.session():
            u = user.models.User.get(user_id)
            if not u:
                return self._denied()
            if not u.peered and u.pending == 'sent':
                u.update_peering(True)
                state.nodes.queue('add', u.id, True)
                trigger_event('peering.accept', u.json())
            if u.pending:
                logger.debug('ignore request from pending peer[%s] changelog (pending sate: %s)', user_id, u.pending)
                return self._denied()
            if not u.peered:
                return self._denied()
        path = changelog_path()
        content_length = changelog_size()
        with open(path, 'rb') as log:
            request_range = self.headers.get('Range', '')
            if request_range:
                r = request_range.split('=')[-1].split('-')
                start = int(r[0])
                end = int(r[1]) if r[1] else (content_length - 1)
                if start == content_length:
                    content_length = 0
                else:
                    content_length = end - start + 1
                if content_length < 0:
                    content_length = os.path.getsize(path)
                    self.send_response(200, 'OK')
                else:
                    log.seek(start)
                    self.send_response(206, 'OK')
            else:
                self.send_response(200, 'OK')
            self.send_header('Content-type', 'text/json')
            self.send_header('X-Node-Protocol', settings.NODE_PROTOCOL)
            self.send_header('Content-Length', str(content_length))
            self.end_headers()
            self.write_fd_with_limit(log, content_length)

    def gzip_data(self, data):
        encoding = self.headers.get('Accept-Encoding')
        if encoding.find('gzip') != -1:
            self.send_header('Content-Encoding', 'gzip')
            bytes_io = io.BytesIO()
            gzip_file = gzip.GzipFile(fileobj=bytes_io, mode='wb')
            gzip_file.write(data)
            gzip_file.close()
            result = bytes_io.getvalue()
            bytes_io.close()
            return result
        else:
            return data

    def gunzip_data(self, data):
        bytes_io = io.BytesIO(data)
        gzip_file = gzip.GzipFile(fileobj=bytes_io, mode='rb')
        result = gzip_file.read()
        gzip_file.close()
        return result

    def do_POST(self):
        '''
            API
            requestPeering  username message
            acceptPeering   username message
            rejectPeering   message
            removePeering   message

            ping            responds public ip
        '''
        x509 = self.connection.get_peer_certificate()
        user_id = get_service_id(x509.get_pubkey()) if x509 else None

        content = {}
        try:
            content_len = int(self.headers.get('content-length', 0))
            data = self.rfile.read(content_len)
            if self.headers.get('Content-Encoding') == 'gzip':
                data = self.gunzip_data(data)
        except:
            logger.debug('invalid request', exc_info=True)
            response_status = (500, 'invalid request')
            self.write_response(response_status, content)
            return

        response_status = (200, 'OK')
        if self.headers.get('X-Node-Protocol', '') > settings.NODE_PROTOCOL:
            state.update_required = True
        if self.headers.get('X-Node-Protocol', '') != settings.NODE_PROTOCOL:
            logger.debug('protocol missmatch %s vs %s',
                         self.headers.get('X-Node-Protocol', ''), settings.NODE_PROTOCOL)
            logger.debug('headers %s', self.headers)
            content = settings.release
        else:
            try:
                action, args = json.loads(data.decode('utf-8'))
            except:
                logger.debug('invalid data: %s', data, exc_info=True)
                response_status = (500, 'invalid request')
                content = {
                    'status': 'invalid request'
                }
                self.write_response(response_status, content)
                return
            logger.debug('%s requests %s%s', user_id, action, args)
            if action == 'ping':
                content = {
                    'status': 'ok'
                }
            else:
                content = api_call(action, user_id, args)
                if content is None:
                    content = {'status': 'not peered'}
                    logger.debug('PEER %s IS UNKNOWN SEND 403', user_id)
                    response_status = (403, 'UNKNOWN USER')
                    content = {}
                #else:
                #    logger.debug('RESPONSE %s: %s', action, content)
        self.write_response(response_status, content)

    def write_response(self, response_status, content):
        self.send_response(*response_status)
        self.send_header('X-Node-Protocol', settings.NODE_PROTOCOL)
        self.send_header('Content-Type', 'application/json')
        content = json.dumps(content, ensure_ascii=False).encode('utf-8')
        content = self.gzip_data(content)
        content_length = len(content)
        self.send_header('Content-Length', str(content_length))
        self.end_headers()
        self.wfile.write(content)

    def chunk_size(self, content_length):
        return min(16*1024, content_length)

    def write_with_limit(self, content, content_length):
        chunk_size = self.chunk_size(content_length)
        position = 0
        while position < content_length:
            if state.bandwidth:
                while not state.bandwidth.upload(chunk_size) and self.server._running:
                    time.sleep(0.1)
            data = content[position:position+chunk_size]
            self.wfile.write(data)
            position += chunk_size

    def write_fd_with_limit(self, f, content_length):
        chunk_size = self.chunk_size(content_length)
        position = 0
        while True:
            data = f.read(chunk_size)
            if not data:
                break
            self.wfile.write(data)
            position += chunk_size
            if position + chunk_size > content_length:
                chunk_size = content_length - position
            if chunk_size <= 0:
                break
            if state.bandwidth:
                while not state.bandwidth.upload(chunk_size) and self.server._running:
                    time.sleep(0.1)

    def write_file_with_limit(self, path, content_length):
        with open(path, 'rb') as f:
            self.write_fd_with_limit(f, content_length)

class Server(Thread):
    http_server = None

    def __init__(self):
        Thread.__init__(self)
        address = (settings.server['node_address'], settings.server['node_port'])
        self.http_server = NodeServer(address, Handler)
        self.daemon = True
        self.start()

    def run(self):
        self.http_server.serve_forever()

    def stop(self):
        if self.http_server:
            self.http_server._running = False
            self.http_server.shutdown()
            self.http_server.socket.close()
        return Thread.join(self)

def start():
    return Server()