cleanup tls fingerprint check

This commit is contained in:
j 2014-09-09 16:29:31 +02:00
parent 0956bd4966
commit 6e2c91fb4c

View file

@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
# vi:si:et:sw=4:sts=4:ts=4
import http.client import http.client
import socket
import urllib.request, urllib.error, urllib.parse import urllib.request, urllib.error, urllib.parse
import ssl
import hashlib import hashlib
import logging import logging
logger = logging.getLogger('oml.ssl_request') logger = logging.getLogger('oml.ssl_request')
@ -9,75 +10,59 @@ logger = logging.getLogger('oml.ssl_request')
class InvalidCertificateException(http.client.HTTPException, urllib.error.URLError): class InvalidCertificateException(http.client.HTTPException, urllib.error.URLError):
def __init__(self, fingerprint, cert, reason): def __init__(self, fingerprint, cert, reason):
http.client.HTTPException.__init__(self) http.client.HTTPException.__init__(self)
self.fingerprint = fingerprint self._fingerprint = fingerprint
self.cert_fingerprint = hashlib.sha1(cert).hexdigest() self._cert_fingerprint = hashlib.sha1(cert).hexdigest()
self.reason = reason self.reason = reason
def __str__(self): def __str__(self):
return ('%s (local) != %s (remote) (%s)\n' % return ('%s (local) != %s (remote) (%s)\n' %
(self.fingerprint, self.cert_fingerprint, self.reason)) (self._fingerprint, self._cert_fingerprint, self.reason))
class CertValidatingHTTPSConnection(http.client.HTTPConnection): class FingerprintHTTPSConnection(http.client.HTTPSConnection):
default_port = http.client.HTTPS_PORT
def __init__(self, host, port=None, fingerprint=None, strict=None, **kwargs): def __init__(self, host, port=None, fingerprint=None, check_hostname=None, **kwargs):
http.client.HTTPConnection.__init__(self, host, port, strict, **kwargs) self._fingerprint = fingerprint
self.fingerprint = fingerprint if self._fingerprint:
if self.fingerprint: check_hostname = None
self.cert_reqs = ssl.CERT_REQUIRED http.client.HTTPSConnection.__init__(self, host, port,
else: check_hostname=check_hostname, **kwargs)
self.cert_reqs = ssl.CERT_NONE
self.cert_reqs = ssl.CERT_NONE
def _ValidateCertificateFingerprint(self, cert): def _check_fingerprint(self, cert):
if len(self.fingerprint) == 40: if len(self._fingerprint) == 40:
fingerprint = hashlib.sha1(cert).hexdigest() fingerprint = hashlib.sha1(cert).hexdigest()
elif len(self.fingerprint) == 64: elif len(self._fingerprint) == 64:
fingerprint = hashlib.sha256(cert).hexdigest() fingerprint = hashlib.sha256(cert).hexdigest()
elif len(self.fingerprint) == 128: elif len(self._fingerprint) == 128:
fingerprint = hashlib.sha512(cert).hexdigest() fingerprint = hashlib.sha512(cert).hexdigest()
else: else:
logging.error('unkown fingerprint length %s (%s)', self.fingerprint, len(self.fingerprint)) logging.error('unkown _fingerprint length %s (%s)',
self._fingerprint, len(self._fingerprint))
return False return False
return fingerprint == self.fingerprint return fingerprint == self._fingerprint
def connect(self): def connect(self):
sock = socket.create_connection((self.host, self.port)) http.client.HTTPSConnection.connect(self)
self.sock = ssl.wrap_socket(sock, cert_reqs=self.cert_reqs) if self._fingerprint:
#if self.cert_reqs & ssl.CERT_REQUIRED:
if self.fingerprint:
cert = self.sock.getpeercert(binary_form=True) cert = self.sock.getpeercert(binary_form=True)
if not self._ValidateCertificateFingerprint(cert): if not self._check_fingerprint(cert):
raise InvalidCertificateException(self.fingerprint, cert, raise InvalidCertificateException(self._fingerprint, cert,
'fingerprint mismatch') 'fingerprint mismatch')
#logger.debug('CIPHER %s VERSION %s', self.sock.cipher(), self.sock.ssl_version) #logger.debug('CIPHER %s VERSION %s', self.sock.cipher(), self.sock.ssl_version)
class VerifiedHTTPSHandler(urllib.request.HTTPSHandler): class FingerprintHTTPSHandler(urllib.request.HTTPSHandler):
def __init__(self, **kwargs):
urllib.request.AbstractHTTPHandler.__init__(self) def __init__(self, debuglevel=0, context=None, check_hostname=None, fingerprint=None):
self._connection_args = kwargs urllib.request.AbstractHTTPHandler.__init__(self, debuglevel)
self._context = context
self._check_hostname = check_hostname
self._fingerprint = fingerprint
def https_open(self, req): def https_open(self, req):
def http_class_wrapper(host, **kwargs): return self.do_open(FingerprintHTTPSConnection, req,
full_kwargs = dict(self._connection_args) context=self._context, check_hostname=self._check_hostname,
full_kwargs.update(kwargs) fingerprint=self._fingerprint)
print(self._connection_args)
print(kwargs)
if 'timeout' in full_kwargs:
del full_kwargs['timeout']
return CertValidatingHTTPSConnection(host, **full_kwargs)
try:
return self.do_open(http_class_wrapper, req)
except urllib.error.URLError as e:
if type(e.reason) == ssl.SSLError and e.reason.args[0] == 1:
raise InvalidCertificateException(self.fingerprint, '',
e.reason.args[1])
raise
https_request = urllib.request.HTTPSHandler.do_request_
def get_opener(fingerprint): def get_opener(fingerprint):
handler = VerifiedHTTPSHandler(fingerprint=fingerprint) handler = FingerprintHTTPSHandler(fingerprint=fingerprint)
opener = urllib.request.build_opener(handler) opener = urllib.request.build_opener(handler)
return opener return opener