update to tornado 4.0 and requests 2.3.0
This commit is contained in:
parent
060f459965
commit
f187000dc9
239 changed files with 19071 additions and 20369 deletions
14
Shared/lib/python2.7/site-packages/tornado/test/__main__.py
Normal file
14
Shared/lib/python2.7/site-packages/tornado/test/__main__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
"""Shim to allow python -m tornado.test.
|
||||
|
||||
This only works in python 2.7+.
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
|
||||
from tornado.test.runtests import all, main
|
||||
|
||||
# tornado.testing.main autodiscovery relies on 'all' being present in
|
||||
# the main module, so import it here even though it is not used directly.
|
||||
# The following line prevents a pyflakes warning.
|
||||
all = all
|
||||
|
||||
main()
|
||||
|
|
@ -67,11 +67,29 @@ class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
|
|||
self.finish(user)
|
||||
|
||||
def _oauth_get_user(self, access_token, callback):
|
||||
if self.get_argument('fail_in_get_user', None):
|
||||
raise Exception("failing in get_user")
|
||||
if access_token != dict(key='uiop', secret='5678'):
|
||||
raise Exception("incorrect access token %r" % access_token)
|
||||
callback(dict(email='foo@example.com'))
|
||||
|
||||
|
||||
class OAuth1ClientLoginCoroutineHandler(OAuth1ClientLoginHandler):
|
||||
"""Replaces OAuth1ClientLoginCoroutineHandler's get() with a coroutine."""
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
if self.get_argument('oauth_token', None):
|
||||
# Ensure that any exceptions are set on the returned Future,
|
||||
# not simply thrown into the surrounding StackContext.
|
||||
try:
|
||||
yield self.get_authenticated_user()
|
||||
except Exception as e:
|
||||
self.set_status(503)
|
||||
self.write("got exception: %s" % e)
|
||||
else:
|
||||
yield self.authorize_redirect()
|
||||
|
||||
|
||||
class OAuth1ClientRequestParametersHandler(RequestHandler, OAuthMixin):
|
||||
def initialize(self, version):
|
||||
self._OAUTH_VERSION = version
|
||||
|
|
@ -255,6 +273,9 @@ class AuthTest(AsyncHTTPTestCase):
|
|||
dict(version='1.0')),
|
||||
('/oauth10a/client/login', OAuth1ClientLoginHandler,
|
||||
dict(test=self, version='1.0a')),
|
||||
('/oauth10a/client/login_coroutine',
|
||||
OAuth1ClientLoginCoroutineHandler,
|
||||
dict(test=self, version='1.0a')),
|
||||
('/oauth10a/client/request_params',
|
||||
OAuth1ClientRequestParametersHandler,
|
||||
dict(version='1.0a')),
|
||||
|
|
@ -348,6 +369,12 @@ class AuthTest(AsyncHTTPTestCase):
|
|||
self.assertTrue('oauth_nonce' in parsed)
|
||||
self.assertTrue('oauth_signature' in parsed)
|
||||
|
||||
def test_oauth10a_get_user_coroutine_exception(self):
|
||||
response = self.fetch(
|
||||
'/oauth10a/client/login_coroutine?oauth_token=zxcv&fail_in_get_user=true',
|
||||
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
|
||||
self.assertEqual(response.code, 503)
|
||||
|
||||
def test_oauth2_redirect(self):
|
||||
response = self.fetch('/oauth2/client/login', follow_redirects=False)
|
||||
self.assertEqual(response.code, 302)
|
||||
|
|
|
|||
|
|
@ -30,6 +30,12 @@ from tornado.tcpserver import TCPServer
|
|||
from tornado.testing import AsyncTestCase, LogTrapTestCase, bind_unused_port, gen_test
|
||||
|
||||
|
||||
try:
|
||||
from concurrent import futures
|
||||
except ImportError:
|
||||
futures = None
|
||||
|
||||
|
||||
class ReturnFutureTest(AsyncTestCase):
|
||||
@return_future
|
||||
def sync_future(self, callback):
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from __future__ import absolute_import, division, print_function, with_statement
|
|||
|
||||
from hashlib import md5
|
||||
|
||||
from tornado.escape import utf8
|
||||
from tornado.httpclient import HTTPRequest
|
||||
from tornado.stack_context import ExceptionStackContext
|
||||
from tornado.testing import AsyncHTTPTestCase
|
||||
|
|
@ -21,7 +22,8 @@ if pycurl is not None:
|
|||
@unittest.skipIf(pycurl is None, "pycurl module not present")
|
||||
class CurlHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
|
||||
def get_http_client(self):
|
||||
client = CurlAsyncHTTPClient(io_loop=self.io_loop)
|
||||
client = CurlAsyncHTTPClient(io_loop=self.io_loop,
|
||||
defaults=dict(allow_ipv6=False))
|
||||
# make sure AsyncHTTPClient magic doesn't give us the wrong class
|
||||
self.assertTrue(isinstance(client, CurlAsyncHTTPClient))
|
||||
return client
|
||||
|
|
@ -51,10 +53,10 @@ class DigestAuthHandler(RequestHandler):
|
|||
assert param_dict['nonce'] == nonce
|
||||
assert param_dict['username'] == username
|
||||
assert param_dict['uri'] == self.request.path
|
||||
h1 = md5('%s:%s:%s' % (username, realm, password)).hexdigest()
|
||||
h2 = md5('%s:%s' % (self.request.method,
|
||||
self.request.path)).hexdigest()
|
||||
digest = md5('%s:%s:%s' % (h1, nonce, h2)).hexdigest()
|
||||
h1 = md5(utf8('%s:%s:%s' % (username, realm, password))).hexdigest()
|
||||
h2 = md5(utf8('%s:%s' % (self.request.method,
|
||||
self.request.path))).hexdigest()
|
||||
digest = md5(utf8('%s:%s:%s' % (h1, nonce, h2))).hexdigest()
|
||||
if digest == param_dict['response']:
|
||||
self.write('ok')
|
||||
else:
|
||||
|
|
@ -66,15 +68,28 @@ class DigestAuthHandler(RequestHandler):
|
|||
(realm, nonce, opaque))
|
||||
|
||||
|
||||
class CustomReasonHandler(RequestHandler):
|
||||
def get(self):
|
||||
self.set_status(200, "Custom reason")
|
||||
|
||||
|
||||
class CustomFailReasonHandler(RequestHandler):
|
||||
def get(self):
|
||||
self.set_status(400, "Custom reason")
|
||||
|
||||
|
||||
@unittest.skipIf(pycurl is None, "pycurl module not present")
|
||||
class CurlHTTPClientTestCase(AsyncHTTPTestCase):
|
||||
def setUp(self):
|
||||
super(CurlHTTPClientTestCase, self).setUp()
|
||||
self.http_client = CurlAsyncHTTPClient(self.io_loop)
|
||||
self.http_client = CurlAsyncHTTPClient(self.io_loop,
|
||||
defaults=dict(allow_ipv6=False))
|
||||
|
||||
def get_app(self):
|
||||
return Application([
|
||||
('/digest', DigestAuthHandler),
|
||||
('/custom_reason', CustomReasonHandler),
|
||||
('/custom_fail_reason', CustomFailReasonHandler),
|
||||
])
|
||||
|
||||
def test_prepare_curl_callback_stack_context(self):
|
||||
|
|
@ -97,3 +112,11 @@ class CurlHTTPClientTestCase(AsyncHTTPTestCase):
|
|||
response = self.fetch('/digest', auth_mode='digest',
|
||||
auth_username='foo', auth_password='bar')
|
||||
self.assertEqual(response.body, b'ok')
|
||||
|
||||
def test_custom_reason(self):
|
||||
response = self.fetch('/custom_reason')
|
||||
self.assertEqual(response.reason, "Custom reason")
|
||||
|
||||
def test_fail_custom_reason(self):
|
||||
response = self.fetch('/custom_fail_reason')
|
||||
self.assertEqual(str(response.error), "HTTP 400: Custom reason")
|
||||
|
|
|
|||
|
|
@ -144,7 +144,7 @@ class EscapeTestCase(unittest.TestCase):
|
|||
(u("<foo>"), u("<foo>")),
|
||||
(b"<foo>", b"<foo>"),
|
||||
|
||||
("<>&\"", "<>&""),
|
||||
("<>&\"'", "<>&"'"),
|
||||
("&", "&amp;"),
|
||||
|
||||
(u("<\u00e9>"), u("<\u00e9>")),
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
|
||||
import contextlib
|
||||
import datetime
|
||||
import functools
|
||||
import sys
|
||||
import textwrap
|
||||
|
|
@ -8,7 +9,7 @@ import time
|
|||
import platform
|
||||
import weakref
|
||||
|
||||
from tornado.concurrent import return_future
|
||||
from tornado.concurrent import return_future, Future
|
||||
from tornado.escape import url_escape
|
||||
from tornado.httpclient import AsyncHTTPClient
|
||||
from tornado.ioloop import IOLoop
|
||||
|
|
@ -20,6 +21,10 @@ from tornado.web import Application, RequestHandler, asynchronous, HTTPError
|
|||
|
||||
from tornado import gen
|
||||
|
||||
try:
|
||||
from concurrent import futures
|
||||
except ImportError:
|
||||
futures = None
|
||||
|
||||
skipBefore33 = unittest.skipIf(sys.version_info < (3, 3), 'PEP 380 not available')
|
||||
skipNotCPython = unittest.skipIf(platform.python_implementation() != 'CPython',
|
||||
|
|
@ -281,18 +286,67 @@ class GenEngineTest(AsyncTestCase):
|
|||
self.stop()
|
||||
self.run_gen(f)
|
||||
|
||||
def test_multi_delayed(self):
|
||||
def test_multi_dict(self):
|
||||
@gen.engine
|
||||
def f():
|
||||
(yield gen.Callback("k1"))("v1")
|
||||
(yield gen.Callback("k2"))("v2")
|
||||
results = yield dict(foo=gen.Wait("k1"), bar=gen.Wait("k2"))
|
||||
self.assertEqual(results, dict(foo="v1", bar="v2"))
|
||||
self.stop()
|
||||
self.run_gen(f)
|
||||
|
||||
# The following tests explicitly run with both gen.Multi
|
||||
# and gen.multi_future (Task returns a Future, so it can be used
|
||||
# with either).
|
||||
def test_multi_yieldpoint_delayed(self):
|
||||
@gen.engine
|
||||
def f():
|
||||
# callbacks run at different times
|
||||
responses = yield [
|
||||
responses = yield gen.Multi([
|
||||
gen.Task(self.delay_callback, 3, arg="v1"),
|
||||
gen.Task(self.delay_callback, 1, arg="v2"),
|
||||
]
|
||||
])
|
||||
self.assertEqual(responses, ["v1", "v2"])
|
||||
self.stop()
|
||||
self.run_gen(f)
|
||||
|
||||
def test_multi_yieldpoint_dict_delayed(self):
|
||||
@gen.engine
|
||||
def f():
|
||||
# callbacks run at different times
|
||||
responses = yield gen.Multi(dict(
|
||||
foo=gen.Task(self.delay_callback, 3, arg="v1"),
|
||||
bar=gen.Task(self.delay_callback, 1, arg="v2"),
|
||||
))
|
||||
self.assertEqual(responses, dict(foo="v1", bar="v2"))
|
||||
self.stop()
|
||||
self.run_gen(f)
|
||||
|
||||
def test_multi_future_delayed(self):
|
||||
@gen.engine
|
||||
def f():
|
||||
# callbacks run at different times
|
||||
responses = yield gen.multi_future([
|
||||
gen.Task(self.delay_callback, 3, arg="v1"),
|
||||
gen.Task(self.delay_callback, 1, arg="v2"),
|
||||
])
|
||||
self.assertEqual(responses, ["v1", "v2"])
|
||||
self.stop()
|
||||
self.run_gen(f)
|
||||
|
||||
def test_multi_future_dict_delayed(self):
|
||||
@gen.engine
|
||||
def f():
|
||||
# callbacks run at different times
|
||||
responses = yield gen.multi_future(dict(
|
||||
foo=gen.Task(self.delay_callback, 3, arg="v1"),
|
||||
bar=gen.Task(self.delay_callback, 1, arg="v2"),
|
||||
))
|
||||
self.assertEqual(responses, dict(foo="v1", bar="v2"))
|
||||
self.stop()
|
||||
self.run_gen(f)
|
||||
|
||||
@skipOnTravis
|
||||
@gen_test
|
||||
def test_multi_performance(self):
|
||||
|
|
@ -304,6 +358,23 @@ class GenEngineTest(AsyncTestCase):
|
|||
end = time.time()
|
||||
self.assertLess(end - start, 1.0)
|
||||
|
||||
@gen_test
|
||||
def test_multi_empty(self):
|
||||
# Empty lists or dicts should return the same type.
|
||||
x = yield []
|
||||
self.assertTrue(isinstance(x, list))
|
||||
y = yield {}
|
||||
self.assertTrue(isinstance(y, dict))
|
||||
|
||||
@gen_test
|
||||
def test_multi_mixed_types(self):
|
||||
# A YieldPoint (Wait) and Future (Task) can be combined
|
||||
# (and use the YieldPoint codepath)
|
||||
(yield gen.Callback("k1"))("v1")
|
||||
responses = yield [gen.Wait("k1"),
|
||||
gen.Task(self.delay_callback, 3, arg="v2")]
|
||||
self.assertEqual(responses, ["v1", "v2"])
|
||||
|
||||
@gen_test
|
||||
def test_future(self):
|
||||
result = yield self.async_future(1)
|
||||
|
|
@ -314,6 +385,11 @@ class GenEngineTest(AsyncTestCase):
|
|||
results = yield [self.async_future(1), self.async_future(2)]
|
||||
self.assertEqual(results, [1, 2])
|
||||
|
||||
@gen_test
|
||||
def test_multi_dict_future(self):
|
||||
results = yield dict(foo=self.async_future(1), bar=self.async_future(2))
|
||||
self.assertEqual(results, dict(foo=1, bar=2))
|
||||
|
||||
def test_arguments(self):
|
||||
@gen.engine
|
||||
def f():
|
||||
|
|
@ -698,8 +774,14 @@ class GenCoroutineTest(AsyncTestCase):
|
|||
def test_replace_context_exception(self):
|
||||
# Test exception handling: exceptions thrown into the stack context
|
||||
# can be caught and replaced.
|
||||
# Note that this test and the following are for behavior that is
|
||||
# not really supported any more: coroutines no longer create a
|
||||
# stack context automatically; but one is created after the first
|
||||
# YieldPoint (i.e. not a Future).
|
||||
@gen.coroutine
|
||||
def f2():
|
||||
(yield gen.Callback(1))()
|
||||
yield gen.Wait(1)
|
||||
self.io_loop.add_callback(lambda: 1 / 0)
|
||||
try:
|
||||
yield gen.Task(self.io_loop.add_timeout,
|
||||
|
|
@ -718,6 +800,8 @@ class GenCoroutineTest(AsyncTestCase):
|
|||
# can be caught and ignored.
|
||||
@gen.coroutine
|
||||
def f2():
|
||||
(yield gen.Callback(1))()
|
||||
yield gen.Wait(1)
|
||||
self.io_loop.add_callback(lambda: 1 / 0)
|
||||
try:
|
||||
yield gen.Task(self.io_loop.add_timeout,
|
||||
|
|
@ -729,6 +813,31 @@ class GenCoroutineTest(AsyncTestCase):
|
|||
self.assertEqual(result, 42)
|
||||
self.finished = True
|
||||
|
||||
@gen_test
|
||||
def test_moment(self):
|
||||
calls = []
|
||||
@gen.coroutine
|
||||
def f(name, yieldable):
|
||||
for i in range(5):
|
||||
calls.append(name)
|
||||
yield yieldable
|
||||
# First, confirm the behavior without moment: each coroutine
|
||||
# monopolizes the event loop until it finishes.
|
||||
immediate = Future()
|
||||
immediate.set_result(None)
|
||||
yield [f('a', immediate), f('b', immediate)]
|
||||
self.assertEqual(''.join(calls), 'aaaaabbbbb')
|
||||
|
||||
# With moment, they take turns.
|
||||
calls = []
|
||||
yield [f('a', gen.moment), f('b', gen.moment)]
|
||||
self.assertEqual(''.join(calls), 'ababababab')
|
||||
self.finished = True
|
||||
|
||||
calls = []
|
||||
yield [f('a', gen.moment), f('b', immediate)]
|
||||
self.assertEqual(''.join(calls), 'abbbbbaaaa')
|
||||
|
||||
|
||||
class GenSequenceHandler(RequestHandler):
|
||||
@asynchronous
|
||||
|
|
@ -803,7 +912,6 @@ class GenExceptionHandler(RequestHandler):
|
|||
|
||||
|
||||
class GenCoroutineExceptionHandler(RequestHandler):
|
||||
@asynchronous
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
# This test depends on the order of the two decorators.
|
||||
|
|
@ -909,5 +1017,55 @@ class GenWebTest(AsyncHTTPTestCase):
|
|||
response = self.fetch('/async_prepare_error')
|
||||
self.assertEqual(response.code, 403)
|
||||
|
||||
|
||||
class WithTimeoutTest(AsyncTestCase):
|
||||
@gen_test
|
||||
def test_timeout(self):
|
||||
with self.assertRaises(gen.TimeoutError):
|
||||
yield gen.with_timeout(datetime.timedelta(seconds=0.1),
|
||||
Future())
|
||||
|
||||
@gen_test
|
||||
def test_completes_before_timeout(self):
|
||||
future = Future()
|
||||
self.io_loop.add_timeout(datetime.timedelta(seconds=0.1),
|
||||
lambda: future.set_result('asdf'))
|
||||
result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
|
||||
future)
|
||||
self.assertEqual(result, 'asdf')
|
||||
|
||||
@gen_test
|
||||
def test_fails_before_timeout(self):
|
||||
future = Future()
|
||||
self.io_loop.add_timeout(
|
||||
datetime.timedelta(seconds=0.1),
|
||||
lambda: future.set_exception(ZeroDivisionError))
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
yield gen.with_timeout(datetime.timedelta(seconds=3600), future)
|
||||
|
||||
@gen_test
|
||||
def test_already_resolved(self):
|
||||
future = Future()
|
||||
future.set_result('asdf')
|
||||
result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
|
||||
future)
|
||||
self.assertEqual(result, 'asdf')
|
||||
|
||||
@unittest.skipIf(futures is None, 'futures module not present')
|
||||
@gen_test
|
||||
def test_timeout_concurrent_future(self):
|
||||
with futures.ThreadPoolExecutor(1) as executor:
|
||||
with self.assertRaises(gen.TimeoutError):
|
||||
yield gen.with_timeout(self.io_loop.time(),
|
||||
executor.submit(time.sleep, 0.1))
|
||||
|
||||
@unittest.skipIf(futures is None, 'futures module not present')
|
||||
@gen_test
|
||||
def test_completed_concurrent_future(self):
|
||||
with futures.ThreadPoolExecutor(1) as executor:
|
||||
yield gen.with_timeout(datetime.timedelta(seconds=3600),
|
||||
executor.submit(lambda: None))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from tornado.log import gen_log
|
|||
from tornado import netutil
|
||||
from tornado.stack_context import ExceptionStackContext, NullContext
|
||||
from tornado.testing import AsyncHTTPTestCase, bind_unused_port, gen_test, ExpectLog
|
||||
from tornado.test.util import unittest
|
||||
from tornado.test.util import unittest, skipOnTravis
|
||||
from tornado.util import u, bytes_type
|
||||
from tornado.web import Application, RequestHandler, url
|
||||
|
||||
|
|
@ -110,6 +110,7 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase):
|
|||
url("/all_methods", AllMethodsHandler),
|
||||
], gzip=True)
|
||||
|
||||
@skipOnTravis
|
||||
def test_hello_world(self):
|
||||
response = self.fetch("/hello")
|
||||
self.assertEqual(response.code, 200)
|
||||
|
|
@ -309,7 +310,7 @@ Transfer-Encoding: chunked
|
|||
self.assertIs(exc_info[0][0], ZeroDivisionError)
|
||||
|
||||
def test_configure_defaults(self):
|
||||
defaults = dict(user_agent='TestDefaultUserAgent')
|
||||
defaults = dict(user_agent='TestDefaultUserAgent', allow_ipv6=False)
|
||||
# Construct a new instance of the configured client class
|
||||
client = self.http_client.__class__(self.io_loop, force_instance=True,
|
||||
defaults=defaults)
|
||||
|
|
@ -355,11 +356,10 @@ Transfer-Encoding: chunked
|
|||
|
||||
@gen_test
|
||||
def test_future_http_error(self):
|
||||
try:
|
||||
with self.assertRaises(HTTPError) as context:
|
||||
yield self.http_client.fetch(self.get_url('/notfound'))
|
||||
except HTTPError as e:
|
||||
self.assertEqual(e.code, 404)
|
||||
self.assertEqual(e.response.code, 404)
|
||||
self.assertEqual(context.exception.code, 404)
|
||||
self.assertEqual(context.exception.response.code, 404)
|
||||
|
||||
@gen_test
|
||||
def test_reuse_request_from_response(self):
|
||||
|
|
@ -387,6 +387,19 @@ Transfer-Encoding: chunked
|
|||
allow_nonstandard_methods=True)
|
||||
self.assertEqual(response.body, b'OTHER')
|
||||
|
||||
@gen_test
|
||||
def test_body(self):
|
||||
hello_url = self.get_url('/hello')
|
||||
with self.assertRaises(AssertionError) as context:
|
||||
yield self.http_client.fetch(hello_url, body='data')
|
||||
|
||||
self.assertTrue('must be empty' in str(context.exception))
|
||||
|
||||
with self.assertRaises(AssertionError) as context:
|
||||
yield self.http_client.fetch(hello_url, method='POST')
|
||||
|
||||
self.assertTrue('must not be empty' in str(context.exception))
|
||||
|
||||
|
||||
class RequestProxyTest(unittest.TestCase):
|
||||
def test_request_set(self):
|
||||
|
|
@ -433,17 +446,22 @@ class HTTPResponseTestCase(unittest.TestCase):
|
|||
|
||||
class SyncHTTPClientTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
if IOLoop.configured_class().__name__ == 'TwistedIOLoop':
|
||||
if IOLoop.configured_class().__name__ in ('TwistedIOLoop',
|
||||
'AsyncIOMainLoop'):
|
||||
# TwistedIOLoop only supports the global reactor, so we can't have
|
||||
# separate IOLoops for client and server threads.
|
||||
# AsyncIOMainLoop doesn't work with the default policy
|
||||
# (although it could with some tweaks to this test and a
|
||||
# policy that created loops for non-main threads).
|
||||
raise unittest.SkipTest(
|
||||
'Sync HTTPClient not compatible with TwistedIOLoop')
|
||||
'Sync HTTPClient not compatible with TwistedIOLoop or '
|
||||
'AsyncIOMainLoop')
|
||||
self.server_ioloop = IOLoop()
|
||||
|
||||
sock, self.port = bind_unused_port()
|
||||
app = Application([('/', HelloWorldHandler)])
|
||||
server = HTTPServer(app, io_loop=self.server_ioloop)
|
||||
server.add_socket(sock)
|
||||
self.server = HTTPServer(app, io_loop=self.server_ioloop)
|
||||
self.server.add_socket(sock)
|
||||
|
||||
self.server_thread = threading.Thread(target=self.server_ioloop.start)
|
||||
self.server_thread.start()
|
||||
|
|
@ -451,7 +469,10 @@ class SyncHTTPClientTest(unittest.TestCase):
|
|||
self.http_client = HTTPClient()
|
||||
|
||||
def tearDown(self):
|
||||
self.server_ioloop.add_callback(self.server_ioloop.stop)
|
||||
def stop_server():
|
||||
self.server.stop()
|
||||
self.server_ioloop.stop()
|
||||
self.server_ioloop.add_callback(stop_server)
|
||||
self.server_thread.join()
|
||||
self.http_client.close()
|
||||
self.server_ioloop.close(all_fds=True)
|
||||
|
|
@ -469,3 +490,28 @@ class SyncHTTPClientTest(unittest.TestCase):
|
|||
with self.assertRaises(HTTPError) as assertion:
|
||||
self.http_client.fetch(self.get_url('/notfound'))
|
||||
self.assertEqual(assertion.exception.code, 404)
|
||||
|
||||
|
||||
class HTTPRequestTestCase(unittest.TestCase):
|
||||
def test_headers(self):
|
||||
request = HTTPRequest('http://example.com', headers={'foo': 'bar'})
|
||||
self.assertEqual(request.headers, {'foo': 'bar'})
|
||||
|
||||
def test_headers_setter(self):
|
||||
request = HTTPRequest('http://example.com')
|
||||
request.headers = {'bar': 'baz'}
|
||||
self.assertEqual(request.headers, {'bar': 'baz'})
|
||||
|
||||
def test_null_headers_setter(self):
|
||||
request = HTTPRequest('http://example.com')
|
||||
request.headers = None
|
||||
self.assertEqual(request.headers, {})
|
||||
|
||||
def test_body(self):
|
||||
request = HTTPRequest('http://example.com', body='foo')
|
||||
self.assertEqual(request.body, utf8('foo'))
|
||||
|
||||
def test_body_setter(self):
|
||||
request = HTTPRequest('http://example.com')
|
||||
request.body = 'foo'
|
||||
self.assertEqual(request.body, utf8('foo'))
|
||||
|
|
|
|||
|
|
@ -2,20 +2,23 @@
|
|||
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from tornado import httpclient, simple_httpclient, netutil
|
||||
from tornado.escape import json_decode, utf8, _unicode, recursive_unicode, native_str
|
||||
from tornado import netutil
|
||||
from tornado.escape import json_decode, json_encode, utf8, _unicode, recursive_unicode, native_str
|
||||
from tornado import gen
|
||||
from tornado.http1connection import HTTP1Connection
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.httputil import HTTPHeaders
|
||||
from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine
|
||||
from tornado.iostream import IOStream
|
||||
from tornado.log import gen_log
|
||||
from tornado.netutil import ssl_options_to_context, Resolver
|
||||
from tornado.log import gen_log, app_log
|
||||
from tornado.netutil import ssl_options_to_context
|
||||
from tornado.simple_httpclient import SimpleAsyncHTTPClient
|
||||
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog
|
||||
from tornado.test.util import unittest
|
||||
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test
|
||||
from tornado.test.util import unittest, skipOnTravis
|
||||
from tornado.util import u, bytes_type
|
||||
from tornado.web import Application, RequestHandler, asynchronous
|
||||
from tornado.web import Application, RequestHandler, asynchronous, stream_request_body
|
||||
from contextlib import closing
|
||||
import datetime
|
||||
import gzip
|
||||
import os
|
||||
import shutil
|
||||
import socket
|
||||
|
|
@ -23,6 +26,28 @@ import ssl
|
|||
import sys
|
||||
import tempfile
|
||||
|
||||
try:
|
||||
from io import BytesIO # python 3
|
||||
except ImportError:
|
||||
from cStringIO import StringIO as BytesIO # python 2
|
||||
|
||||
|
||||
def read_stream_body(stream, callback):
|
||||
"""Reads an HTTP response from `stream` and runs callback with its
|
||||
headers and body."""
|
||||
chunks = []
|
||||
class Delegate(HTTPMessageDelegate):
|
||||
def headers_received(self, start_line, headers):
|
||||
self.headers = headers
|
||||
|
||||
def data_received(self, chunk):
|
||||
chunks.append(chunk)
|
||||
|
||||
def finish(self):
|
||||
callback((self.headers, b''.join(chunks)))
|
||||
conn = HTTP1Connection(stream, True)
|
||||
conn.read_response(Delegate())
|
||||
|
||||
|
||||
class HandlerBaseTestCase(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
|
|
@ -86,11 +111,13 @@ class SSLTestMixin(object):
|
|||
# connection, rather than waiting for a timeout or otherwise
|
||||
# misbehaving.
|
||||
with ExpectLog(gen_log, '(SSL Error|uncaught exception)'):
|
||||
self.http_client.fetch(self.get_url("/").replace('https:', 'http:'),
|
||||
self.stop,
|
||||
request_timeout=3600,
|
||||
connect_timeout=3600)
|
||||
response = self.wait()
|
||||
with ExpectLog(gen_log, 'Uncaught exception', required=False):
|
||||
self.http_client.fetch(
|
||||
self.get_url("/").replace('https:', 'http:'),
|
||||
self.stop,
|
||||
request_timeout=3600,
|
||||
connect_timeout=3600)
|
||||
response = self.wait()
|
||||
self.assertEqual(response.code, 599)
|
||||
|
||||
# Python's SSL implementation differs significantly between versions.
|
||||
|
|
@ -163,18 +190,7 @@ class MultipartTestHandler(RequestHandler):
|
|||
})
|
||||
|
||||
|
||||
class RawRequestHTTPConnection(simple_httpclient._HTTPConnection):
|
||||
def set_request(self, request):
|
||||
self.__next_request = request
|
||||
|
||||
def _on_connect(self):
|
||||
self.stream.write(self.__next_request)
|
||||
self.__next_request = None
|
||||
self.stream.read_until(b"\r\n\r\n", self._on_headers)
|
||||
|
||||
# This test is also called from wsgi_test
|
||||
|
||||
|
||||
class HTTPConnectionTest(AsyncHTTPTestCase):
|
||||
def get_handlers(self):
|
||||
return [("/multipart", MultipartTestHandler),
|
||||
|
|
@ -184,23 +200,16 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
|
|||
return Application(self.get_handlers())
|
||||
|
||||
def raw_fetch(self, headers, body):
|
||||
with closing(Resolver(io_loop=self.io_loop)) as resolver:
|
||||
with closing(SimpleAsyncHTTPClient(self.io_loop,
|
||||
resolver=resolver)) as client:
|
||||
conn = RawRequestHTTPConnection(
|
||||
self.io_loop, client,
|
||||
httpclient._RequestProxy(
|
||||
httpclient.HTTPRequest(self.get_url("/")),
|
||||
dict(httpclient.HTTPRequest._DEFAULTS)),
|
||||
None, self.stop,
|
||||
1024 * 1024, resolver)
|
||||
conn.set_request(
|
||||
b"\r\n".join(headers +
|
||||
[utf8("Content-Length: %d\r\n" % len(body))]) +
|
||||
b"\r\n" + body)
|
||||
response = self.wait()
|
||||
response.rethrow()
|
||||
return response
|
||||
with closing(IOStream(socket.socket())) as stream:
|
||||
stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
|
||||
self.wait()
|
||||
stream.write(
|
||||
b"\r\n".join(headers +
|
||||
[utf8("Content-Length: %d\r\n" % len(body))]) +
|
||||
b"\r\n" + body)
|
||||
read_stream_body(stream, self.stop)
|
||||
headers, body = self.wait()
|
||||
return body
|
||||
|
||||
def test_multipart_form(self):
|
||||
# Encodings here are tricky: Headers are latin1, bodies can be
|
||||
|
|
@ -211,17 +220,17 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
|
|||
b"X-Header-encoding-test: \xe9",
|
||||
],
|
||||
b"\r\n".join([
|
||||
b"Content-Disposition: form-data; name=argument",
|
||||
b"",
|
||||
u("\u00e1").encode("utf-8"),
|
||||
b"--1234567890",
|
||||
u('Content-Disposition: form-data; name="files"; filename="\u00f3"').encode("utf8"),
|
||||
b"",
|
||||
u("\u00fa").encode("utf-8"),
|
||||
b"--1234567890--",
|
||||
b"",
|
||||
b"Content-Disposition: form-data; name=argument",
|
||||
b"",
|
||||
u("\u00e1").encode("utf-8"),
|
||||
b"--1234567890",
|
||||
u('Content-Disposition: form-data; name="files"; filename="\u00f3"').encode("utf8"),
|
||||
b"",
|
||||
u("\u00fa").encode("utf-8"),
|
||||
b"--1234567890--",
|
||||
b"",
|
||||
]))
|
||||
data = json_decode(response.body)
|
||||
data = json_decode(response)
|
||||
self.assertEqual(u("\u00e9"), data["header"])
|
||||
self.assertEqual(u("\u00e1"), data["argument"])
|
||||
self.assertEqual(u("\u00f3"), data["filename"])
|
||||
|
|
@ -344,6 +353,21 @@ class HTTPServerTest(AsyncHTTPTestCase):
|
|||
self.assertEqual(200, response.code)
|
||||
self.assertEqual(json_decode(response.body), {})
|
||||
|
||||
def test_malformed_body(self):
|
||||
# parse_qs is pretty forgiving, but it will fail on python 3
|
||||
# if the data is not utf8. On python 2 parse_qs will work,
|
||||
# but then the recursive_unicode call in EchoHandler will
|
||||
# fail.
|
||||
if str is bytes_type:
|
||||
return
|
||||
with ExpectLog(gen_log, 'Invalid x-www-form-urlencoded body'):
|
||||
response = self.fetch(
|
||||
'/echo', method="POST",
|
||||
headers={'Content-Type': 'application/x-www-form-urlencoded'},
|
||||
body=b'\xe9')
|
||||
self.assertEqual(200, response.code)
|
||||
self.assertEqual(b'{}', response.body)
|
||||
|
||||
|
||||
class HTTPServerRawTest(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
|
|
@ -382,6 +406,25 @@ class HTTPServerRawTest(AsyncHTTPTestCase):
|
|||
self.stop)
|
||||
self.wait()
|
||||
|
||||
def test_chunked_request_body(self):
|
||||
# Chunked requests are not widely supported and we don't have a way
|
||||
# to generate them in AsyncHTTPClient, but HTTPServer will read them.
|
||||
self.stream.write(b"""\
|
||||
POST /echo HTTP/1.1
|
||||
Transfer-Encoding: chunked
|
||||
Content-Type: application/x-www-form-urlencoded
|
||||
|
||||
4
|
||||
foo=
|
||||
3
|
||||
bar
|
||||
0
|
||||
|
||||
""".replace(b"\n", b"\r\n"))
|
||||
read_stream_body(self.stream, self.stop)
|
||||
headers, response = self.wait()
|
||||
self.assertEqual(json_decode(response), {u('foo'): [u('bar')]})
|
||||
|
||||
|
||||
class XHeaderTest(HandlerBaseTestCase):
|
||||
class Handler(RequestHandler):
|
||||
|
|
@ -497,31 +540,40 @@ class UnixSocketTest(AsyncTestCase):
|
|||
def setUp(self):
|
||||
super(UnixSocketTest, self).setUp()
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
self.sockfile = os.path.join(self.tmpdir, "test.sock")
|
||||
sock = netutil.bind_unix_socket(self.sockfile)
|
||||
app = Application([("/hello", HelloWorldRequestHandler)])
|
||||
self.server = HTTPServer(app, io_loop=self.io_loop)
|
||||
self.server.add_socket(sock)
|
||||
self.stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop)
|
||||
self.stream.connect(self.sockfile, self.stop)
|
||||
self.wait()
|
||||
|
||||
def tearDown(self):
|
||||
self.stream.close()
|
||||
self.server.stop()
|
||||
shutil.rmtree(self.tmpdir)
|
||||
super(UnixSocketTest, self).tearDown()
|
||||
|
||||
def test_unix_socket(self):
|
||||
sockfile = os.path.join(self.tmpdir, "test.sock")
|
||||
sock = netutil.bind_unix_socket(sockfile)
|
||||
app = Application([("/hello", HelloWorldRequestHandler)])
|
||||
server = HTTPServer(app, io_loop=self.io_loop)
|
||||
server.add_socket(sock)
|
||||
stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop)
|
||||
stream.connect(sockfile, self.stop)
|
||||
self.wait()
|
||||
stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
|
||||
stream.read_until(b"\r\n", self.stop)
|
||||
self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
|
||||
self.stream.read_until(b"\r\n", self.stop)
|
||||
response = self.wait()
|
||||
self.assertEqual(response, b"HTTP/1.0 200 OK\r\n")
|
||||
stream.read_until(b"\r\n\r\n", self.stop)
|
||||
self.stream.read_until(b"\r\n\r\n", self.stop)
|
||||
headers = HTTPHeaders.parse(self.wait().decode('latin1'))
|
||||
stream.read_bytes(int(headers["Content-Length"]), self.stop)
|
||||
self.stream.read_bytes(int(headers["Content-Length"]), self.stop)
|
||||
body = self.wait()
|
||||
self.assertEqual(body, b"Hello world")
|
||||
stream.close()
|
||||
server.stop()
|
||||
|
||||
def test_unix_socket_bad_request(self):
|
||||
# Unix sockets don't have remote addresses so they just return an
|
||||
# empty string.
|
||||
with ExpectLog(gen_log, "Malformed HTTP message from"):
|
||||
self.stream.write(b"garbage\r\n\r\n")
|
||||
self.stream.read_until_close(self.stop)
|
||||
response = self.wait()
|
||||
self.assertEqual(response, b"")
|
||||
|
||||
|
||||
class KeepAliveTest(AsyncHTTPTestCase):
|
||||
|
|
@ -586,8 +638,8 @@ class KeepAliveTest(AsyncHTTPTestCase):
|
|||
return headers
|
||||
|
||||
def read_response(self):
|
||||
headers = self.read_headers()
|
||||
self.stream.read_bytes(int(headers['Content-Length']), self.stop)
|
||||
self.headers = self.read_headers()
|
||||
self.stream.read_bytes(int(self.headers['Content-Length']), self.stop)
|
||||
body = self.wait()
|
||||
self.assertEqual(b'Hello world', body)
|
||||
|
||||
|
|
@ -621,6 +673,7 @@ class KeepAliveTest(AsyncHTTPTestCase):
|
|||
self.stream.read_until_close(callback=self.stop)
|
||||
data = self.wait()
|
||||
self.assertTrue(not data)
|
||||
self.assertTrue('Connection' not in self.headers)
|
||||
self.close()
|
||||
|
||||
def test_http10_keepalive(self):
|
||||
|
|
@ -628,8 +681,10 @@ class KeepAliveTest(AsyncHTTPTestCase):
|
|||
self.connect()
|
||||
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
|
||||
self.read_response()
|
||||
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
|
||||
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
|
||||
self.read_response()
|
||||
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
|
||||
self.close()
|
||||
|
||||
def test_pipelined_requests(self):
|
||||
|
|
@ -659,3 +714,322 @@ class KeepAliveTest(AsyncHTTPTestCase):
|
|||
self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n')
|
||||
self.read_headers()
|
||||
self.close()
|
||||
|
||||
|
||||
class GzipBaseTest(object):
|
||||
def get_app(self):
|
||||
return Application([('/', EchoHandler)])
|
||||
|
||||
def post_gzip(self, body):
|
||||
bytesio = BytesIO()
|
||||
gzip_file = gzip.GzipFile(mode='w', fileobj=bytesio)
|
||||
gzip_file.write(utf8(body))
|
||||
gzip_file.close()
|
||||
compressed_body = bytesio.getvalue()
|
||||
return self.fetch('/', method='POST', body=compressed_body,
|
||||
headers={'Content-Encoding': 'gzip'})
|
||||
|
||||
def test_uncompressed(self):
|
||||
response = self.fetch('/', method='POST', body='foo=bar')
|
||||
self.assertEquals(json_decode(response.body), {u('foo'): [u('bar')]})
|
||||
|
||||
|
||||
class GzipTest(GzipBaseTest, AsyncHTTPTestCase):
|
||||
def get_httpserver_options(self):
|
||||
return dict(decompress_request=True)
|
||||
|
||||
def test_gzip(self):
|
||||
response = self.post_gzip('foo=bar')
|
||||
self.assertEquals(json_decode(response.body), {u('foo'): [u('bar')]})
|
||||
|
||||
|
||||
class GzipUnsupportedTest(GzipBaseTest, AsyncHTTPTestCase):
|
||||
def test_gzip_unsupported(self):
|
||||
# Gzip support is opt-in; without it the server fails to parse
|
||||
# the body (but parsing form bodies is currently just a log message,
|
||||
# not a fatal error).
|
||||
with ExpectLog(gen_log, "Unsupported Content-Encoding"):
|
||||
response = self.post_gzip('foo=bar')
|
||||
self.assertEquals(json_decode(response.body), {})
|
||||
|
||||
|
||||
class StreamingChunkSizeTest(AsyncHTTPTestCase):
|
||||
# 50 characters long, and repetitive so it can be compressed.
|
||||
BODY = b'01234567890123456789012345678901234567890123456789'
|
||||
CHUNK_SIZE = 16
|
||||
|
||||
def get_http_client(self):
|
||||
# body_producer doesn't work on curl_httpclient, so override the
|
||||
# configured AsyncHTTPClient implementation.
|
||||
return SimpleAsyncHTTPClient(io_loop=self.io_loop)
|
||||
|
||||
def get_httpserver_options(self):
|
||||
return dict(chunk_size=self.CHUNK_SIZE, decompress_request=True)
|
||||
|
||||
class MessageDelegate(HTTPMessageDelegate):
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def headers_received(self, start_line, headers):
|
||||
self.chunk_lengths = []
|
||||
|
||||
def data_received(self, chunk):
|
||||
self.chunk_lengths.append(len(chunk))
|
||||
|
||||
def finish(self):
|
||||
response_body = utf8(json_encode(self.chunk_lengths))
|
||||
self.connection.write_headers(
|
||||
ResponseStartLine('HTTP/1.1', 200, 'OK'),
|
||||
HTTPHeaders({'Content-Length': str(len(response_body))}))
|
||||
self.connection.write(response_body)
|
||||
self.connection.finish()
|
||||
|
||||
def get_app(self):
|
||||
class App(HTTPServerConnectionDelegate):
|
||||
def start_request(self, connection):
|
||||
return StreamingChunkSizeTest.MessageDelegate(connection)
|
||||
return App()
|
||||
|
||||
def fetch_chunk_sizes(self, **kwargs):
|
||||
response = self.fetch('/', method='POST', **kwargs)
|
||||
response.rethrow()
|
||||
chunks = json_decode(response.body)
|
||||
self.assertEqual(len(self.BODY), sum(chunks))
|
||||
for chunk_size in chunks:
|
||||
self.assertLessEqual(chunk_size, self.CHUNK_SIZE,
|
||||
'oversized chunk: ' + str(chunks))
|
||||
self.assertGreater(chunk_size, 0,
|
||||
'empty chunk: ' + str(chunks))
|
||||
return chunks
|
||||
|
||||
def compress(self, body):
|
||||
bytesio = BytesIO()
|
||||
gzfile = gzip.GzipFile(mode='w', fileobj=bytesio)
|
||||
gzfile.write(body)
|
||||
gzfile.close()
|
||||
compressed = bytesio.getvalue()
|
||||
if len(compressed) >= len(body):
|
||||
raise Exception("body did not shrink when compressed")
|
||||
return compressed
|
||||
|
||||
def test_regular_body(self):
|
||||
chunks = self.fetch_chunk_sizes(body=self.BODY)
|
||||
# Without compression we know exactly what to expect.
|
||||
self.assertEqual([16, 16, 16, 2], chunks)
|
||||
|
||||
def test_compressed_body(self):
|
||||
self.fetch_chunk_sizes(body=self.compress(self.BODY),
|
||||
headers={'Content-Encoding': 'gzip'})
|
||||
# Compression creates irregular boundaries so the assertions
|
||||
# in fetch_chunk_sizes are as specific as we can get.
|
||||
|
||||
def test_chunked_body(self):
|
||||
def body_producer(write):
|
||||
write(self.BODY[:20])
|
||||
write(self.BODY[20:])
|
||||
chunks = self.fetch_chunk_sizes(body_producer=body_producer)
|
||||
# HTTP chunk boundaries translate to application-visible breaks
|
||||
self.assertEqual([16, 4, 16, 14], chunks)
|
||||
|
||||
def test_chunked_compressed(self):
|
||||
compressed = self.compress(self.BODY)
|
||||
self.assertGreater(len(compressed), 20)
|
||||
def body_producer(write):
|
||||
write(compressed[:20])
|
||||
write(compressed[20:])
|
||||
self.fetch_chunk_sizes(body_producer=body_producer,
|
||||
headers={'Content-Encoding': 'gzip'})
|
||||
|
||||
|
||||
class MaxHeaderSizeTest(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
return Application([('/', HelloWorldRequestHandler)])
|
||||
|
||||
def get_httpserver_options(self):
|
||||
return dict(max_header_size=1024)
|
||||
|
||||
def test_small_headers(self):
|
||||
response = self.fetch("/", headers={'X-Filler': 'a' * 100})
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b"Hello world")
|
||||
|
||||
def test_large_headers(self):
|
||||
with ExpectLog(gen_log, "Unsatisfiable read"):
|
||||
response = self.fetch("/", headers={'X-Filler': 'a' * 1000})
|
||||
self.assertEqual(response.code, 599)
|
||||
|
||||
|
||||
@skipOnTravis
|
||||
class IdleTimeoutTest(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
return Application([('/', HelloWorldRequestHandler)])
|
||||
|
||||
def get_httpserver_options(self):
|
||||
return dict(idle_connection_timeout=0.1)
|
||||
|
||||
def setUp(self):
|
||||
super(IdleTimeoutTest, self).setUp()
|
||||
self.streams = []
|
||||
|
||||
def tearDown(self):
|
||||
super(IdleTimeoutTest, self).tearDown()
|
||||
for stream in self.streams:
|
||||
stream.close()
|
||||
|
||||
def connect(self):
|
||||
stream = IOStream(socket.socket())
|
||||
stream.connect(('localhost', self.get_http_port()), self.stop)
|
||||
self.wait()
|
||||
self.streams.append(stream)
|
||||
return stream
|
||||
|
||||
def test_unused_connection(self):
|
||||
stream = self.connect()
|
||||
stream.set_close_callback(self.stop)
|
||||
self.wait()
|
||||
|
||||
def test_idle_after_use(self):
|
||||
stream = self.connect()
|
||||
stream.set_close_callback(lambda: self.stop("closed"))
|
||||
|
||||
# Use the connection twice to make sure keep-alives are working
|
||||
for i in range(2):
|
||||
stream.write(b"GET / HTTP/1.1\r\n\r\n")
|
||||
stream.read_until(b"\r\n\r\n", self.stop)
|
||||
self.wait()
|
||||
stream.read_bytes(11, self.stop)
|
||||
data = self.wait()
|
||||
self.assertEqual(data, b"Hello world")
|
||||
|
||||
# Now let the timeout trigger and close the connection.
|
||||
data = self.wait()
|
||||
self.assertEqual(data, "closed")
|
||||
|
||||
|
||||
class BodyLimitsTest(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
class BufferedHandler(RequestHandler):
|
||||
def put(self):
|
||||
self.write(str(len(self.request.body)))
|
||||
|
||||
@stream_request_body
|
||||
class StreamingHandler(RequestHandler):
|
||||
def initialize(self):
|
||||
self.bytes_read = 0
|
||||
|
||||
def prepare(self):
|
||||
if 'expected_size' in self.request.arguments:
|
||||
self.request.connection.set_max_body_size(
|
||||
int(self.get_argument('expected_size')))
|
||||
if 'body_timeout' in self.request.arguments:
|
||||
self.request.connection.set_body_timeout(
|
||||
float(self.get_argument('body_timeout')))
|
||||
|
||||
def data_received(self, data):
|
||||
self.bytes_read += len(data)
|
||||
|
||||
def put(self):
|
||||
self.write(str(self.bytes_read))
|
||||
|
||||
return Application([('/buffered', BufferedHandler),
|
||||
('/streaming', StreamingHandler)])
|
||||
|
||||
def get_httpserver_options(self):
|
||||
return dict(body_timeout=3600, max_body_size=4096)
|
||||
|
||||
def get_http_client(self):
|
||||
# body_producer doesn't work on curl_httpclient, so override the
|
||||
# configured AsyncHTTPClient implementation.
|
||||
return SimpleAsyncHTTPClient(io_loop=self.io_loop)
|
||||
|
||||
def test_small_body(self):
|
||||
response = self.fetch('/buffered', method='PUT', body=b'a' * 4096)
|
||||
self.assertEqual(response.body, b'4096')
|
||||
response = self.fetch('/streaming', method='PUT', body=b'a' * 4096)
|
||||
self.assertEqual(response.body, b'4096')
|
||||
|
||||
def test_large_body_buffered(self):
|
||||
with ExpectLog(gen_log, '.*Content-Length too long'):
|
||||
response = self.fetch('/buffered', method='PUT', body=b'a' * 10240)
|
||||
self.assertEqual(response.code, 599)
|
||||
|
||||
def test_large_body_buffered_chunked(self):
|
||||
with ExpectLog(gen_log, '.*chunked body too large'):
|
||||
response = self.fetch('/buffered', method='PUT',
|
||||
body_producer=lambda write: write(b'a' * 10240))
|
||||
self.assertEqual(response.code, 599)
|
||||
|
||||
def test_large_body_streaming(self):
|
||||
with ExpectLog(gen_log, '.*Content-Length too long'):
|
||||
response = self.fetch('/streaming', method='PUT', body=b'a' * 10240)
|
||||
self.assertEqual(response.code, 599)
|
||||
|
||||
def test_large_body_streaming_chunked(self):
|
||||
with ExpectLog(gen_log, '.*chunked body too large'):
|
||||
response = self.fetch('/streaming', method='PUT',
|
||||
body_producer=lambda write: write(b'a' * 10240))
|
||||
self.assertEqual(response.code, 599)
|
||||
|
||||
def test_large_body_streaming_override(self):
|
||||
response = self.fetch('/streaming?expected_size=10240', method='PUT',
|
||||
body=b'a' * 10240)
|
||||
self.assertEqual(response.body, b'10240')
|
||||
|
||||
def test_large_body_streaming_chunked_override(self):
|
||||
response = self.fetch('/streaming?expected_size=10240', method='PUT',
|
||||
body_producer=lambda write: write(b'a' * 10240))
|
||||
self.assertEqual(response.body, b'10240')
|
||||
|
||||
@gen_test
|
||||
def test_timeout(self):
|
||||
stream = IOStream(socket.socket())
|
||||
try:
|
||||
yield stream.connect(('127.0.0.1', self.get_http_port()))
|
||||
# Use a raw stream because AsyncHTTPClient won't let us read a
|
||||
# response without finishing a body.
|
||||
stream.write(b'PUT /streaming?body_timeout=0.1 HTTP/1.0\r\n'
|
||||
b'Content-Length: 42\r\n\r\n')
|
||||
with ExpectLog(gen_log, 'Timeout reading body'):
|
||||
response = yield stream.read_until_close()
|
||||
self.assertEqual(response, b'')
|
||||
finally:
|
||||
stream.close()
|
||||
|
||||
@gen_test
|
||||
def test_body_size_override_reset(self):
|
||||
# The max_body_size override is reset between requests.
|
||||
stream = IOStream(socket.socket())
|
||||
try:
|
||||
yield stream.connect(('127.0.0.1', self.get_http_port()))
|
||||
# Use a raw stream so we can make sure it's all on one connection.
|
||||
stream.write(b'PUT /streaming?expected_size=10240 HTTP/1.1\r\n'
|
||||
b'Content-Length: 10240\r\n\r\n')
|
||||
stream.write(b'a' * 10240)
|
||||
headers, response = yield gen.Task(read_stream_body, stream)
|
||||
self.assertEqual(response, b'10240')
|
||||
# Without the ?expected_size parameter, we get the old default value
|
||||
stream.write(b'PUT /streaming HTTP/1.1\r\n'
|
||||
b'Content-Length: 10240\r\n\r\n')
|
||||
with ExpectLog(gen_log, '.*Content-Length too long'):
|
||||
data = yield stream.read_until_close()
|
||||
self.assertEqual(data, b'')
|
||||
finally:
|
||||
stream.close()
|
||||
|
||||
|
||||
class LegacyInterfaceTest(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
# The old request_callback interface does not implement the
|
||||
# delegate interface, and writes its response via request.write
|
||||
# instead of request.connection.write_headers.
|
||||
def handle_request(request):
|
||||
message = b"Hello world"
|
||||
request.write(utf8("HTTP/1.1 200 OK\r\n"
|
||||
"Content-Length: %d\r\n\r\n" % len(message)))
|
||||
request.write(message)
|
||||
request.finish()
|
||||
return handle_request
|
||||
|
||||
def test_legacy_interface(self):
|
||||
response = self.fetch('/')
|
||||
self.assertEqual(response.body, b"Hello world")
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ class ImportTest(unittest.TestCase):
|
|||
# import tornado.curl_httpclient # depends on pycurl
|
||||
import tornado.escape
|
||||
import tornado.gen
|
||||
import tornado.http1connection
|
||||
import tornado.httpclient
|
||||
import tornado.httpserver
|
||||
import tornado.httputil
|
||||
|
|
|
|||
|
|
@ -12,8 +12,9 @@ import time
|
|||
|
||||
from tornado import gen
|
||||
from tornado.ioloop import IOLoop, TimeoutError
|
||||
from tornado.log import app_log
|
||||
from tornado.stack_context import ExceptionStackContext, StackContext, wrap, NullContext
|
||||
from tornado.testing import AsyncTestCase, bind_unused_port
|
||||
from tornado.testing import AsyncTestCase, bind_unused_port, ExpectLog
|
||||
from tornado.test.util import unittest, skipIfNonUnix, skipOnTravis
|
||||
|
||||
try:
|
||||
|
|
@ -51,7 +52,8 @@ class TestIOLoop(AsyncTestCase):
|
|||
thread = threading.Thread(target=target)
|
||||
self.io_loop.add_callback(thread.start)
|
||||
self.wait()
|
||||
self.assertAlmostEqual(time.time(), self.stop_time, places=2)
|
||||
delta = time.time() - self.stop_time
|
||||
self.assertLess(delta, 0.1)
|
||||
thread.join()
|
||||
|
||||
def test_add_timeout_timedelta(self):
|
||||
|
|
@ -153,7 +155,7 @@ class TestIOLoop(AsyncTestCase):
|
|||
|
||||
def test_remove_timeout_after_fire(self):
|
||||
# It is not an error to call remove_timeout after it has run.
|
||||
handle = self.io_loop.add_timeout(self.io_loop.time(), self.stop())
|
||||
handle = self.io_loop.add_timeout(self.io_loop.time(), self.stop)
|
||||
self.wait()
|
||||
self.io_loop.remove_timeout(handle)
|
||||
|
||||
|
|
@ -171,6 +173,131 @@ class TestIOLoop(AsyncTestCase):
|
|||
self.io_loop.add_callback(lambda: self.io_loop.add_callback(self.stop))
|
||||
self.wait()
|
||||
|
||||
def test_timeout_with_arguments(self):
|
||||
# This tests that all the timeout methods pass through *args correctly.
|
||||
results = []
|
||||
self.io_loop.add_timeout(self.io_loop.time(), results.append, 1)
|
||||
self.io_loop.add_timeout(datetime.timedelta(seconds=0),
|
||||
results.append, 2)
|
||||
self.io_loop.call_at(self.io_loop.time(), results.append, 3)
|
||||
self.io_loop.call_later(0, results.append, 4)
|
||||
self.io_loop.call_later(0, self.stop)
|
||||
self.wait()
|
||||
self.assertEqual(results, [1, 2, 3, 4])
|
||||
|
||||
def test_close_file_object(self):
|
||||
"""When a file object is used instead of a numeric file descriptor,
|
||||
the object should be closed (by IOLoop.close(all_fds=True),
|
||||
not just the fd.
|
||||
"""
|
||||
# Use a socket since they are supported by IOLoop on all platforms.
|
||||
# Unfortunately, sockets don't support the .closed attribute for
|
||||
# inspecting their close status, so we must use a wrapper.
|
||||
class SocketWrapper(object):
|
||||
def __init__(self, sockobj):
|
||||
self.sockobj = sockobj
|
||||
self.closed = False
|
||||
|
||||
def fileno(self):
|
||||
return self.sockobj.fileno()
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
self.sockobj.close()
|
||||
sockobj, port = bind_unused_port()
|
||||
socket_wrapper = SocketWrapper(sockobj)
|
||||
io_loop = IOLoop()
|
||||
io_loop.add_handler(socket_wrapper, lambda fd, events: None,
|
||||
IOLoop.READ)
|
||||
io_loop.close(all_fds=True)
|
||||
self.assertTrue(socket_wrapper.closed)
|
||||
|
||||
def test_handler_callback_file_object(self):
|
||||
"""The handler callback receives the same fd object it passed in."""
|
||||
server_sock, port = bind_unused_port()
|
||||
fds = []
|
||||
def handle_connection(fd, events):
|
||||
fds.append(fd)
|
||||
conn, addr = server_sock.accept()
|
||||
conn.close()
|
||||
self.stop()
|
||||
self.io_loop.add_handler(server_sock, handle_connection, IOLoop.READ)
|
||||
with contextlib.closing(socket.socket()) as client_sock:
|
||||
client_sock.connect(('127.0.0.1', port))
|
||||
self.wait()
|
||||
self.io_loop.remove_handler(server_sock)
|
||||
self.io_loop.add_handler(server_sock.fileno(), handle_connection,
|
||||
IOLoop.READ)
|
||||
with contextlib.closing(socket.socket()) as client_sock:
|
||||
client_sock.connect(('127.0.0.1', port))
|
||||
self.wait()
|
||||
self.assertIs(fds[0], server_sock)
|
||||
self.assertEqual(fds[1], server_sock.fileno())
|
||||
self.io_loop.remove_handler(server_sock.fileno())
|
||||
server_sock.close()
|
||||
|
||||
def test_mixed_fd_fileobj(self):
|
||||
server_sock, port = bind_unused_port()
|
||||
def f(fd, events):
|
||||
pass
|
||||
self.io_loop.add_handler(server_sock, f, IOLoop.READ)
|
||||
with self.assertRaises(Exception):
|
||||
# The exact error is unspecified - some implementations use
|
||||
# IOError, others use ValueError.
|
||||
self.io_loop.add_handler(server_sock.fileno(), f, IOLoop.READ)
|
||||
self.io_loop.remove_handler(server_sock.fileno())
|
||||
server_sock.close()
|
||||
|
||||
def test_reentrant(self):
|
||||
"""Calling start() twice should raise an error, not deadlock."""
|
||||
returned_from_start = [False]
|
||||
got_exception = [False]
|
||||
def callback():
|
||||
try:
|
||||
self.io_loop.start()
|
||||
returned_from_start[0] = True
|
||||
except Exception:
|
||||
got_exception[0] = True
|
||||
self.stop()
|
||||
self.io_loop.add_callback(callback)
|
||||
self.wait()
|
||||
self.assertTrue(got_exception[0])
|
||||
self.assertFalse(returned_from_start[0])
|
||||
|
||||
def test_exception_logging(self):
|
||||
"""Uncaught exceptions get logged by the IOLoop."""
|
||||
# Use a NullContext to keep the exception from being caught by
|
||||
# AsyncTestCase.
|
||||
with NullContext():
|
||||
self.io_loop.add_callback(lambda: 1/0)
|
||||
self.io_loop.add_callback(self.stop)
|
||||
with ExpectLog(app_log, "Exception in callback"):
|
||||
self.wait()
|
||||
|
||||
def test_exception_logging_future(self):
|
||||
"""The IOLoop examines exceptions from Futures and logs them."""
|
||||
with NullContext():
|
||||
@gen.coroutine
|
||||
def callback():
|
||||
self.io_loop.add_callback(self.stop)
|
||||
1/0
|
||||
self.io_loop.add_callback(callback)
|
||||
with ExpectLog(app_log, "Exception in callback"):
|
||||
self.wait()
|
||||
|
||||
def test_spawn_callback(self):
|
||||
# An added callback runs in the test's stack_context, so will be
|
||||
# re-arised in wait().
|
||||
self.io_loop.add_callback(lambda: 1/0)
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
self.wait()
|
||||
# A spawned callback is run directly on the IOLoop, so it will be
|
||||
# logged without stopping the test.
|
||||
self.io_loop.spawn_callback(lambda: 1/0)
|
||||
self.io_loop.add_callback(self.stop)
|
||||
with ExpectLog(app_log, "Exception in callback"):
|
||||
self.wait()
|
||||
|
||||
|
||||
# Deliberately not a subclass of AsyncTestCase so the IOLoop isn't
|
||||
# automatically set as current.
|
||||
|
|
@ -329,5 +456,6 @@ class TestIOLoopRunSync(unittest.TestCase):
|
|||
yield gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1)
|
||||
self.assertRaises(TimeoutError, self.io_loop.run_sync, f, timeout=0.01)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -1,13 +1,16 @@
|
|||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from tornado.concurrent import Future
|
||||
from tornado import gen
|
||||
from tornado import netutil
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.iostream import IOStream, SSLIOStream, PipeIOStream
|
||||
from tornado.iostream import IOStream, SSLIOStream, PipeIOStream, StreamClosedError
|
||||
from tornado.httputil import HTTPHeaders
|
||||
from tornado.log import gen_log, app_log
|
||||
from tornado.netutil import ssl_wrap_socket
|
||||
from tornado.stack_context import NullContext
|
||||
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog
|
||||
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test
|
||||
from tornado.test.util import unittest, skipIfNonUnix
|
||||
from tornado.web import RequestHandler, Application
|
||||
import certifi
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
|
|
@ -17,6 +20,13 @@ import ssl
|
|||
import sys
|
||||
|
||||
|
||||
def _server_ssl_options():
|
||||
return dict(
|
||||
certfile=os.path.join(os.path.dirname(__file__), 'test.crt'),
|
||||
keyfile=os.path.join(os.path.dirname(__file__), 'test.key'),
|
||||
)
|
||||
|
||||
|
||||
class HelloHandler(RequestHandler):
|
||||
def get(self):
|
||||
self.write("Hello")
|
||||
|
|
@ -106,6 +116,48 @@ class TestIOStreamWebMixin(object):
|
|||
|
||||
stream.close()
|
||||
|
||||
@gen_test
|
||||
def test_future_interface(self):
|
||||
"""Basic test of IOStream's ability to return Futures."""
|
||||
stream = self._make_client_iostream()
|
||||
connect_result = yield stream.connect(
|
||||
("localhost", self.get_http_port()))
|
||||
self.assertIs(connect_result, stream)
|
||||
yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
|
||||
first_line = yield stream.read_until(b"\r\n")
|
||||
self.assertEqual(first_line, b"HTTP/1.0 200 OK\r\n")
|
||||
# callback=None is equivalent to no callback.
|
||||
header_data = yield stream.read_until(b"\r\n\r\n", callback=None)
|
||||
headers = HTTPHeaders.parse(header_data.decode('latin1'))
|
||||
content_length = int(headers['Content-Length'])
|
||||
body = yield stream.read_bytes(content_length)
|
||||
self.assertEqual(body, b'Hello')
|
||||
stream.close()
|
||||
|
||||
@gen_test
|
||||
def test_future_close_while_reading(self):
|
||||
stream = self._make_client_iostream()
|
||||
yield stream.connect(("localhost", self.get_http_port()))
|
||||
yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
|
||||
with self.assertRaises(StreamClosedError):
|
||||
yield stream.read_bytes(1024 * 1024)
|
||||
stream.close()
|
||||
|
||||
@gen_test
|
||||
def test_future_read_until_close(self):
|
||||
# Ensure that the data comes through before the StreamClosedError.
|
||||
stream = self._make_client_iostream()
|
||||
yield stream.connect(("localhost", self.get_http_port()))
|
||||
yield stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n")
|
||||
yield stream.read_until(b"\r\n\r\n")
|
||||
body = yield stream.read_until_close()
|
||||
self.assertEqual(body, b"Hello")
|
||||
|
||||
# Nothing else to read; the error comes immediately without waiting
|
||||
# for yield.
|
||||
with self.assertRaises(StreamClosedError):
|
||||
stream.read_bytes(1)
|
||||
|
||||
|
||||
class TestIOStreamMixin(object):
|
||||
def _make_server_iostream(self, connection, **kwargs):
|
||||
|
|
@ -120,16 +172,6 @@ class TestIOStreamMixin(object):
|
|||
|
||||
def accept_callback(connection, address):
|
||||
streams[0] = self._make_server_iostream(connection, **kwargs)
|
||||
if isinstance(streams[0], SSLIOStream):
|
||||
# HACK: The SSL handshake won't complete (and
|
||||
# therefore the client connect callback won't be
|
||||
# run)until the server side has tried to do something
|
||||
# with the connection. For these tests we want both
|
||||
# sides to connect before we do anything else with the
|
||||
# connection, so we must cause some dummy activity on the
|
||||
# server. If this turns out to be useful for real apps
|
||||
# it should have a cleaner interface.
|
||||
streams[0]._add_io_state(IOLoop.READ)
|
||||
self.stop()
|
||||
|
||||
def connect_callback():
|
||||
|
|
@ -168,9 +210,6 @@ class TestIOStreamMixin(object):
|
|||
server, client = self.make_iostream_pair()
|
||||
server.write(b'', callback=self.stop)
|
||||
self.wait()
|
||||
# As a side effect, the stream is now listening for connection
|
||||
# close (if it wasn't already), but is not listening for writes
|
||||
self.assertEqual(server._state, IOLoop.READ | IOLoop.ERROR)
|
||||
server.close()
|
||||
client.close()
|
||||
|
||||
|
|
@ -193,8 +232,11 @@ class TestIOStreamMixin(object):
|
|||
self.assertFalse(self.connect_called)
|
||||
self.assertTrue(isinstance(stream.error, socket.error), stream.error)
|
||||
if sys.platform != 'cygwin':
|
||||
_ERRNO_CONNREFUSED = (errno.ECONNREFUSED,)
|
||||
if hasattr(errno, "WSAECONNREFUSED"):
|
||||
_ERRNO_CONNREFUSED += (errno.WSAECONNREFUSED,)
|
||||
# cygwin's errnos don't match those used on native windows python
|
||||
self.assertEqual(stream.error.args[0], errno.ECONNREFUSED)
|
||||
self.assertTrue(stream.error.args[0] in _ERRNO_CONNREFUSED)
|
||||
|
||||
def test_gaierror(self):
|
||||
# Test that IOStream sets its exc_info on getaddrinfo error
|
||||
|
|
@ -308,6 +350,25 @@ class TestIOStreamMixin(object):
|
|||
server.close()
|
||||
client.close()
|
||||
|
||||
def test_future_delayed_close_callback(self):
|
||||
# Same as test_delayed_close_callback, but with the future interface.
|
||||
server, client = self.make_iostream_pair()
|
||||
# We can't call make_iostream_pair inside a gen_test function
|
||||
# because the ioloop is not reentrant.
|
||||
@gen_test
|
||||
def f(self):
|
||||
server.write(b"12")
|
||||
chunks = []
|
||||
chunks.append((yield client.read_bytes(1)))
|
||||
server.close()
|
||||
chunks.append((yield client.read_bytes(1)))
|
||||
self.assertEqual(chunks, [b"1", b"2"])
|
||||
try:
|
||||
f(self)
|
||||
finally:
|
||||
server.close()
|
||||
client.close()
|
||||
|
||||
def test_close_buffered_data(self):
|
||||
# Similar to the previous test, but with data stored in the OS's
|
||||
# socket buffers instead of the IOStream's read buffer. Out-of-band
|
||||
|
|
@ -340,14 +401,18 @@ class TestIOStreamMixin(object):
|
|||
# Similar to test_delayed_close_callback, but read_until_close takes
|
||||
# a separate code path so test it separately.
|
||||
server, client = self.make_iostream_pair()
|
||||
client.set_close_callback(self.stop)
|
||||
try:
|
||||
server.write(b"1234")
|
||||
server.close()
|
||||
self.wait()
|
||||
# Read one byte to make sure the client has received the data.
|
||||
# It won't run the close callback as long as there is more buffered
|
||||
# data that could satisfy a later read.
|
||||
client.read_bytes(1, self.stop)
|
||||
data = self.wait()
|
||||
self.assertEqual(data, b"1")
|
||||
client.read_until_close(self.stop)
|
||||
data = self.wait()
|
||||
self.assertEqual(data, b"1234")
|
||||
self.assertEqual(data, b"234")
|
||||
finally:
|
||||
server.close()
|
||||
client.close()
|
||||
|
|
@ -357,17 +422,18 @@ class TestIOStreamMixin(object):
|
|||
# All data should go through the streaming callback,
|
||||
# and the final read callback just gets an empty string.
|
||||
server, client = self.make_iostream_pair()
|
||||
client.set_close_callback(self.stop)
|
||||
try:
|
||||
server.write(b"1234")
|
||||
server.close()
|
||||
self.wait()
|
||||
client.read_bytes(1, self.stop)
|
||||
data = self.wait()
|
||||
self.assertEqual(data, b"1")
|
||||
streaming_data = []
|
||||
client.read_until_close(self.stop,
|
||||
streaming_callback=streaming_data.append)
|
||||
data = self.wait()
|
||||
self.assertEqual(b'', data)
|
||||
self.assertEqual(b''.join(streaming_data), b"1234")
|
||||
self.assertEqual(b''.join(streaming_data), b"234")
|
||||
finally:
|
||||
server.close()
|
||||
client.close()
|
||||
|
|
@ -461,6 +527,203 @@ class TestIOStreamMixin(object):
|
|||
server.close()
|
||||
client.close()
|
||||
|
||||
def test_future_close_callback(self):
|
||||
# Regression test for interaction between the Future read interfaces
|
||||
# and IOStream._maybe_add_error_listener.
|
||||
server, client = self.make_iostream_pair()
|
||||
closed = [False]
|
||||
def close_callback():
|
||||
closed[0] = True
|
||||
self.stop()
|
||||
server.set_close_callback(close_callback)
|
||||
try:
|
||||
client.write(b'a')
|
||||
future = server.read_bytes(1)
|
||||
self.io_loop.add_future(future, self.stop)
|
||||
self.assertEqual(self.wait().result(), b'a')
|
||||
self.assertFalse(closed[0])
|
||||
client.close()
|
||||
self.wait()
|
||||
self.assertTrue(closed[0])
|
||||
finally:
|
||||
server.close()
|
||||
client.close()
|
||||
|
||||
def test_read_bytes_partial(self):
|
||||
server, client = self.make_iostream_pair()
|
||||
try:
|
||||
# Ask for more than is available with partial=True
|
||||
client.read_bytes(50, self.stop, partial=True)
|
||||
server.write(b"hello")
|
||||
data = self.wait()
|
||||
self.assertEqual(data, b"hello")
|
||||
|
||||
# Ask for less than what is available; num_bytes is still
|
||||
# respected.
|
||||
client.read_bytes(3, self.stop, partial=True)
|
||||
server.write(b"world")
|
||||
data = self.wait()
|
||||
self.assertEqual(data, b"wor")
|
||||
|
||||
# Partial reads won't return an empty string, but read_bytes(0)
|
||||
# will.
|
||||
client.read_bytes(0, self.stop, partial=True)
|
||||
data = self.wait()
|
||||
self.assertEqual(data, b'')
|
||||
finally:
|
||||
server.close()
|
||||
client.close()
|
||||
|
||||
def test_read_until_max_bytes(self):
|
||||
server, client = self.make_iostream_pair()
|
||||
client.set_close_callback(lambda: self.stop("closed"))
|
||||
try:
|
||||
# Extra room under the limit
|
||||
client.read_until(b"def", self.stop, max_bytes=50)
|
||||
server.write(b"abcdef")
|
||||
data = self.wait()
|
||||
self.assertEqual(data, b"abcdef")
|
||||
|
||||
# Just enough space
|
||||
client.read_until(b"def", self.stop, max_bytes=6)
|
||||
server.write(b"abcdef")
|
||||
data = self.wait()
|
||||
self.assertEqual(data, b"abcdef")
|
||||
|
||||
# Not enough space, but we don't know it until all we can do is
|
||||
# log a warning and close the connection.
|
||||
with ExpectLog(gen_log, "Unsatisfiable read"):
|
||||
client.read_until(b"def", self.stop, max_bytes=5)
|
||||
server.write(b"123456")
|
||||
data = self.wait()
|
||||
self.assertEqual(data, "closed")
|
||||
finally:
|
||||
server.close()
|
||||
client.close()
|
||||
|
||||
def test_read_until_max_bytes_inline(self):
|
||||
server, client = self.make_iostream_pair()
|
||||
client.set_close_callback(lambda: self.stop("closed"))
|
||||
try:
|
||||
# Similar to the error case in the previous test, but the
|
||||
# server writes first so client reads are satisfied
|
||||
# inline. For consistency with the out-of-line case, we
|
||||
# do not raise the error synchronously.
|
||||
server.write(b"123456")
|
||||
with ExpectLog(gen_log, "Unsatisfiable read"):
|
||||
client.read_until(b"def", self.stop, max_bytes=5)
|
||||
data = self.wait()
|
||||
self.assertEqual(data, "closed")
|
||||
finally:
|
||||
server.close()
|
||||
client.close()
|
||||
|
||||
def test_read_until_max_bytes_ignores_extra(self):
|
||||
server, client = self.make_iostream_pair()
|
||||
client.set_close_callback(lambda: self.stop("closed"))
|
||||
try:
|
||||
# Even though data that matches arrives the same packet that
|
||||
# puts us over the limit, we fail the request because it was not
|
||||
# found within the limit.
|
||||
server.write(b"abcdef")
|
||||
with ExpectLog(gen_log, "Unsatisfiable read"):
|
||||
client.read_until(b"def", self.stop, max_bytes=5)
|
||||
data = self.wait()
|
||||
self.assertEqual(data, "closed")
|
||||
finally:
|
||||
server.close()
|
||||
client.close()
|
||||
|
||||
def test_read_until_regex_max_bytes(self):
|
||||
server, client = self.make_iostream_pair()
|
||||
client.set_close_callback(lambda: self.stop("closed"))
|
||||
try:
|
||||
# Extra room under the limit
|
||||
client.read_until_regex(b"def", self.stop, max_bytes=50)
|
||||
server.write(b"abcdef")
|
||||
data = self.wait()
|
||||
self.assertEqual(data, b"abcdef")
|
||||
|
||||
# Just enough space
|
||||
client.read_until_regex(b"def", self.stop, max_bytes=6)
|
||||
server.write(b"abcdef")
|
||||
data = self.wait()
|
||||
self.assertEqual(data, b"abcdef")
|
||||
|
||||
# Not enough space, but we don't know it until all we can do is
|
||||
# log a warning and close the connection.
|
||||
with ExpectLog(gen_log, "Unsatisfiable read"):
|
||||
client.read_until_regex(b"def", self.stop, max_bytes=5)
|
||||
server.write(b"123456")
|
||||
data = self.wait()
|
||||
self.assertEqual(data, "closed")
|
||||
finally:
|
||||
server.close()
|
||||
client.close()
|
||||
|
||||
def test_read_until_regex_max_bytes_inline(self):
|
||||
server, client = self.make_iostream_pair()
|
||||
client.set_close_callback(lambda: self.stop("closed"))
|
||||
try:
|
||||
# Similar to the error case in the previous test, but the
|
||||
# server writes first so client reads are satisfied
|
||||
# inline. For consistency with the out-of-line case, we
|
||||
# do not raise the error synchronously.
|
||||
server.write(b"123456")
|
||||
with ExpectLog(gen_log, "Unsatisfiable read"):
|
||||
client.read_until_regex(b"def", self.stop, max_bytes=5)
|
||||
data = self.wait()
|
||||
self.assertEqual(data, "closed")
|
||||
finally:
|
||||
server.close()
|
||||
client.close()
|
||||
|
||||
def test_read_until_regex_max_bytes_ignores_extra(self):
|
||||
server, client = self.make_iostream_pair()
|
||||
client.set_close_callback(lambda: self.stop("closed"))
|
||||
try:
|
||||
# Even though data that matches arrives the same packet that
|
||||
# puts us over the limit, we fail the request because it was not
|
||||
# found within the limit.
|
||||
server.write(b"abcdef")
|
||||
with ExpectLog(gen_log, "Unsatisfiable read"):
|
||||
client.read_until_regex(b"def", self.stop, max_bytes=5)
|
||||
data = self.wait()
|
||||
self.assertEqual(data, "closed")
|
||||
finally:
|
||||
server.close()
|
||||
client.close()
|
||||
|
||||
def test_small_reads_from_large_buffer(self):
|
||||
# 10KB buffer size, 100KB available to read.
|
||||
# Read 1KB at a time and make sure that the buffer is not eagerly
|
||||
# filled.
|
||||
server, client = self.make_iostream_pair(max_buffer_size=10 * 1024)
|
||||
try:
|
||||
server.write(b"a" * 1024 * 100)
|
||||
for i in range(100):
|
||||
client.read_bytes(1024, self.stop)
|
||||
data = self.wait()
|
||||
self.assertEqual(data, b"a" * 1024)
|
||||
finally:
|
||||
server.close()
|
||||
client.close()
|
||||
|
||||
def test_small_read_untils_from_large_buffer(self):
|
||||
# 10KB buffer size, 100KB available to read.
|
||||
# Read 1KB at a time and make sure that the buffer is not eagerly
|
||||
# filled.
|
||||
server, client = self.make_iostream_pair(max_buffer_size=10 * 1024)
|
||||
try:
|
||||
server.write((b"a" * 1023 + b"\n") * 100)
|
||||
for i in range(100):
|
||||
client.read_until(b"\n", self.stop, max_bytes=4096)
|
||||
data = self.wait()
|
||||
self.assertEqual(data, b"a" * 1023 + b"\n")
|
||||
finally:
|
||||
server.close()
|
||||
client.close()
|
||||
|
||||
|
||||
class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase):
|
||||
def _make_client_iostream(self):
|
||||
|
|
@ -482,14 +745,10 @@ class TestIOStream(TestIOStreamMixin, AsyncTestCase):
|
|||
|
||||
class TestIOStreamSSL(TestIOStreamMixin, AsyncTestCase):
|
||||
def _make_server_iostream(self, connection, **kwargs):
|
||||
ssl_options = dict(
|
||||
certfile=os.path.join(os.path.dirname(__file__), 'test.crt'),
|
||||
keyfile=os.path.join(os.path.dirname(__file__), 'test.key'),
|
||||
)
|
||||
connection = ssl.wrap_socket(connection,
|
||||
server_side=True,
|
||||
do_handshake_on_connect=False,
|
||||
**ssl_options)
|
||||
**_server_ssl_options())
|
||||
return SSLIOStream(connection, io_loop=self.io_loop, **kwargs)
|
||||
|
||||
def _make_client_iostream(self, connection, **kwargs):
|
||||
|
|
@ -517,6 +776,91 @@ class TestIOStreamSSLContext(TestIOStreamMixin, AsyncTestCase):
|
|||
ssl_options=context, **kwargs)
|
||||
|
||||
|
||||
class TestIOStreamStartTLS(AsyncTestCase):
|
||||
def setUp(self):
|
||||
try:
|
||||
super(TestIOStreamStartTLS, self).setUp()
|
||||
self.listener, self.port = bind_unused_port()
|
||||
self.server_stream = None
|
||||
self.server_accepted = Future()
|
||||
netutil.add_accept_handler(self.listener, self.accept)
|
||||
self.client_stream = IOStream(socket.socket())
|
||||
self.io_loop.add_future(self.client_stream.connect(
|
||||
('127.0.0.1', self.port)), self.stop)
|
||||
self.wait()
|
||||
self.io_loop.add_future(self.server_accepted, self.stop)
|
||||
self.wait()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise
|
||||
|
||||
def tearDown(self):
|
||||
if self.server_stream is not None:
|
||||
self.server_stream.close()
|
||||
if self.client_stream is not None:
|
||||
self.client_stream.close()
|
||||
self.listener.close()
|
||||
super(TestIOStreamStartTLS, self).tearDown()
|
||||
|
||||
def accept(self, connection, address):
|
||||
if self.server_stream is not None:
|
||||
self.fail("should only get one connection")
|
||||
self.server_stream = IOStream(connection)
|
||||
self.server_accepted.set_result(None)
|
||||
|
||||
@gen.coroutine
|
||||
def client_send_line(self, line):
|
||||
self.client_stream.write(line)
|
||||
recv_line = yield self.server_stream.read_until(b"\r\n")
|
||||
self.assertEqual(line, recv_line)
|
||||
|
||||
@gen.coroutine
|
||||
def server_send_line(self, line):
|
||||
self.server_stream.write(line)
|
||||
recv_line = yield self.client_stream.read_until(b"\r\n")
|
||||
self.assertEqual(line, recv_line)
|
||||
|
||||
def client_start_tls(self, ssl_options=None):
|
||||
client_stream = self.client_stream
|
||||
self.client_stream = None
|
||||
return client_stream.start_tls(False, ssl_options)
|
||||
|
||||
def server_start_tls(self, ssl_options=None):
|
||||
server_stream = self.server_stream
|
||||
self.server_stream = None
|
||||
return server_stream.start_tls(True, ssl_options)
|
||||
|
||||
@gen_test
|
||||
def test_start_tls_smtp(self):
|
||||
# This flow is simplified from RFC 3207 section 5.
|
||||
# We don't really need all of this, but it helps to make sure
|
||||
# that after realistic back-and-forth traffic the buffers end up
|
||||
# in a sane state.
|
||||
yield self.server_send_line(b"220 mail.example.com ready\r\n")
|
||||
yield self.client_send_line(b"EHLO mail.example.com\r\n")
|
||||
yield self.server_send_line(b"250-mail.example.com welcome\r\n")
|
||||
yield self.server_send_line(b"250 STARTTLS\r\n")
|
||||
yield self.client_send_line(b"STARTTLS\r\n")
|
||||
yield self.server_send_line(b"220 Go ahead\r\n")
|
||||
client_future = self.client_start_tls()
|
||||
server_future = self.server_start_tls(_server_ssl_options())
|
||||
self.client_stream = yield client_future
|
||||
self.server_stream = yield server_future
|
||||
self.assertTrue(isinstance(self.client_stream, SSLIOStream))
|
||||
self.assertTrue(isinstance(self.server_stream, SSLIOStream))
|
||||
yield self.client_send_line(b"EHLO mail.example.com\r\n")
|
||||
yield self.server_send_line(b"250 mail.example.com welcome\r\n")
|
||||
|
||||
@gen_test
|
||||
def test_handshake_fail(self):
|
||||
self.server_start_tls(_server_ssl_options())
|
||||
client_future = self.client_start_tls(
|
||||
dict(cert_reqs=ssl.CERT_REQUIRED, ca_certs=certifi.where()))
|
||||
with ExpectLog(gen_log, "SSL Error"):
|
||||
with self.assertRaises(ssl.SSLError):
|
||||
yield client_future
|
||||
|
||||
|
||||
@skipIfNonUnix
|
||||
class TestPipeIOStream(AsyncTestCase):
|
||||
def test_pipe_iostream(self):
|
||||
|
|
@ -543,3 +887,21 @@ class TestPipeIOStream(AsyncTestCase):
|
|||
self.assertEqual(data, b"ld")
|
||||
|
||||
rs.close()
|
||||
|
||||
def test_pipe_iostream_big_write(self):
|
||||
r, w = os.pipe()
|
||||
|
||||
rs = PipeIOStream(r, io_loop=self.io_loop)
|
||||
ws = PipeIOStream(w, io_loop=self.io_loop)
|
||||
|
||||
NUM_BYTES = 1048576
|
||||
|
||||
# Write 1MB of data, which should fill the buffer
|
||||
ws.write(b"1" * NUM_BYTES)
|
||||
|
||||
rs.read_bytes(NUM_BYTES, self.stop)
|
||||
data = self.wait()
|
||||
self.assertEqual(data, b"1" * NUM_BYTES)
|
||||
|
||||
ws.close()
|
||||
rs.close()
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ import glob
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import warnings
|
||||
|
||||
|
|
@ -52,7 +54,6 @@ class LogFormatterTest(unittest.TestCase):
|
|||
logging.ERROR: u("\u0001"),
|
||||
}
|
||||
self.formatter._normal = u("\u0002")
|
||||
self.formatter._color = True
|
||||
# construct a Logger directly to bypass getLogger's caching
|
||||
self.logger = logging.Logger('LogFormatterTest')
|
||||
self.logger.propagate = False
|
||||
|
|
@ -157,3 +158,50 @@ class EnablePrettyLoggingTest(unittest.TestCase):
|
|||
for filename in glob.glob(tmpdir + '/test_log*'):
|
||||
os.unlink(filename)
|
||||
os.rmdir(tmpdir)
|
||||
|
||||
|
||||
class LoggingOptionTest(unittest.TestCase):
|
||||
"""Test the ability to enable and disable Tornado's logging hooks."""
|
||||
def logs_present(self, statement, args=None):
|
||||
# Each test may manipulate and/or parse the options and then logs
|
||||
# a line at the 'info' level. This level is ignored in the
|
||||
# logging module by default, but Tornado turns it on by default
|
||||
# so it is the easiest way to tell whether tornado's logging hooks
|
||||
# ran.
|
||||
IMPORT = 'from tornado.options import options, parse_command_line'
|
||||
LOG_INFO = 'import logging; logging.info("hello")'
|
||||
program = ';'.join([IMPORT, statement, LOG_INFO])
|
||||
proc = subprocess.Popen(
|
||||
[sys.executable, '-c', program] + (args or []),
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
stdout, stderr = proc.communicate()
|
||||
self.assertEqual(proc.returncode, 0, 'process failed: %r' % stdout)
|
||||
return b'hello' in stdout
|
||||
|
||||
def test_default(self):
|
||||
self.assertFalse(self.logs_present('pass'))
|
||||
|
||||
def test_tornado_default(self):
|
||||
self.assertTrue(self.logs_present('parse_command_line()'))
|
||||
|
||||
def test_disable_command_line(self):
|
||||
self.assertFalse(self.logs_present('parse_command_line()',
|
||||
['--logging=none']))
|
||||
|
||||
def test_disable_command_line_case_insensitive(self):
|
||||
self.assertFalse(self.logs_present('parse_command_line()',
|
||||
['--logging=None']))
|
||||
|
||||
def test_disable_code_string(self):
|
||||
self.assertFalse(self.logs_present(
|
||||
'options.logging = "none"; parse_command_line()'))
|
||||
|
||||
def test_disable_code_none(self):
|
||||
self.assertFalse(self.logs_present(
|
||||
'options.logging = None; parse_command_line()'))
|
||||
|
||||
def test_disable_override(self):
|
||||
# command line trumps code defaults
|
||||
self.assertTrue(self.logs_present(
|
||||
'options.logging = None; parse_command_line()',
|
||||
['--logging=info']))
|
||||
|
|
|
|||
|
|
@ -1,10 +1,16 @@
|
|||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
from subprocess import Popen
|
||||
import sys
|
||||
import time
|
||||
|
||||
from tornado.netutil import BlockingResolver, ThreadedResolver, is_valid_ip
|
||||
from tornado.netutil import BlockingResolver, ThreadedResolver, is_valid_ip, bind_sockets
|
||||
from tornado.stack_context import ExceptionStackContext
|
||||
from tornado.testing import AsyncTestCase, gen_test
|
||||
from tornado.test.util import unittest
|
||||
from tornado.test.util import unittest, skipIfNoNetwork
|
||||
|
||||
try:
|
||||
from concurrent import futures
|
||||
|
|
@ -20,6 +26,7 @@ else:
|
|||
|
||||
try:
|
||||
import twisted
|
||||
import twisted.names
|
||||
except ImportError:
|
||||
twisted = None
|
||||
else:
|
||||
|
|
@ -27,6 +34,15 @@ else:
|
|||
|
||||
|
||||
class _ResolverTestMixin(object):
|
||||
def skipOnCares(self):
|
||||
# Some DNS-hijacking ISPs (e.g. Time Warner) return non-empty results
|
||||
# with an NXDOMAIN status code. Most resolvers treat this as an error;
|
||||
# C-ares returns the results, making the "bad_host" tests unreliable.
|
||||
# C-ares will try to resolve even malformed names, such as the
|
||||
# name with spaces used in this test.
|
||||
if self.resolver.__class__.__name__ == 'CaresResolver':
|
||||
self.skipTest("CaresResolver doesn't recognize fake NXDOMAIN")
|
||||
|
||||
def test_localhost(self):
|
||||
self.resolver.resolve('localhost', 80, callback=self.stop)
|
||||
result = self.wait()
|
||||
|
|
@ -39,13 +55,34 @@ class _ResolverTestMixin(object):
|
|||
self.assertIn((socket.AF_INET, ('127.0.0.1', 80)),
|
||||
addrinfo)
|
||||
|
||||
def test_bad_host(self):
|
||||
self.skipOnCares()
|
||||
def handler(exc_typ, exc_val, exc_tb):
|
||||
self.stop(exc_val)
|
||||
return True # Halt propagation.
|
||||
|
||||
with ExceptionStackContext(handler):
|
||||
self.resolver.resolve('an invalid domain', 80, callback=self.stop)
|
||||
|
||||
result = self.wait()
|
||||
self.assertIsInstance(result, Exception)
|
||||
|
||||
@gen_test
|
||||
def test_future_interface_bad_host(self):
|
||||
self.skipOnCares()
|
||||
with self.assertRaises(Exception):
|
||||
yield self.resolver.resolve('an invalid domain', 80,
|
||||
socket.AF_UNSPEC)
|
||||
|
||||
|
||||
@skipIfNoNetwork
|
||||
class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
|
||||
def setUp(self):
|
||||
super(BlockingResolverTest, self).setUp()
|
||||
self.resolver = BlockingResolver(io_loop=self.io_loop)
|
||||
|
||||
|
||||
@skipIfNoNetwork
|
||||
@unittest.skipIf(futures is None, "futures module not present")
|
||||
class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
|
||||
def setUp(self):
|
||||
|
|
@ -57,6 +94,34 @@ class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
|
|||
super(ThreadedResolverTest, self).tearDown()
|
||||
|
||||
|
||||
@skipIfNoNetwork
|
||||
@unittest.skipIf(futures is None, "futures module not present")
|
||||
@unittest.skipIf(sys.platform == 'win32', "preexec_fn not available on win32")
|
||||
class ThreadedResolverImportTest(unittest.TestCase):
|
||||
def test_import(self):
|
||||
TIMEOUT = 5
|
||||
|
||||
# Test for a deadlock when importing a module that runs the
|
||||
# ThreadedResolver at import-time. See resolve_test.py for
|
||||
# full explanation.
|
||||
command = [
|
||||
sys.executable,
|
||||
'-c',
|
||||
'import tornado.test.resolve_test_helper']
|
||||
|
||||
start = time.time()
|
||||
popen = Popen(command, preexec_fn=lambda: signal.alarm(TIMEOUT))
|
||||
while time.time() - start < TIMEOUT:
|
||||
return_code = popen.poll()
|
||||
if return_code is not None:
|
||||
self.assertEqual(0, return_code)
|
||||
return # Success.
|
||||
time.sleep(0.05)
|
||||
|
||||
self.fail("import timed out")
|
||||
|
||||
|
||||
@skipIfNoNetwork
|
||||
@unittest.skipIf(pycares is None, "pycares module not present")
|
||||
class CaresResolverTest(AsyncTestCase, _ResolverTestMixin):
|
||||
def setUp(self):
|
||||
|
|
@ -64,6 +129,7 @@ class CaresResolverTest(AsyncTestCase, _ResolverTestMixin):
|
|||
self.resolver = CaresResolver(io_loop=self.io_loop)
|
||||
|
||||
|
||||
@skipIfNoNetwork
|
||||
@unittest.skipIf(twisted is None, "twisted module not present")
|
||||
@unittest.skipIf(getattr(twisted, '__version__', '0.0') < "12.1", "old version of twisted")
|
||||
class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin):
|
||||
|
|
@ -82,3 +148,21 @@ class IsValidIPTest(unittest.TestCase):
|
|||
self.assertTrue(not is_valid_ip('localhost'))
|
||||
self.assertTrue(not is_valid_ip('4.4.4.4<'))
|
||||
self.assertTrue(not is_valid_ip(' 127.0.0.1'))
|
||||
self.assertTrue(not is_valid_ip(''))
|
||||
self.assertTrue(not is_valid_ip(' '))
|
||||
self.assertTrue(not is_valid_ip('\n'))
|
||||
self.assertTrue(not is_valid_ip('\x00'))
|
||||
|
||||
|
||||
class TestPortAllocation(unittest.TestCase):
|
||||
def test_same_port_allocation(self):
|
||||
if 'TRAVIS' in os.environ:
|
||||
self.skipTest("dual-stack servers often have port conflicts on travis")
|
||||
sockets = bind_sockets(None, 'localhost')
|
||||
try:
|
||||
port = sockets[0].getsockname()[1]
|
||||
self.assertTrue(all(s.getsockname()[1] == port
|
||||
for s in sockets[1:]))
|
||||
finally:
|
||||
for sock in sockets:
|
||||
sock.close()
|
||||
|
|
|
|||
|
|
@ -19,8 +19,10 @@ from tornado.web import RequestHandler, Application
|
|||
|
||||
|
||||
def skip_if_twisted():
|
||||
if IOLoop.configured_class().__name__.endswith('TwistedIOLoop'):
|
||||
raise unittest.SkipTest("Process tests not compatible with TwistedIOLoop")
|
||||
if IOLoop.configured_class().__name__.endswith(('TwistedIOLoop',
|
||||
'AsyncIOMainLoop')):
|
||||
raise unittest.SkipTest("Process tests not compatible with "
|
||||
"TwistedIOLoop or AsyncIOMainLoop")
|
||||
|
||||
# Not using AsyncHTTPTestCase because we need control over the IOLoop.
|
||||
|
||||
|
|
@ -135,6 +137,14 @@ class ProcessTest(unittest.TestCase):
|
|||
@skipIfNonUnix
|
||||
class SubprocessTest(AsyncTestCase):
|
||||
def test_subprocess(self):
|
||||
if IOLoop.configured_class().__name__.endswith('LayeredTwistedIOLoop'):
|
||||
# This test fails non-deterministically with LayeredTwistedIOLoop.
|
||||
# (the read_until('\n') returns '\n' instead of 'hello\n')
|
||||
# This probably indicates a problem with either TornadoReactor
|
||||
# or TwistedIOLoop, but I haven't been able to track it down
|
||||
# and for now this is just causing spurious travis-ci failures.
|
||||
raise unittest.SkipTest("Subprocess tests not compatible with "
|
||||
"LayeredTwistedIOLoop")
|
||||
subproc = Subprocess([sys.executable, '-u', '-i'],
|
||||
stdin=Subprocess.STREAM,
|
||||
stdout=Subprocess.STREAM, stderr=subprocess.STDOUT,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,12 @@
|
|||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.netutil import ThreadedResolver
|
||||
from tornado.util import u
|
||||
|
||||
# When this module is imported, it runs getaddrinfo on a thread. Since
|
||||
# the hostname is unicode, getaddrinfo attempts to import encodings.idna
|
||||
# but blocks on the import lock. Verify that ThreadedResolver avoids
|
||||
# this deadlock.
|
||||
|
||||
resolver = ThreadedResolver()
|
||||
IOLoop.current().run_sync(lambda: resolver.resolve(u('localhost'), 80))
|
||||
|
|
@ -13,6 +13,11 @@ from tornado.netutil import Resolver
|
|||
from tornado.options import define, options, add_parse_callback
|
||||
from tornado.test.util import unittest
|
||||
|
||||
try:
|
||||
reduce # py2
|
||||
except NameError:
|
||||
from functools import reduce # py3
|
||||
|
||||
TEST_MODULES = [
|
||||
'tornado.httputil.doctests',
|
||||
'tornado.iostream.doctests',
|
||||
|
|
@ -35,6 +40,7 @@ TEST_MODULES = [
|
|||
'tornado.test.process_test',
|
||||
'tornado.test.simple_httpclient_test',
|
||||
'tornado.test.stack_context_test',
|
||||
'tornado.test.tcpclient_test',
|
||||
'tornado.test.template_test',
|
||||
'tornado.test.testing_test',
|
||||
'tornado.test.twisted_test',
|
||||
|
|
@ -60,7 +66,8 @@ class TornadoTextTestRunner(unittest.TextTestRunner):
|
|||
self.stream.write("\n")
|
||||
return result
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
def main():
|
||||
# The -W command-line option does not work in a virtualenv with
|
||||
# python 3 (as of virtualenv 1.7), so configure warnings
|
||||
# programmatically instead.
|
||||
|
|
@ -77,6 +84,9 @@ if __name__ == '__main__':
|
|||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("error", category=DeprecationWarning,
|
||||
module=r"tornado\..*")
|
||||
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
|
||||
warnings.filterwarnings("error", category=PendingDeprecationWarning,
|
||||
module=r"tornado\..*")
|
||||
# The unittest module is aggressive about deprecating redundant methods,
|
||||
# leaving some without non-deprecated spellings that work on both
|
||||
# 2.7 and 3.2
|
||||
|
|
@ -86,7 +96,8 @@ if __name__ == '__main__':
|
|||
logging.getLogger("tornado.access").setLevel(logging.CRITICAL)
|
||||
|
||||
define('httpclient', type=str, default=None,
|
||||
callback=AsyncHTTPClient.configure)
|
||||
callback=lambda s: AsyncHTTPClient.configure(
|
||||
s, defaults=dict(allow_ipv6=False)))
|
||||
define('ioloop', type=str, default=None)
|
||||
define('ioloop_time_monotonic', default=False)
|
||||
define('resolver', type=str, default=None,
|
||||
|
|
@ -121,3 +132,6 @@ if __name__ == '__main__':
|
|||
kwargs['warnings'] = False
|
||||
kwargs['testRunner'] = TornadoTextTestRunner
|
||||
tornado.testing.main(**kwargs)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -10,16 +10,18 @@ import re
|
|||
import socket
|
||||
import sys
|
||||
|
||||
from tornado import gen
|
||||
from tornado.httpclient import AsyncHTTPClient
|
||||
from tornado.httputil import HTTPHeaders
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.log import gen_log
|
||||
from tornado.simple_httpclient import SimpleAsyncHTTPClient, _DEFAULT_CA_CERTS
|
||||
from tornado.log import gen_log, app_log
|
||||
from tornado.netutil import Resolver, bind_sockets
|
||||
from tornado.simple_httpclient import SimpleAsyncHTTPClient, _default_ca_certs
|
||||
from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler
|
||||
from tornado.test import httpclient_test
|
||||
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog
|
||||
from tornado.test.util import unittest, skipOnTravis
|
||||
from tornado.web import RequestHandler, Application, asynchronous, url
|
||||
from tornado.test.util import skipOnTravis, skipIfNoIPv6
|
||||
from tornado.web import RequestHandler, Application, asynchronous, url, stream_request_body
|
||||
|
||||
|
||||
class SimpleHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
|
||||
|
|
@ -69,7 +71,8 @@ class OptionsHandler(RequestHandler):
|
|||
class NoContentHandler(RequestHandler):
|
||||
def get(self):
|
||||
if self.get_argument("error", None):
|
||||
self.set_header("Content-Length", "7")
|
||||
self.set_header("Content-Length", "5")
|
||||
self.write("hello")
|
||||
self.set_status(204)
|
||||
|
||||
|
||||
|
|
@ -93,6 +96,30 @@ class HostEchoHandler(RequestHandler):
|
|||
self.write(self.request.headers["Host"])
|
||||
|
||||
|
||||
class NoContentLengthHandler(RequestHandler):
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
# Emulate the old HTTP/1.0 behavior of returning a body with no
|
||||
# content-length. Tornado handles content-length at the framework
|
||||
# level so we have to go around it.
|
||||
stream = self.request.connection.stream
|
||||
yield stream.write(b"HTTP/1.0 200 OK\r\n\r\n"
|
||||
b"hello")
|
||||
stream.close()
|
||||
|
||||
|
||||
class EchoPostHandler(RequestHandler):
|
||||
def post(self):
|
||||
self.write(self.request.body)
|
||||
|
||||
|
||||
@stream_request_body
|
||||
class RespondInPrepareHandler(RequestHandler):
|
||||
def prepare(self):
|
||||
self.set_status(403)
|
||||
self.finish("forbidden")
|
||||
|
||||
|
||||
class SimpleHTTPClientTestMixin(object):
|
||||
def get_app(self):
|
||||
# callable objects to finish pending /trigger requests
|
||||
|
|
@ -111,6 +138,9 @@ class SimpleHTTPClientTestMixin(object):
|
|||
url("/see_other_post", SeeOtherPostHandler),
|
||||
url("/see_other_get", SeeOtherGetHandler),
|
||||
url("/host_echo", HostEchoHandler),
|
||||
url("/no_content_length", NoContentLengthHandler),
|
||||
url("/echo_post", EchoPostHandler),
|
||||
url("/respond_in_prepare", RespondInPrepareHandler),
|
||||
], gzip=True)
|
||||
|
||||
def test_singleton(self):
|
||||
|
|
@ -122,9 +152,9 @@ class SimpleHTTPClientTestMixin(object):
|
|||
SimpleAsyncHTTPClient(self.io_loop,
|
||||
force_instance=True))
|
||||
# different IOLoops use different objects
|
||||
io_loop2 = IOLoop()
|
||||
self.assertTrue(SimpleAsyncHTTPClient(self.io_loop) is not
|
||||
SimpleAsyncHTTPClient(io_loop2))
|
||||
with closing(IOLoop()) as io_loop2:
|
||||
self.assertTrue(SimpleAsyncHTTPClient(self.io_loop) is not
|
||||
SimpleAsyncHTTPClient(io_loop2))
|
||||
|
||||
def test_connection_limit(self):
|
||||
with closing(self.create_client(max_clients=2)) as client:
|
||||
|
|
@ -162,7 +192,7 @@ class SimpleHTTPClientTestMixin(object):
|
|||
response.rethrow()
|
||||
|
||||
def test_default_certificates_exist(self):
|
||||
open(_DEFAULT_CA_CERTS).close()
|
||||
open(_default_ca_certs()).close()
|
||||
|
||||
def test_gzip(self):
|
||||
# All the tests in this file should be using gzip, but this test
|
||||
|
|
@ -212,28 +242,30 @@ class SimpleHTTPClientTestMixin(object):
|
|||
# trigger the hanging request to let it clean up after itself
|
||||
self.triggers.popleft()()
|
||||
|
||||
@unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present')
|
||||
@skipIfNoIPv6
|
||||
def test_ipv6(self):
|
||||
try:
|
||||
self.http_server.listen(self.get_http_port(), address='::1')
|
||||
[sock] = bind_sockets(None, '::1', family=socket.AF_INET6)
|
||||
port = sock.getsockname()[1]
|
||||
self.http_server.add_socket(sock)
|
||||
except socket.gaierror as e:
|
||||
if e.args[0] == socket.EAI_ADDRFAMILY:
|
||||
# python supports ipv6, but it's not configured on the network
|
||||
# interface, so skip this test.
|
||||
return
|
||||
raise
|
||||
url = self.get_url("/hello").replace("localhost", "[::1]")
|
||||
url = '%s://[::1]:%d/hello' % (self.get_protocol(), port)
|
||||
|
||||
# ipv6 is currently disabled by default and must be explicitly requested
|
||||
self.http_client.fetch(url, self.stop)
|
||||
# ipv6 is currently enabled by default but can be disabled
|
||||
self.http_client.fetch(url, self.stop, allow_ipv6=False)
|
||||
response = self.wait()
|
||||
self.assertEqual(response.code, 599)
|
||||
|
||||
self.http_client.fetch(url, self.stop, allow_ipv6=True)
|
||||
self.http_client.fetch(url, self.stop)
|
||||
response = self.wait()
|
||||
self.assertEqual(response.body, b"Hello world!")
|
||||
|
||||
def test_multiple_content_length_accepted(self):
|
||||
def xtest_multiple_content_length_accepted(self):
|
||||
response = self.fetch("/content_length?value=2,2")
|
||||
self.assertEqual(response.body, b"ok")
|
||||
response = self.fetch("/content_length?value=2,%202,2")
|
||||
|
|
@ -265,7 +297,8 @@ class SimpleHTTPClientTestMixin(object):
|
|||
self.assertEqual(response.headers["Content-length"], "0")
|
||||
|
||||
# 204 status with non-zero content length is malformed
|
||||
response = self.fetch("/no_content?error=1")
|
||||
with ExpectLog(app_log, "Uncaught exception"):
|
||||
response = self.fetch("/no_content?error=1")
|
||||
self.assertEqual(response.code, 599)
|
||||
|
||||
def test_host_header(self):
|
||||
|
|
@ -288,14 +321,86 @@ class SimpleHTTPClientTestMixin(object):
|
|||
|
||||
if sys.platform != 'cygwin':
|
||||
# cygwin returns EPERM instead of ECONNREFUSED here
|
||||
self.assertTrue(str(errno.ECONNREFUSED) in str(response.error),
|
||||
response.error)
|
||||
contains_errno = str(errno.ECONNREFUSED) in str(response.error)
|
||||
if not contains_errno and hasattr(errno, "WSAECONNREFUSED"):
|
||||
contains_errno = str(errno.WSAECONNREFUSED) in str(response.error)
|
||||
self.assertTrue(contains_errno, response.error)
|
||||
# This is usually "Connection refused".
|
||||
# On windows, strerror is broken and returns "Unknown error".
|
||||
expected_message = os.strerror(errno.ECONNREFUSED)
|
||||
self.assertTrue(expected_message in str(response.error),
|
||||
response.error)
|
||||
|
||||
def test_queue_timeout(self):
|
||||
with closing(self.create_client(max_clients=1)) as client:
|
||||
client.fetch(self.get_url('/trigger'), self.stop,
|
||||
request_timeout=10)
|
||||
# Wait for the trigger request to block, not complete.
|
||||
self.wait()
|
||||
client.fetch(self.get_url('/hello'), self.stop,
|
||||
connect_timeout=0.1)
|
||||
response = self.wait()
|
||||
|
||||
self.assertEqual(response.code, 599)
|
||||
self.assertTrue(response.request_time < 1, response.request_time)
|
||||
self.assertEqual(str(response.error), "HTTP 599: Timeout")
|
||||
self.triggers.popleft()()
|
||||
self.wait()
|
||||
|
||||
def test_no_content_length(self):
|
||||
response = self.fetch("/no_content_length")
|
||||
self.assertEquals(b"hello", response.body)
|
||||
|
||||
def sync_body_producer(self, write):
|
||||
write(b'1234')
|
||||
write(b'5678')
|
||||
|
||||
@gen.coroutine
|
||||
def async_body_producer(self, write):
|
||||
yield write(b'1234')
|
||||
yield gen.Task(IOLoop.current().add_callback)
|
||||
yield write(b'5678')
|
||||
|
||||
def test_sync_body_producer_chunked(self):
|
||||
response = self.fetch("/echo_post", method="POST",
|
||||
body_producer=self.sync_body_producer)
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b"12345678")
|
||||
|
||||
def test_sync_body_producer_content_length(self):
|
||||
response = self.fetch("/echo_post", method="POST",
|
||||
body_producer=self.sync_body_producer,
|
||||
headers={'Content-Length': '8'})
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b"12345678")
|
||||
|
||||
def test_async_body_producer_chunked(self):
|
||||
response = self.fetch("/echo_post", method="POST",
|
||||
body_producer=self.async_body_producer)
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b"12345678")
|
||||
|
||||
def test_async_body_producer_content_length(self):
|
||||
response = self.fetch("/echo_post", method="POST",
|
||||
body_producer=self.async_body_producer,
|
||||
headers={'Content-Length': '8'})
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b"12345678")
|
||||
|
||||
def test_100_continue(self):
|
||||
response = self.fetch("/echo_post", method="POST",
|
||||
body=b"1234",
|
||||
expect_100_continue=True)
|
||||
self.assertEqual(response.body, b"1234")
|
||||
|
||||
def test_100_continue_early_response(self):
|
||||
def body_producer(write):
|
||||
raise Exception("should not be called")
|
||||
response = self.fetch("/respond_in_prepare", method="POST",
|
||||
body_producer=body_producer,
|
||||
expect_100_continue=True)
|
||||
self.assertEqual(response.code, 403)
|
||||
|
||||
|
||||
class SimpleHTTPClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPTestCase):
|
||||
def setUp(self):
|
||||
|
|
@ -396,3 +501,52 @@ class HostnameMappingTestCase(AsyncHTTPTestCase):
|
|||
response = self.wait()
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b'Hello world!')
|
||||
|
||||
|
||||
class ResolveTimeoutTestCase(AsyncHTTPTestCase):
|
||||
def setUp(self):
|
||||
# Dummy Resolver subclass that never invokes its callback.
|
||||
class BadResolver(Resolver):
|
||||
def resolve(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
super(ResolveTimeoutTestCase, self).setUp()
|
||||
self.http_client = SimpleAsyncHTTPClient(
|
||||
self.io_loop,
|
||||
resolver=BadResolver())
|
||||
|
||||
def get_app(self):
|
||||
return Application([url("/hello", HelloWorldHandler), ])
|
||||
|
||||
def test_resolve_timeout(self):
|
||||
response = self.fetch('/hello', connect_timeout=0.1)
|
||||
self.assertEqual(response.code, 599)
|
||||
|
||||
|
||||
class MaxHeaderSizeTest(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
class SmallHeaders(RequestHandler):
|
||||
def get(self):
|
||||
self.set_header("X-Filler", "a" * 100)
|
||||
self.write("ok")
|
||||
|
||||
class LargeHeaders(RequestHandler):
|
||||
def get(self):
|
||||
self.set_header("X-Filler", "a" * 1000)
|
||||
self.write("ok")
|
||||
|
||||
return Application([('/small', SmallHeaders),
|
||||
('/large', LargeHeaders)])
|
||||
|
||||
def get_http_client(self):
|
||||
return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_header_size=1024)
|
||||
|
||||
def test_small_headers(self):
|
||||
response = self.fetch('/small')
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b'ok')
|
||||
|
||||
def test_large_headers(self):
|
||||
with ExpectLog(gen_log, "Unsatisfiable read"):
|
||||
response = self.fetch('/large')
|
||||
self.assertEqual(response.code, 599)
|
||||
|
|
|
|||
|
|
@ -35,11 +35,11 @@ class TestRequestHandler(RequestHandler):
|
|||
logging.debug('in part3()')
|
||||
raise Exception('test exception')
|
||||
|
||||
def get_error_html(self, status_code, **kwargs):
|
||||
if 'exception' in kwargs and str(kwargs['exception']) == 'test exception':
|
||||
return 'got expected exception'
|
||||
def write_error(self, status_code, **kwargs):
|
||||
if 'exc_info' in kwargs and str(kwargs['exc_info'][1]) == 'test exception':
|
||||
self.write('got expected exception')
|
||||
else:
|
||||
return 'unexpected failure'
|
||||
self.write('unexpected failure')
|
||||
|
||||
|
||||
class HTTPStackContextTest(AsyncHTTPTestCase):
|
||||
|
|
@ -219,16 +219,22 @@ class StackContextTest(AsyncTestCase):
|
|||
def test_yield_in_with(self):
|
||||
@gen.engine
|
||||
def f():
|
||||
self.callback = yield gen.Callback('a')
|
||||
with StackContext(functools.partial(self.context, 'c1')):
|
||||
# This yield is a problem: the generator will be suspended
|
||||
# and the StackContext's __exit__ is not called yet, so
|
||||
# the context will be left on _state.contexts for anything
|
||||
# that runs before the yield resolves.
|
||||
yield gen.Task(self.io_loop.add_callback)
|
||||
yield gen.Wait('a')
|
||||
|
||||
with self.assertRaises(StackContextInconsistentError):
|
||||
f()
|
||||
self.wait()
|
||||
# Cleanup: to avoid GC warnings (which for some reason only seem
|
||||
# to show up on py33-asyncio), invoke the callback (which will do
|
||||
# nothing since the gen.Runner is already finished) and delete it.
|
||||
self.callback()
|
||||
del self.callback
|
||||
|
||||
@gen_test
|
||||
def test_yield_outside_with(self):
|
||||
|
|
@ -256,12 +262,13 @@ class StackContextTest(AsyncTestCase):
|
|||
self.io_loop.add_callback(cb)
|
||||
yield gen.Wait('k1')
|
||||
|
||||
@gen_test
|
||||
def test_run_with_stack_context(self):
|
||||
@gen.coroutine
|
||||
def f1():
|
||||
self.assertEqual(self.active_contexts, ['c1'])
|
||||
yield run_with_stack_context(
|
||||
StackContext(functools.partial(self.context, 'c1')),
|
||||
StackContext(functools.partial(self.context, 'c2')),
|
||||
f2)
|
||||
self.assertEqual(self.active_contexts, ['c1'])
|
||||
|
||||
|
|
@ -272,7 +279,7 @@ class StackContextTest(AsyncTestCase):
|
|||
self.assertEqual(self.active_contexts, ['c1', 'c2'])
|
||||
|
||||
self.assertEqual(self.active_contexts, [])
|
||||
run_with_stack_context(
|
||||
yield run_with_stack_context(
|
||||
StackContext(functools.partial(self.context, 'c1')),
|
||||
f1)
|
||||
self.assertEqual(self.active_contexts, [])
|
||||
|
|
|
|||
|
|
@ -0,0 +1,278 @@
|
|||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright 2014 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
|
||||
from contextlib import closing
|
||||
import os
|
||||
import socket
|
||||
|
||||
from tornado.concurrent import Future
|
||||
from tornado.netutil import bind_sockets, Resolver
|
||||
from tornado.tcpclient import TCPClient, _Connector
|
||||
from tornado.tcpserver import TCPServer
|
||||
from tornado.testing import AsyncTestCase, bind_unused_port, gen_test
|
||||
from tornado.test.util import skipIfNoIPv6, unittest
|
||||
|
||||
# Fake address families for testing. Used in place of AF_INET
|
||||
# and AF_INET6 because some installations do not have AF_INET6.
|
||||
AF1, AF2 = 1, 2
|
||||
|
||||
|
||||
class TestTCPServer(TCPServer):
|
||||
def __init__(self, family):
|
||||
super(TestTCPServer, self).__init__()
|
||||
self.streams = []
|
||||
sockets = bind_sockets(None, 'localhost', family)
|
||||
self.add_sockets(sockets)
|
||||
self.port = sockets[0].getsockname()[1]
|
||||
|
||||
def handle_stream(self, stream, address):
|
||||
self.streams.append(stream)
|
||||
|
||||
def stop(self):
|
||||
super(TestTCPServer, self).stop()
|
||||
for stream in self.streams:
|
||||
stream.close()
|
||||
|
||||
|
||||
class TCPClientTest(AsyncTestCase):
|
||||
def setUp(self):
|
||||
super(TCPClientTest, self).setUp()
|
||||
self.server = None
|
||||
self.client = TCPClient()
|
||||
|
||||
def start_server(self, family):
|
||||
if family == socket.AF_UNSPEC and 'TRAVIS' in os.environ:
|
||||
self.skipTest("dual-stack servers often have port conflicts on travis")
|
||||
self.server = TestTCPServer(family)
|
||||
return self.server.port
|
||||
|
||||
def stop_server(self):
|
||||
if self.server is not None:
|
||||
self.server.stop()
|
||||
self.server = None
|
||||
|
||||
def tearDown(self):
|
||||
self.client.close()
|
||||
self.stop_server()
|
||||
super(TCPClientTest, self).tearDown()
|
||||
|
||||
def skipIfLocalhostV4(self):
|
||||
Resolver().resolve('localhost', 0, callback=self.stop)
|
||||
addrinfo = self.wait()
|
||||
families = set(addr[0] for addr in addrinfo)
|
||||
if socket.AF_INET6 not in families:
|
||||
self.skipTest("localhost does not resolve to ipv6")
|
||||
|
||||
@gen_test
|
||||
def do_test_connect(self, family, host):
|
||||
port = self.start_server(family)
|
||||
stream = yield self.client.connect(host, port)
|
||||
with closing(stream):
|
||||
stream.write(b"hello")
|
||||
data = yield self.server.streams[0].read_bytes(5)
|
||||
self.assertEqual(data, b"hello")
|
||||
|
||||
def test_connect_ipv4_ipv4(self):
|
||||
self.do_test_connect(socket.AF_INET, '127.0.0.1')
|
||||
|
||||
def test_connect_ipv4_dual(self):
|
||||
self.do_test_connect(socket.AF_INET, 'localhost')
|
||||
|
||||
@skipIfNoIPv6
|
||||
def test_connect_ipv6_ipv6(self):
|
||||
self.skipIfLocalhostV4()
|
||||
self.do_test_connect(socket.AF_INET6, '::1')
|
||||
|
||||
@skipIfNoIPv6
|
||||
def test_connect_ipv6_dual(self):
|
||||
self.skipIfLocalhostV4()
|
||||
if Resolver.configured_class().__name__.endswith('TwistedResolver'):
|
||||
self.skipTest('TwistedResolver does not support multiple addresses')
|
||||
self.do_test_connect(socket.AF_INET6, 'localhost')
|
||||
|
||||
def test_connect_unspec_ipv4(self):
|
||||
self.do_test_connect(socket.AF_UNSPEC, '127.0.0.1')
|
||||
|
||||
@skipIfNoIPv6
|
||||
def test_connect_unspec_ipv6(self):
|
||||
self.skipIfLocalhostV4()
|
||||
self.do_test_connect(socket.AF_UNSPEC, '::1')
|
||||
|
||||
def test_connect_unspec_dual(self):
|
||||
self.do_test_connect(socket.AF_UNSPEC, 'localhost')
|
||||
|
||||
@gen_test
|
||||
def test_refused_ipv4(self):
|
||||
sock, port = bind_unused_port()
|
||||
sock.close()
|
||||
with self.assertRaises(IOError):
|
||||
yield self.client.connect('127.0.0.1', port)
|
||||
|
||||
|
||||
class TestConnectorSplit(unittest.TestCase):
|
||||
def test_one_family(self):
|
||||
# These addresses aren't in the right format, but split doesn't care.
|
||||
primary, secondary = _Connector.split(
|
||||
[(AF1, 'a'),
|
||||
(AF1, 'b')])
|
||||
self.assertEqual(primary, [(AF1, 'a'),
|
||||
(AF1, 'b')])
|
||||
self.assertEqual(secondary, [])
|
||||
|
||||
def test_mixed(self):
|
||||
primary, secondary = _Connector.split(
|
||||
[(AF1, 'a'),
|
||||
(AF2, 'b'),
|
||||
(AF1, 'c'),
|
||||
(AF2, 'd')])
|
||||
self.assertEqual(primary, [(AF1, 'a'), (AF1, 'c')])
|
||||
self.assertEqual(secondary, [(AF2, 'b'), (AF2, 'd')])
|
||||
|
||||
|
||||
class ConnectorTest(AsyncTestCase):
|
||||
class FakeStream(object):
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
def setUp(self):
|
||||
super(ConnectorTest, self).setUp()
|
||||
self.connect_futures = {}
|
||||
self.streams = {}
|
||||
self.addrinfo = [(AF1, 'a'), (AF1, 'b'),
|
||||
(AF2, 'c'), (AF2, 'd')]
|
||||
|
||||
def tearDown(self):
|
||||
# Unless explicitly checked (and popped) in the test, we shouldn't
|
||||
# be closing any streams
|
||||
for stream in self.streams.values():
|
||||
self.assertFalse(stream.closed)
|
||||
super(ConnectorTest, self).tearDown()
|
||||
|
||||
def create_stream(self, af, addr):
|
||||
future = Future()
|
||||
self.connect_futures[(af, addr)] = future
|
||||
return future
|
||||
|
||||
def assert_pending(self, *keys):
|
||||
self.assertEqual(sorted(self.connect_futures.keys()), sorted(keys))
|
||||
|
||||
def resolve_connect(self, af, addr, success):
|
||||
future = self.connect_futures.pop((af, addr))
|
||||
if success:
|
||||
self.streams[addr] = ConnectorTest.FakeStream()
|
||||
future.set_result(self.streams[addr])
|
||||
else:
|
||||
future.set_exception(IOError())
|
||||
|
||||
def start_connect(self, addrinfo):
|
||||
conn = _Connector(addrinfo, self.io_loop, self.create_stream)
|
||||
# Give it a huge timeout; we'll trigger timeouts manually.
|
||||
future = conn.start(3600)
|
||||
return conn, future
|
||||
|
||||
def test_immediate_success(self):
|
||||
conn, future = self.start_connect(self.addrinfo)
|
||||
self.assertEqual(list(self.connect_futures.keys()),
|
||||
[(AF1, 'a')])
|
||||
self.resolve_connect(AF1, 'a', True)
|
||||
self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
|
||||
|
||||
def test_immediate_failure(self):
|
||||
# Fail with just one address.
|
||||
conn, future = self.start_connect([(AF1, 'a')])
|
||||
self.assert_pending((AF1, 'a'))
|
||||
self.resolve_connect(AF1, 'a', False)
|
||||
self.assertRaises(IOError, future.result)
|
||||
|
||||
def test_one_family_second_try(self):
|
||||
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
|
||||
self.assert_pending((AF1, 'a'))
|
||||
self.resolve_connect(AF1, 'a', False)
|
||||
self.assert_pending((AF1, 'b'))
|
||||
self.resolve_connect(AF1, 'b', True)
|
||||
self.assertEqual(future.result(), (AF1, 'b', self.streams['b']))
|
||||
|
||||
def test_one_family_second_try_failure(self):
|
||||
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
|
||||
self.assert_pending((AF1, 'a'))
|
||||
self.resolve_connect(AF1, 'a', False)
|
||||
self.assert_pending((AF1, 'b'))
|
||||
self.resolve_connect(AF1, 'b', False)
|
||||
self.assertRaises(IOError, future.result)
|
||||
|
||||
def test_one_family_second_try_timeout(self):
|
||||
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
|
||||
self.assert_pending((AF1, 'a'))
|
||||
# trigger the timeout while the first lookup is pending;
|
||||
# nothing happens.
|
||||
conn.on_timeout()
|
||||
self.assert_pending((AF1, 'a'))
|
||||
self.resolve_connect(AF1, 'a', False)
|
||||
self.assert_pending((AF1, 'b'))
|
||||
self.resolve_connect(AF1, 'b', True)
|
||||
self.assertEqual(future.result(), (AF1, 'b', self.streams['b']))
|
||||
|
||||
def test_two_families_immediate_failure(self):
|
||||
conn, future = self.start_connect(self.addrinfo)
|
||||
self.assert_pending((AF1, 'a'))
|
||||
self.resolve_connect(AF1, 'a', False)
|
||||
self.assert_pending((AF1, 'b'), (AF2, 'c'))
|
||||
self.resolve_connect(AF1, 'b', False)
|
||||
self.resolve_connect(AF2, 'c', True)
|
||||
self.assertEqual(future.result(), (AF2, 'c', self.streams['c']))
|
||||
|
||||
def test_two_families_timeout(self):
|
||||
conn, future = self.start_connect(self.addrinfo)
|
||||
self.assert_pending((AF1, 'a'))
|
||||
conn.on_timeout()
|
||||
self.assert_pending((AF1, 'a'), (AF2, 'c'))
|
||||
self.resolve_connect(AF2, 'c', True)
|
||||
self.assertEqual(future.result(), (AF2, 'c', self.streams['c']))
|
||||
# resolving 'a' after the connection has completed doesn't start 'b'
|
||||
self.resolve_connect(AF1, 'a', False)
|
||||
self.assert_pending()
|
||||
|
||||
def test_success_after_timeout(self):
|
||||
conn, future = self.start_connect(self.addrinfo)
|
||||
self.assert_pending((AF1, 'a'))
|
||||
conn.on_timeout()
|
||||
self.assert_pending((AF1, 'a'), (AF2, 'c'))
|
||||
self.resolve_connect(AF1, 'a', True)
|
||||
self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
|
||||
# resolving 'c' after completion closes the connection.
|
||||
self.resolve_connect(AF2, 'c', True)
|
||||
self.assertTrue(self.streams.pop('c').closed)
|
||||
|
||||
def test_all_fail(self):
|
||||
conn, future = self.start_connect(self.addrinfo)
|
||||
self.assert_pending((AF1, 'a'))
|
||||
conn.on_timeout()
|
||||
self.assert_pending((AF1, 'a'), (AF2, 'c'))
|
||||
self.resolve_connect(AF2, 'c', False)
|
||||
self.assert_pending((AF1, 'a'), (AF2, 'd'))
|
||||
self.resolve_connect(AF2, 'd', False)
|
||||
# one queue is now empty
|
||||
self.assert_pending((AF1, 'a'))
|
||||
self.resolve_connect(AF1, 'a', False)
|
||||
self.assert_pending((AF1, 'b'))
|
||||
self.assertFalse(future.done())
|
||||
self.resolve_connect(AF1, 'b', False)
|
||||
self.assertRaises(IOError, future.result)
|
||||
|
|
@ -182,6 +182,7 @@ three
|
|||
"""})
|
||||
try:
|
||||
loader.load("test.html").generate()
|
||||
self.fail("did not get expected exception")
|
||||
except ZeroDivisionError:
|
||||
self.assertTrue("# test.html:2" in traceback.format_exc())
|
||||
|
||||
|
|
@ -192,6 +193,7 @@ three{%end%}
|
|||
"""})
|
||||
try:
|
||||
loader.load("test.html").generate()
|
||||
self.fail("did not get expected exception")
|
||||
except ZeroDivisionError:
|
||||
self.assertTrue("# test.html:2" in traceback.format_exc())
|
||||
|
||||
|
|
@ -202,6 +204,7 @@ three{%end%}
|
|||
}, namespace={"_tt_modules": ObjectDict({"Template": lambda path, **kwargs: loader.load(path).generate(**kwargs)})})
|
||||
try:
|
||||
loader.load("base.html").generate()
|
||||
self.fail("did not get expected exception")
|
||||
except ZeroDivisionError:
|
||||
exc_stack = traceback.format_exc()
|
||||
self.assertTrue('# base.html:1' in exc_stack)
|
||||
|
|
@ -214,6 +217,7 @@ three{%end%}
|
|||
})
|
||||
try:
|
||||
loader.load("base.html").generate()
|
||||
self.fail("did not get expected exception")
|
||||
except ZeroDivisionError:
|
||||
self.assertTrue("# sub.html:1 (via base.html:1)" in
|
||||
traceback.format_exc())
|
||||
|
|
@ -225,6 +229,7 @@ three{%end%}
|
|||
})
|
||||
try:
|
||||
loader.load("sub.html").generate()
|
||||
self.fail("did not get expected exception")
|
||||
except ZeroDivisionError:
|
||||
exc_stack = traceback.format_exc()
|
||||
self.assertTrue("# base.html:1" in exc_stack)
|
||||
|
|
@ -240,6 +245,7 @@ three{%end%}
|
|||
"""})
|
||||
try:
|
||||
loader.load("sub.html").generate()
|
||||
self.fail("did not get expected exception")
|
||||
except ZeroDivisionError:
|
||||
self.assertTrue("# sub.html:4 (via base.html:1)" in
|
||||
traceback.format_exc())
|
||||
|
|
@ -252,6 +258,7 @@ three{%end%}
|
|||
})
|
||||
try:
|
||||
loader.load("a.html").generate()
|
||||
self.fail("did not get expected exception")
|
||||
except ZeroDivisionError:
|
||||
self.assertTrue("# c.html:1 (via b.html:1, a.html:1)" in
|
||||
traceback.format_exc())
|
||||
|
|
@ -380,6 +387,20 @@ raw: {% raw name %}""",
|
|||
self.assertEqual(render("foo.py", ["not a string"]),
|
||||
b"""s = "['not a string']"\n""")
|
||||
|
||||
def test_minimize_whitespace(self):
|
||||
# Whitespace including newlines is allowed within template tags
|
||||
# and directives, and this is one way to avoid long lines while
|
||||
# keeping extra whitespace out of the rendered output.
|
||||
loader = DictLoader({'foo.txt': """\
|
||||
{% for i in items
|
||||
%}{% if i > 0 %}, {% end %}{#
|
||||
#}{{i
|
||||
}}{% end
|
||||
%}""",
|
||||
})
|
||||
self.assertEqual(loader.load("foo.txt").generate(items=range(5)),
|
||||
b"0, 1, 2, 3, 4")
|
||||
|
||||
|
||||
class TemplateLoaderTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
|
|
|||
|
|
@ -8,11 +8,12 @@ from tornado.test.util import unittest
|
|||
|
||||
import contextlib
|
||||
import os
|
||||
import traceback
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_environ(name, value):
|
||||
old_value = os.environ.get('name')
|
||||
old_value = os.environ.get(name)
|
||||
os.environ[name] = value
|
||||
|
||||
try:
|
||||
|
|
@ -62,6 +63,39 @@ class AsyncTestCaseTest(AsyncTestCase):
|
|||
self.wait(timeout=0.15)
|
||||
|
||||
|
||||
class AsyncTestCaseWrapperTest(unittest.TestCase):
|
||||
def test_undecorated_generator(self):
|
||||
class Test(AsyncTestCase):
|
||||
def test_gen(self):
|
||||
yield
|
||||
test = Test('test_gen')
|
||||
result = unittest.TestResult()
|
||||
test.run(result)
|
||||
self.assertEqual(len(result.errors), 1)
|
||||
self.assertIn("should be decorated", result.errors[0][1])
|
||||
|
||||
def test_undecorated_generator_with_skip(self):
|
||||
class Test(AsyncTestCase):
|
||||
@unittest.skip("don't run this")
|
||||
def test_gen(self):
|
||||
yield
|
||||
test = Test('test_gen')
|
||||
result = unittest.TestResult()
|
||||
test.run(result)
|
||||
self.assertEqual(len(result.errors), 0)
|
||||
self.assertEqual(len(result.skipped), 1)
|
||||
|
||||
def test_other_return(self):
|
||||
class Test(AsyncTestCase):
|
||||
def test_other_return(self):
|
||||
return 42
|
||||
test = Test('test_other_return')
|
||||
result = unittest.TestResult()
|
||||
test.run(result)
|
||||
self.assertEqual(len(result.errors), 1)
|
||||
self.assertIn("Return value from test method ignored", result.errors[0][1])
|
||||
|
||||
|
||||
class SetUpTearDownTest(unittest.TestCase):
|
||||
def test_set_up_tear_down(self):
|
||||
"""
|
||||
|
|
@ -115,8 +149,17 @@ class GenTest(AsyncTestCase):
|
|||
def test(self):
|
||||
yield gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1)
|
||||
|
||||
with self.assertRaises(ioloop.TimeoutError):
|
||||
# This can't use assertRaises because we need to inspect the
|
||||
# exc_info triple (and not just the exception object)
|
||||
try:
|
||||
test(self)
|
||||
self.fail("did not get expected exception")
|
||||
except ioloop.TimeoutError:
|
||||
# The stack trace should blame the add_timeout line, not just
|
||||
# unrelated IOLoop/testing internals.
|
||||
self.assertIn(
|
||||
"gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1)",
|
||||
traceback.format_exc())
|
||||
|
||||
self.finished = True
|
||||
|
||||
|
|
@ -155,5 +198,23 @@ class GenTest(AsyncTestCase):
|
|||
|
||||
self.finished = True
|
||||
|
||||
def test_with_method_args(self):
|
||||
@gen_test
|
||||
def test_with_args(self, *args):
|
||||
self.assertEqual(args, ('test',))
|
||||
yield gen.Task(self.io_loop.add_callback)
|
||||
|
||||
test_with_args(self, 'test')
|
||||
self.finished = True
|
||||
|
||||
def test_with_method_kwargs(self):
|
||||
@gen_test
|
||||
def test_with_kwargs(self, **kwargs):
|
||||
self.assertDictEqual(kwargs, {'test': 'test'})
|
||||
yield gen.Task(self.io_loop.add_callback)
|
||||
|
||||
test_with_kwargs(self, test='test')
|
||||
self.finished = True
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -470,14 +470,17 @@ if have_twisted:
|
|||
'twisted.internet.test.test_core.ObjectModelIntegrationTest': [],
|
||||
'twisted.internet.test.test_core.SystemEventTestsBuilder': [
|
||||
'test_iterate', # deliberately not supported
|
||||
'test_runAfterCrash', # fails because TwistedIOLoop uses the global reactor
|
||||
] if issubclass(IOLoop.configured_class(), TwistedIOLoop) else [
|
||||
'test_iterate', # deliberately not supported
|
||||
# Fails on TwistedIOLoop and AsyncIOLoop.
|
||||
'test_runAfterCrash',
|
||||
],
|
||||
'twisted.internet.test.test_fdset.ReactorFDSetTestsBuilder': [
|
||||
"test_lostFileDescriptor", # incompatible with epoll and kqueue
|
||||
],
|
||||
'twisted.internet.test.test_process.ProcessTestsBuilder': [
|
||||
# Only work as root. Twisted's "skip" functionality works
|
||||
# with py27+, but not unittest2 on py26.
|
||||
'test_changeGID',
|
||||
'test_changeUID',
|
||||
],
|
||||
# Process tests appear to work on OSX 10.7, but not 10.6
|
||||
#'twisted.internet.test.test_process.PTYProcessTestsBuilder': [
|
||||
|
|
|
|||
|
|
@ -1,14 +1,18 @@
|
|||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
|
||||
# Encapsulate the choice of unittest or unittest2 here.
|
||||
# To be used as 'from tornado.test.util import unittest'.
|
||||
if sys.version_info >= (2, 7):
|
||||
import unittest
|
||||
else:
|
||||
if sys.version_info < (2, 7):
|
||||
# In py26, we must always use unittest2.
|
||||
import unittest2 as unittest
|
||||
else:
|
||||
# Otherwise, use whichever version of unittest was imported in
|
||||
# tornado.testing.
|
||||
from tornado.testing import unittest
|
||||
|
||||
skipIfNonUnix = unittest.skipIf(os.name != 'posix' or sys.platform == 'cygwin',
|
||||
"non-unix platform")
|
||||
|
|
@ -17,3 +21,10 @@ skipIfNonUnix = unittest.skipIf(os.name != 'posix' or sys.platform == 'cygwin',
|
|||
# timing-related tests unreliable.
|
||||
skipOnTravis = unittest.skipIf('TRAVIS' in os.environ,
|
||||
'timing tests unreliable on travis')
|
||||
|
||||
# Set the environment variable NO_NETWORK=1 to disable any tests that
|
||||
# depend on an external network.
|
||||
skipIfNoNetwork = unittest.skipIf('NO_NETWORK' in os.environ,
|
||||
'network access disabled')
|
||||
|
||||
skipIfNoIPv6 = unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present')
|
||||
|
|
|
|||
|
|
@ -151,14 +151,22 @@ class ArgReplacerTest(unittest.TestCase):
|
|||
self.replacer = ArgReplacer(function, 'callback')
|
||||
|
||||
def test_omitted(self):
|
||||
self.assertEqual(self.replacer.replace('new', (1, 2), dict()),
|
||||
args = (1, 2)
|
||||
kwargs = dict()
|
||||
self.assertIs(self.replacer.get_old_value(args, kwargs), None)
|
||||
self.assertEqual(self.replacer.replace('new', args, kwargs),
|
||||
(None, (1, 2), dict(callback='new')))
|
||||
|
||||
def test_position(self):
|
||||
self.assertEqual(self.replacer.replace('new', (1, 2, 'old', 3), dict()),
|
||||
args = (1, 2, 'old', 3)
|
||||
kwargs = dict()
|
||||
self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old')
|
||||
self.assertEqual(self.replacer.replace('new', args, kwargs),
|
||||
('old', [1, 2, 'new', 3], dict()))
|
||||
|
||||
def test_keyword(self):
|
||||
self.assertEqual(self.replacer.replace('new', (1,),
|
||||
dict(y=2, callback='old', z=3)),
|
||||
args = (1,)
|
||||
kwargs = dict(y=2, callback='old', z=3)
|
||||
self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old')
|
||||
self.assertEqual(self.replacer.replace('new', args, kwargs),
|
||||
('old', (1,), dict(y=2, callback='new', z=3)))
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,21 +1,66 @@
|
|||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
|
||||
import traceback
|
||||
|
||||
from tornado.concurrent import Future
|
||||
from tornado import gen
|
||||
from tornado.httpclient import HTTPError
|
||||
from tornado.log import gen_log
|
||||
from tornado.httpclient import HTTPError, HTTPRequest
|
||||
from tornado.log import gen_log, app_log
|
||||
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
|
||||
from tornado.test.util import unittest
|
||||
from tornado.web import Application, RequestHandler
|
||||
from tornado.util import u
|
||||
|
||||
try:
|
||||
import tornado.websocket
|
||||
from tornado.util import _websocket_mask_python
|
||||
except ImportError:
|
||||
# The unittest module presents misleading errors on ImportError
|
||||
# (it acts as if websocket_test could not be found, hiding the underlying
|
||||
# error). If we get an ImportError here (which could happen due to
|
||||
# TORNADO_EXTENSION=1), print some extra information before failing.
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
from tornado.websocket import WebSocketHandler, websocket_connect, WebSocketError
|
||||
|
||||
try:
|
||||
from tornado import speedups
|
||||
except ImportError:
|
||||
speedups = None
|
||||
|
||||
class EchoHandler(WebSocketHandler):
|
||||
|
||||
class TestWebSocketHandler(WebSocketHandler):
|
||||
"""Base class for testing handlers that exposes the on_close event.
|
||||
|
||||
This allows for deterministic cleanup of the associated socket.
|
||||
"""
|
||||
def initialize(self, close_future):
|
||||
self.close_future = close_future
|
||||
|
||||
def on_close(self):
|
||||
self.close_future.set_result((self.close_code, self.close_reason))
|
||||
|
||||
|
||||
class EchoHandler(TestWebSocketHandler):
|
||||
def on_message(self, message):
|
||||
self.write_message(message, isinstance(message, bytes))
|
||||
|
||||
def on_close(self):
|
||||
self.close_future.set_result(None)
|
||||
|
||||
class ErrorInOnMessageHandler(TestWebSocketHandler):
|
||||
def on_message(self, message):
|
||||
1/0
|
||||
|
||||
|
||||
class HeaderHandler(TestWebSocketHandler):
|
||||
def open(self):
|
||||
try:
|
||||
# In a websocket context, many RequestHandler methods
|
||||
# raise RuntimeErrors.
|
||||
self.set_status(503)
|
||||
raise Exception("did not get expected exception")
|
||||
except RuntimeError:
|
||||
pass
|
||||
self.write_message(self.request.headers.get('X-Test', ''))
|
||||
|
||||
|
||||
class NonWebSocketHandler(RequestHandler):
|
||||
|
|
@ -23,14 +68,29 @@ class NonWebSocketHandler(RequestHandler):
|
|||
self.write('ok')
|
||||
|
||||
|
||||
class CloseReasonHandler(TestWebSocketHandler):
|
||||
def open(self):
|
||||
self.close(1001, "goodbye")
|
||||
|
||||
|
||||
class WebSocketTest(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
self.close_future = Future()
|
||||
return Application([
|
||||
('/echo', EchoHandler, dict(close_future=self.close_future)),
|
||||
('/non_ws', NonWebSocketHandler),
|
||||
('/header', HeaderHandler, dict(close_future=self.close_future)),
|
||||
('/close_reason', CloseReasonHandler,
|
||||
dict(close_future=self.close_future)),
|
||||
('/error_in_on_message', ErrorInOnMessageHandler,
|
||||
dict(close_future=self.close_future)),
|
||||
])
|
||||
|
||||
def test_http_request(self):
|
||||
# WS server, HTTP client.
|
||||
response = self.fetch('/echo')
|
||||
self.assertEqual(response.code, 400)
|
||||
|
||||
@gen_test
|
||||
def test_websocket_gen(self):
|
||||
ws = yield websocket_connect(
|
||||
|
|
@ -39,6 +99,8 @@ class WebSocketTest(AsyncHTTPTestCase):
|
|||
ws.write_message('hello')
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, 'hello')
|
||||
ws.close()
|
||||
yield self.close_future
|
||||
|
||||
def test_websocket_callbacks(self):
|
||||
websocket_connect(
|
||||
|
|
@ -49,6 +111,40 @@ class WebSocketTest(AsyncHTTPTestCase):
|
|||
ws.read_message(self.stop)
|
||||
response = self.wait().result()
|
||||
self.assertEqual(response, 'hello')
|
||||
self.close_future.add_done_callback(lambda f: self.stop())
|
||||
ws.close()
|
||||
self.wait()
|
||||
|
||||
@gen_test
|
||||
def test_binary_message(self):
|
||||
ws = yield websocket_connect(
|
||||
'ws://localhost:%d/echo' % self.get_http_port())
|
||||
ws.write_message(b'hello \xe9', binary=True)
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, b'hello \xe9')
|
||||
ws.close()
|
||||
yield self.close_future
|
||||
|
||||
@gen_test
|
||||
def test_unicode_message(self):
|
||||
ws = yield websocket_connect(
|
||||
'ws://localhost:%d/echo' % self.get_http_port())
|
||||
ws.write_message(u('hello \u00e9'))
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, u('hello \u00e9'))
|
||||
ws.close()
|
||||
yield self.close_future
|
||||
|
||||
@gen_test
|
||||
def test_error_in_on_message(self):
|
||||
ws = yield websocket_connect(
|
||||
'ws://localhost:%d/error_in_on_message' % self.get_http_port())
|
||||
ws.write_message('hello')
|
||||
with ExpectLog(app_log, "Uncaught exception"):
|
||||
response = yield ws.read_message()
|
||||
self.assertIs(response, None)
|
||||
ws.close()
|
||||
yield self.close_future
|
||||
|
||||
@gen_test
|
||||
def test_websocket_http_fail(self):
|
||||
|
|
@ -69,13 +165,12 @@ class WebSocketTest(AsyncHTTPTestCase):
|
|||
def test_websocket_network_fail(self):
|
||||
sock, port = bind_unused_port()
|
||||
sock.close()
|
||||
with self.assertRaises(HTTPError) as cm:
|
||||
with self.assertRaises(IOError):
|
||||
with ExpectLog(gen_log, ".*"):
|
||||
yield websocket_connect(
|
||||
'ws://localhost:%d/' % port,
|
||||
io_loop=self.io_loop,
|
||||
connect_timeout=0.01)
|
||||
self.assertEqual(cm.exception.code, 599)
|
||||
connect_timeout=3600)
|
||||
|
||||
@gen_test
|
||||
def test_websocket_close_buffered_data(self):
|
||||
|
|
@ -85,3 +180,134 @@ class WebSocketTest(AsyncHTTPTestCase):
|
|||
ws.write_message('world')
|
||||
ws.stream.close()
|
||||
yield self.close_future
|
||||
|
||||
@gen_test
|
||||
def test_websocket_headers(self):
|
||||
# Ensure that arbitrary headers can be passed through websocket_connect.
|
||||
ws = yield websocket_connect(
|
||||
HTTPRequest('ws://localhost:%d/header' % self.get_http_port(),
|
||||
headers={'X-Test': 'hello'}))
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, 'hello')
|
||||
ws.close()
|
||||
yield self.close_future
|
||||
|
||||
@gen_test
|
||||
def test_server_close_reason(self):
|
||||
ws = yield websocket_connect(
|
||||
'ws://localhost:%d/close_reason' % self.get_http_port())
|
||||
msg = yield ws.read_message()
|
||||
# A message of None means the other side closed the connection.
|
||||
self.assertIs(msg, None)
|
||||
self.assertEqual(ws.close_code, 1001)
|
||||
self.assertEqual(ws.close_reason, "goodbye")
|
||||
|
||||
@gen_test
|
||||
def test_client_close_reason(self):
|
||||
ws = yield websocket_connect(
|
||||
'ws://localhost:%d/echo' % self.get_http_port())
|
||||
ws.close(1001, 'goodbye')
|
||||
code, reason = yield self.close_future
|
||||
self.assertEqual(code, 1001)
|
||||
self.assertEqual(reason, 'goodbye')
|
||||
|
||||
@gen_test
|
||||
def test_check_origin_valid_no_path(self):
|
||||
port = self.get_http_port()
|
||||
|
||||
url = 'ws://localhost:%d/echo' % port
|
||||
headers = {'Origin': 'http://localhost:%d' % port}
|
||||
|
||||
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
|
||||
io_loop=self.io_loop)
|
||||
ws.write_message('hello')
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, 'hello')
|
||||
ws.close()
|
||||
yield self.close_future
|
||||
|
||||
@gen_test
|
||||
def test_check_origin_valid_with_path(self):
|
||||
port = self.get_http_port()
|
||||
|
||||
url = 'ws://localhost:%d/echo' % port
|
||||
headers = {'Origin': 'http://localhost:%d/something' % port}
|
||||
|
||||
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
|
||||
io_loop=self.io_loop)
|
||||
ws.write_message('hello')
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, 'hello')
|
||||
ws.close()
|
||||
yield self.close_future
|
||||
|
||||
@gen_test
|
||||
def test_check_origin_invalid_partial_url(self):
|
||||
port = self.get_http_port()
|
||||
|
||||
url = 'ws://localhost:%d/echo' % port
|
||||
headers = {'Origin': 'localhost:%d' % port}
|
||||
|
||||
with self.assertRaises(HTTPError) as cm:
|
||||
yield websocket_connect(HTTPRequest(url, headers=headers),
|
||||
io_loop=self.io_loop)
|
||||
self.assertEqual(cm.exception.code, 403)
|
||||
|
||||
@gen_test
|
||||
def test_check_origin_invalid(self):
|
||||
port = self.get_http_port()
|
||||
|
||||
url = 'ws://localhost:%d/echo' % port
|
||||
# Host is localhost, which should not be accessible from some other
|
||||
# domain
|
||||
headers = {'Origin': 'http://somewhereelse.com'}
|
||||
|
||||
with self.assertRaises(HTTPError) as cm:
|
||||
yield websocket_connect(HTTPRequest(url, headers=headers),
|
||||
io_loop=self.io_loop)
|
||||
|
||||
self.assertEqual(cm.exception.code, 403)
|
||||
|
||||
@gen_test
|
||||
def test_check_origin_invalid_subdomains(self):
|
||||
port = self.get_http_port()
|
||||
|
||||
url = 'ws://localhost:%d/echo' % port
|
||||
# Subdomains should be disallowed by default. If we could pass a
|
||||
# resolver to websocket_connect we could test sibling domains as well.
|
||||
headers = {'Origin': 'http://subtenant.localhost'}
|
||||
|
||||
with self.assertRaises(HTTPError) as cm:
|
||||
yield websocket_connect(HTTPRequest(url, headers=headers),
|
||||
io_loop=self.io_loop)
|
||||
|
||||
self.assertEqual(cm.exception.code, 403)
|
||||
|
||||
|
||||
class MaskFunctionMixin(object):
|
||||
# Subclasses should define self.mask(mask, data)
|
||||
def test_mask(self):
|
||||
self.assertEqual(self.mask(b'abcd', b''), b'')
|
||||
self.assertEqual(self.mask(b'abcd', b'b'), b'\x03')
|
||||
self.assertEqual(self.mask(b'abcd', b'54321'), b'TVPVP')
|
||||
self.assertEqual(self.mask(b'ZXCV', b'98765432'), b'c`t`olpd')
|
||||
# Include test cases with \x00 bytes (to ensure that the C
|
||||
# extension isn't depending on null-terminated strings) and
|
||||
# bytes with the high bit set (to smoke out signedness issues).
|
||||
self.assertEqual(self.mask(b'\x00\x01\x02\x03',
|
||||
b'\xff\xfb\xfd\xfc\xfe\xfa'),
|
||||
b'\xff\xfa\xff\xff\xfe\xfb')
|
||||
self.assertEqual(self.mask(b'\xff\xfb\xfd\xfc',
|
||||
b'\x00\x01\x02\x03\x04\x05'),
|
||||
b'\xff\xfa\xff\xff\xfb\xfe')
|
||||
|
||||
|
||||
class PythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
|
||||
def mask(self, mask, data):
|
||||
return _websocket_mask_python(mask, data)
|
||||
|
||||
|
||||
@unittest.skipIf(speedups is None, "tornado.speedups module not present")
|
||||
class CythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
|
||||
def mask(self, mask, data):
|
||||
return speedups.websocket_mask(mask, data)
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ from tornado.escape import json_decode
|
|||
from tornado.test.httpserver_test import TypeCheckHandler
|
||||
from tornado.testing import AsyncHTTPTestCase
|
||||
from tornado.util import u
|
||||
from tornado.web import RequestHandler
|
||||
from tornado.wsgi import WSGIApplication, WSGIContainer
|
||||
from tornado.web import RequestHandler, Application
|
||||
from tornado.wsgi import WSGIApplication, WSGIContainer, WSGIAdapter
|
||||
|
||||
|
||||
class WSGIContainerTest(AsyncHTTPTestCase):
|
||||
|
|
@ -74,14 +74,27 @@ class WSGIConnectionTest(httpserver_test.HTTPConnectionTest):
|
|||
return WSGIContainer(validator(WSGIApplication(self.get_handlers())))
|
||||
|
||||
|
||||
def wrap_web_tests():
|
||||
def wrap_web_tests_application():
|
||||
result = {}
|
||||
for cls in web_test.wsgi_safe_tests:
|
||||
class WSGIWrappedTest(cls):
|
||||
class WSGIApplicationWrappedTest(cls):
|
||||
def get_app(self):
|
||||
self.app = WSGIApplication(self.get_handlers(),
|
||||
**self.get_app_kwargs())
|
||||
return WSGIContainer(validator(self.app))
|
||||
result["WSGIWrapped_" + cls.__name__] = WSGIWrappedTest
|
||||
result["WSGIApplication_" + cls.__name__] = WSGIApplicationWrappedTest
|
||||
return result
|
||||
globals().update(wrap_web_tests())
|
||||
globals().update(wrap_web_tests_application())
|
||||
|
||||
|
||||
def wrap_web_tests_adapter():
|
||||
result = {}
|
||||
for cls in web_test.wsgi_safe_tests:
|
||||
class WSGIAdapterWrappedTest(cls):
|
||||
def get_app(self):
|
||||
self.app = Application(self.get_handlers(),
|
||||
**self.get_app_kwargs())
|
||||
return WSGIContainer(validator(WSGIAdapter(self.app)))
|
||||
result["WSGIAdapter_" + cls.__name__] = WSGIAdapterWrappedTest
|
||||
return result
|
||||
globals().update(wrap_web_tests_adapter())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue