openmedialibrary/oml/node/server.py

379 lines
14 KiB
Python

# -*- coding: utf-8 -*-
# vi:si:et:sw=4:sts=4:ts=4
from socketserver import ThreadingMixIn
from threading import Thread
import base64
import gzip
import hashlib
import http.server
import io
import json
import os
import socket
import socketserver
import time
from Crypto.PublicKey import RSA
from Crypto.Util.asn1 import DerSequence
from OpenSSL.crypto import dump_privatekey, FILETYPE_ASN1
from OpenSSL.SSL import (
Context, Connection, TLSv1_2_METHOD,
VERIFY_PEER, VERIFY_FAIL_IF_NO_PEER_CERT, VERIFY_CLIENT_ONCE
)
import db
import settings
import state
import user
from changelog import changelog_size, changelog_path
from websocket import trigger_event
from . import nodeapi
from .sslsocket import fileobject
import logging
logger = logging.getLogger(__name__)
def get_service_id(key):
'''
service_id is the first half of the sha1 of the rsa public key encoded in base32
'''
# compute sha1 of public key and encode first half in base32
pub_der = DerSequence()
pub_der.decode(dump_privatekey(FILETYPE_ASN1, key))
public_key = RSA.construct((pub_der._seq[1], pub_der._seq[2])).exportKey('DER')[22:]
service_id = base64.b32encode(hashlib.sha1(public_key).digest()[:10]).lower().decode()
return service_id
class TLSTCPServer(socketserver.TCPServer):
def _accept(self, connection, x509, errnum, errdepth, ok):
# client_id is validated in request
return True
def __init__(self, server_address, HandlerClass, bind_and_activate=True):
socketserver.TCPServer.__init__(self, server_address, HandlerClass)
ctx = Context(TLSv1_2_METHOD)
ctx.use_privatekey_file(settings.ssl_key_path)
ctx.use_certificate_file(settings.ssl_cert_path)
# only allow clients with cert:
ctx.set_verify(VERIFY_PEER | VERIFY_CLIENT_ONCE | VERIFY_FAIL_IF_NO_PEER_CERT, self._accept)
#ctx.set_verify(VERIFY_PEER | VERIFY_CLIENT_ONCE, self._accept)
self.socket = Connection(ctx, socket.socket(self.address_family, self.socket_type))
if bind_and_activate:
self.server_bind()
self.server_activate()
def shutdown_request(self, request):
try:
request.shutdown()
except:
pass
class NodeServer(ThreadingMixIn, TLSTCPServer):
_running = True
allow_reuse_address = True
def api_call(action, user_id, args):
with db.session():
u = user.models.User.get(user_id)
if action in (
'requestPeering', 'acceptPeering', 'rejectPeering',
'removePeering', 'cancelPeering'
) or (u and u.peered):
content = getattr(nodeapi, 'api_' + action)(user_id, *args)
else:
if u and u.pending:
logger.debug('ignore request from pending peer[%s] %s (%s) (pending state: %s)',
user_id, action, args, u.pending)
content = {}
else:
content = None
return content
class Handler(http.server.SimpleHTTPRequestHandler):
def setup(self):
self.connection = self.request
self.rfile = fileobject(self.connection, 'rb', self.rbufsize)
self.wfile = fileobject(self.connection, 'wb', self.wbufsize)
def version_string(self):
return settings.USER_AGENT
def log_message(self, format, *args):
if settings.DEBUG_HTTP:
logger.debug("%s - - [%s] %s\n", self.address_string(),
self.log_date_time_string(), format % args)
def do_HEAD(self):
return self.do_GET()
def do_GET(self):
#x509 = self.connection.get_peer_certificate()
#user_id = get_service_id(x509.get_pubkey()) if x509 else None
import item.models
parts = self.path.split('/')
if len(parts) == 3 and parts[1] in ('get', 'preview'):
id = parts[2]
preview = parts[1] == 'preview'
else:
id = None
if id and len(id) == 32 and id.isalnum():
with db.session():
if preview:
from item.icons import get_icon_sync
try:
data = get_icon_sync(id, 'preview', 512)
except:
data = None
if data:
self.send_response(200, 'ok')
self.send_header('Content-type', 'image/jpg')
else:
self.send_response(404, 'Not Found')
self.send_header('Content-type', 'text/plain')
data = b'404 - Not Found'
content_length = len(data)
self.send_header('Content-Length', str(content_length))
self.end_headers()
self.write_with_limit(data, content_length)
return
file = item.models.File.get(id)
if not file:
self.send_response(404, 'Not Found')
self.send_header('Content-type', 'text/plain')
self.end_headers()
self.wfile.write(b'404 - Not Found')
return
path = file.fullpath()
mimetype = {
'epub': 'application/epub+zip',
'pdf': 'application/pdf',
'txt': 'text/plain',
}.get(path.split('.')[-1], None)
self.send_response(200, 'OK')
self.send_header('Content-Type', mimetype)
self.send_header('X-Node-Protocol', settings.NODE_PROTOCOL)
if mimetype == 'text/plain':
with open(path, 'rb') as f:
content = f.read()
content = self.gzip_data(content)
content_length = len(content)
else:
content = None
content_length = os.path.getsize(path)
self.send_header('Content-Length', str(content_length))
self.end_headers()
if content:
self.write_with_limit(content, content_length)
else:
self.write_file_with_limit(path, content_length)
elif len(parts) == 2 and parts[1] == 'log':
self._changelog()
else:
self.send_response(200, 'OK')
self.send_header('Content-type', 'text/plain')
self.send_header('X-Node-Protocol', settings.NODE_PROTOCOL)
self.end_headers()
self.wfile.write('Open Media Library\n'.encode())
def _denied(self):
self.send_response(403, 'denied')
self.end_headers()
def _changelog(self):
x509 = self.connection.get_peer_certificate()
user_id = get_service_id(x509.get_pubkey()) if x509 else None
with db.session():
u = user.models.User.get(user_id)
if not u:
return self._denied()
if not u.peered and u.pending == 'sent':
u.update_peering(True)
state.nodes.queue('add', u.id, True)
trigger_event('peering.accept', u.json())
if u.pending:
logger.debug('ignore request from pending peer[%s] changelog (pending sate: %s)', user_id, u.pending)
return self._denied()
if not u.peered:
return self._denied()
path = changelog_path()
content_length = changelog_size()
with open(path, 'rb') as log:
request_range = self.headers.get('Range', '')
if request_range:
r = request_range.split('=')[-1].split('-')
start = int(r[0])
end = int(r[1]) if r[1] else (content_length - 1)
if start == content_length:
content_length = 0
else:
content_length = end - start + 1
if content_length < 0:
content_length = os.path.getsize(path)
self.send_response(200, 'OK')
else:
log.seek(start)
self.send_response(206, 'OK')
else:
self.send_response(200, 'OK')
self.send_header('Content-type', 'text/json')
self.send_header('X-Node-Protocol', settings.NODE_PROTOCOL)
self.send_header('Content-Length', str(content_length))
self.end_headers()
self.write_fd_with_limit(log, content_length)
def gzip_data(self, data):
encoding = self.headers.get('Accept-Encoding')
if encoding.find('gzip') != -1:
self.send_header('Content-Encoding', 'gzip')
bytes_io = io.BytesIO()
gzip_file = gzip.GzipFile(fileobj=bytes_io, mode='wb')
gzip_file.write(data)
gzip_file.close()
result = bytes_io.getvalue()
bytes_io.close()
return result
else:
return data
def gunzip_data(self, data):
bytes_io = io.BytesIO(data)
gzip_file = gzip.GzipFile(fileobj=bytes_io, mode='rb')
result = gzip_file.read()
gzip_file.close()
return result
def do_POST(self):
'''
API
requestPeering username message
acceptPeering username message
rejectPeering message
removePeering message
ping responds public ip
'''
x509 = self.connection.get_peer_certificate()
user_id = get_service_id(x509.get_pubkey()) if x509 else None
content = {}
try:
content_len = int(self.headers.get('content-length', 0))
data = self.rfile.read(content_len)
if self.headers.get('Content-Encoding') == 'gzip':
data = self.gunzip_data(data)
except:
logger.debug('invalid request', exc_info=True)
response_status = (500, 'invalid request')
self.write_response(response_status, content)
return
response_status = (200, 'OK')
if self.headers.get('X-Node-Protocol', '') > settings.NODE_PROTOCOL:
state.update_required = True
if self.headers.get('X-Node-Protocol', '') != settings.NODE_PROTOCOL:
logger.debug('protocol missmatch %s vs %s',
self.headers.get('X-Node-Protocol', ''), settings.NODE_PROTOCOL)
logger.debug('headers %s', self.headers)
content = settings.release
else:
try:
action, args = json.loads(data.decode('utf-8'))
except:
logger.debug('invalid data: %s', data, exc_info=True)
response_status = (500, 'invalid request')
content = {
'status': 'invalid request'
}
self.write_response(response_status, content)
return
logger.debug('%s requests %s%s', user_id, action, args)
if action == 'ping':
content = {
'status': 'ok'
}
else:
content = api_call(action, user_id, args)
if content is None:
content = {'status': 'not peered'}
logger.debug('PEER %s IS UNKNOWN SEND 403', user_id)
response_status = (403, 'UNKNOWN USER')
content = {}
#else:
# logger.debug('RESPONSE %s: %s', action, content)
self.write_response(response_status, content)
def write_response(self, response_status, content):
self.send_response(*response_status)
self.send_header('X-Node-Protocol', settings.NODE_PROTOCOL)
self.send_header('Content-Type', 'application/json')
content = json.dumps(content, ensure_ascii=False).encode('utf-8')
content = self.gzip_data(content)
content_length = len(content)
self.send_header('Content-Length', str(content_length))
self.end_headers()
self.wfile.write(content)
logger.debug('%s bytes response', content_length)
def chunk_size(self, content_length):
return min(16*1024, content_length)
def write_with_limit(self, content, content_length):
chunk_size = self.chunk_size(content_length)
position = 0
while position < content_length:
if state.bandwidth:
while not state.bandwidth.upload(chunk_size) and self.server._running:
time.sleep(0.1)
data = content[position:position+chunk_size]
self.wfile.write(data)
position += chunk_size
def write_fd_with_limit(self, f, content_length):
chunk_size = self.chunk_size(content_length)
position = 0
while True:
data = f.read(chunk_size)
if not data:
break
self.wfile.write(data)
position += chunk_size
if position + chunk_size > content_length:
chunk_size = content_length - position
if chunk_size <= 0:
break
if state.bandwidth:
while not state.bandwidth.upload(chunk_size) and self.server._running:
time.sleep(0.1)
def write_file_with_limit(self, path, content_length):
with open(path, 'rb') as f:
self.write_fd_with_limit(f, content_length)
class Server(Thread):
http_server = None
def __init__(self):
Thread.__init__(self)
address = (settings.server['node_address'], settings.server['node_port'])
self.http_server = NodeServer(address, Handler)
self.daemon = True
self.start()
def run(self):
self.http_server.serve_forever()
def stop(self):
if self.http_server:
self.http_server._running = False
self.http_server.shutdown()
self.http_server.socket.close()
return Thread.join(self)
def start():
return Server()