update to tornado 4.0 and requests 2.3.0

This commit is contained in:
j 2014-08-12 10:44:01 +02:00
commit f187000dc9
239 changed files with 19071 additions and 20369 deletions

View file

@ -0,0 +1,14 @@
"""Shim to allow python -m tornado.test.
This only works in python 2.7+.
"""
from __future__ import absolute_import, division, print_function, with_statement
from tornado.test.runtests import all, main
# tornado.testing.main autodiscovery relies on 'all' being present in
# the main module, so import it here even though it is not used directly.
# The following line prevents a pyflakes warning.
all = all
main()

View file

@ -67,11 +67,29 @@ class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
self.finish(user)
def _oauth_get_user(self, access_token, callback):
if self.get_argument('fail_in_get_user', None):
raise Exception("failing in get_user")
if access_token != dict(key='uiop', secret='5678'):
raise Exception("incorrect access token %r" % access_token)
callback(dict(email='foo@example.com'))
class OAuth1ClientLoginCoroutineHandler(OAuth1ClientLoginHandler):
"""Replaces OAuth1ClientLoginCoroutineHandler's get() with a coroutine."""
@gen.coroutine
def get(self):
if self.get_argument('oauth_token', None):
# Ensure that any exceptions are set on the returned Future,
# not simply thrown into the surrounding StackContext.
try:
yield self.get_authenticated_user()
except Exception as e:
self.set_status(503)
self.write("got exception: %s" % e)
else:
yield self.authorize_redirect()
class OAuth1ClientRequestParametersHandler(RequestHandler, OAuthMixin):
def initialize(self, version):
self._OAUTH_VERSION = version
@ -255,6 +273,9 @@ class AuthTest(AsyncHTTPTestCase):
dict(version='1.0')),
('/oauth10a/client/login', OAuth1ClientLoginHandler,
dict(test=self, version='1.0a')),
('/oauth10a/client/login_coroutine',
OAuth1ClientLoginCoroutineHandler,
dict(test=self, version='1.0a')),
('/oauth10a/client/request_params',
OAuth1ClientRequestParametersHandler,
dict(version='1.0a')),
@ -348,6 +369,12 @@ class AuthTest(AsyncHTTPTestCase):
self.assertTrue('oauth_nonce' in parsed)
self.assertTrue('oauth_signature' in parsed)
def test_oauth10a_get_user_coroutine_exception(self):
response = self.fetch(
'/oauth10a/client/login_coroutine?oauth_token=zxcv&fail_in_get_user=true',
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
self.assertEqual(response.code, 503)
def test_oauth2_redirect(self):
response = self.fetch('/oauth2/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)

View file

@ -30,6 +30,12 @@ from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, LogTrapTestCase, bind_unused_port, gen_test
try:
from concurrent import futures
except ImportError:
futures = None
class ReturnFutureTest(AsyncTestCase):
@return_future
def sync_future(self, callback):

View file

@ -2,6 +2,7 @@ from __future__ import absolute_import, division, print_function, with_statement
from hashlib import md5
from tornado.escape import utf8
from tornado.httpclient import HTTPRequest
from tornado.stack_context import ExceptionStackContext
from tornado.testing import AsyncHTTPTestCase
@ -21,7 +22,8 @@ if pycurl is not None:
@unittest.skipIf(pycurl is None, "pycurl module not present")
class CurlHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
def get_http_client(self):
client = CurlAsyncHTTPClient(io_loop=self.io_loop)
client = CurlAsyncHTTPClient(io_loop=self.io_loop,
defaults=dict(allow_ipv6=False))
# make sure AsyncHTTPClient magic doesn't give us the wrong class
self.assertTrue(isinstance(client, CurlAsyncHTTPClient))
return client
@ -51,10 +53,10 @@ class DigestAuthHandler(RequestHandler):
assert param_dict['nonce'] == nonce
assert param_dict['username'] == username
assert param_dict['uri'] == self.request.path
h1 = md5('%s:%s:%s' % (username, realm, password)).hexdigest()
h2 = md5('%s:%s' % (self.request.method,
self.request.path)).hexdigest()
digest = md5('%s:%s:%s' % (h1, nonce, h2)).hexdigest()
h1 = md5(utf8('%s:%s:%s' % (username, realm, password))).hexdigest()
h2 = md5(utf8('%s:%s' % (self.request.method,
self.request.path))).hexdigest()
digest = md5(utf8('%s:%s:%s' % (h1, nonce, h2))).hexdigest()
if digest == param_dict['response']:
self.write('ok')
else:
@ -66,15 +68,28 @@ class DigestAuthHandler(RequestHandler):
(realm, nonce, opaque))
class CustomReasonHandler(RequestHandler):
def get(self):
self.set_status(200, "Custom reason")
class CustomFailReasonHandler(RequestHandler):
def get(self):
self.set_status(400, "Custom reason")
@unittest.skipIf(pycurl is None, "pycurl module not present")
class CurlHTTPClientTestCase(AsyncHTTPTestCase):
def setUp(self):
super(CurlHTTPClientTestCase, self).setUp()
self.http_client = CurlAsyncHTTPClient(self.io_loop)
self.http_client = CurlAsyncHTTPClient(self.io_loop,
defaults=dict(allow_ipv6=False))
def get_app(self):
return Application([
('/digest', DigestAuthHandler),
('/custom_reason', CustomReasonHandler),
('/custom_fail_reason', CustomFailReasonHandler),
])
def test_prepare_curl_callback_stack_context(self):
@ -97,3 +112,11 @@ class CurlHTTPClientTestCase(AsyncHTTPTestCase):
response = self.fetch('/digest', auth_mode='digest',
auth_username='foo', auth_password='bar')
self.assertEqual(response.body, b'ok')
def test_custom_reason(self):
response = self.fetch('/custom_reason')
self.assertEqual(response.reason, "Custom reason")
def test_fail_custom_reason(self):
response = self.fetch('/custom_fail_reason')
self.assertEqual(str(response.error), "HTTP 400: Custom reason")

View file

@ -144,7 +144,7 @@ class EscapeTestCase(unittest.TestCase):
(u("<foo>"), u("&lt;foo&gt;")),
(b"<foo>", b"&lt;foo&gt;"),
("<>&\"", "&lt;&gt;&amp;&quot;"),
("<>&\"'", "&lt;&gt;&amp;&quot;&#39;"),
("&amp;", "&amp;amp;"),
(u("<\u00e9>"), u("&lt;\u00e9&gt;")),

View file

@ -1,6 +1,7 @@
from __future__ import absolute_import, division, print_function, with_statement
import contextlib
import datetime
import functools
import sys
import textwrap
@ -8,7 +9,7 @@ import time
import platform
import weakref
from tornado.concurrent import return_future
from tornado.concurrent import return_future, Future
from tornado.escape import url_escape
from tornado.httpclient import AsyncHTTPClient
from tornado.ioloop import IOLoop
@ -20,6 +21,10 @@ from tornado.web import Application, RequestHandler, asynchronous, HTTPError
from tornado import gen
try:
from concurrent import futures
except ImportError:
futures = None
skipBefore33 = unittest.skipIf(sys.version_info < (3, 3), 'PEP 380 not available')
skipNotCPython = unittest.skipIf(platform.python_implementation() != 'CPython',
@ -281,18 +286,67 @@ class GenEngineTest(AsyncTestCase):
self.stop()
self.run_gen(f)
def test_multi_delayed(self):
def test_multi_dict(self):
@gen.engine
def f():
(yield gen.Callback("k1"))("v1")
(yield gen.Callback("k2"))("v2")
results = yield dict(foo=gen.Wait("k1"), bar=gen.Wait("k2"))
self.assertEqual(results, dict(foo="v1", bar="v2"))
self.stop()
self.run_gen(f)
# The following tests explicitly run with both gen.Multi
# and gen.multi_future (Task returns a Future, so it can be used
# with either).
def test_multi_yieldpoint_delayed(self):
@gen.engine
def f():
# callbacks run at different times
responses = yield [
responses = yield gen.Multi([
gen.Task(self.delay_callback, 3, arg="v1"),
gen.Task(self.delay_callback, 1, arg="v2"),
]
])
self.assertEqual(responses, ["v1", "v2"])
self.stop()
self.run_gen(f)
def test_multi_yieldpoint_dict_delayed(self):
@gen.engine
def f():
# callbacks run at different times
responses = yield gen.Multi(dict(
foo=gen.Task(self.delay_callback, 3, arg="v1"),
bar=gen.Task(self.delay_callback, 1, arg="v2"),
))
self.assertEqual(responses, dict(foo="v1", bar="v2"))
self.stop()
self.run_gen(f)
def test_multi_future_delayed(self):
@gen.engine
def f():
# callbacks run at different times
responses = yield gen.multi_future([
gen.Task(self.delay_callback, 3, arg="v1"),
gen.Task(self.delay_callback, 1, arg="v2"),
])
self.assertEqual(responses, ["v1", "v2"])
self.stop()
self.run_gen(f)
def test_multi_future_dict_delayed(self):
@gen.engine
def f():
# callbacks run at different times
responses = yield gen.multi_future(dict(
foo=gen.Task(self.delay_callback, 3, arg="v1"),
bar=gen.Task(self.delay_callback, 1, arg="v2"),
))
self.assertEqual(responses, dict(foo="v1", bar="v2"))
self.stop()
self.run_gen(f)
@skipOnTravis
@gen_test
def test_multi_performance(self):
@ -304,6 +358,23 @@ class GenEngineTest(AsyncTestCase):
end = time.time()
self.assertLess(end - start, 1.0)
@gen_test
def test_multi_empty(self):
# Empty lists or dicts should return the same type.
x = yield []
self.assertTrue(isinstance(x, list))
y = yield {}
self.assertTrue(isinstance(y, dict))
@gen_test
def test_multi_mixed_types(self):
# A YieldPoint (Wait) and Future (Task) can be combined
# (and use the YieldPoint codepath)
(yield gen.Callback("k1"))("v1")
responses = yield [gen.Wait("k1"),
gen.Task(self.delay_callback, 3, arg="v2")]
self.assertEqual(responses, ["v1", "v2"])
@gen_test
def test_future(self):
result = yield self.async_future(1)
@ -314,6 +385,11 @@ class GenEngineTest(AsyncTestCase):
results = yield [self.async_future(1), self.async_future(2)]
self.assertEqual(results, [1, 2])
@gen_test
def test_multi_dict_future(self):
results = yield dict(foo=self.async_future(1), bar=self.async_future(2))
self.assertEqual(results, dict(foo=1, bar=2))
def test_arguments(self):
@gen.engine
def f():
@ -698,8 +774,14 @@ class GenCoroutineTest(AsyncTestCase):
def test_replace_context_exception(self):
# Test exception handling: exceptions thrown into the stack context
# can be caught and replaced.
# Note that this test and the following are for behavior that is
# not really supported any more: coroutines no longer create a
# stack context automatically; but one is created after the first
# YieldPoint (i.e. not a Future).
@gen.coroutine
def f2():
(yield gen.Callback(1))()
yield gen.Wait(1)
self.io_loop.add_callback(lambda: 1 / 0)
try:
yield gen.Task(self.io_loop.add_timeout,
@ -718,6 +800,8 @@ class GenCoroutineTest(AsyncTestCase):
# can be caught and ignored.
@gen.coroutine
def f2():
(yield gen.Callback(1))()
yield gen.Wait(1)
self.io_loop.add_callback(lambda: 1 / 0)
try:
yield gen.Task(self.io_loop.add_timeout,
@ -729,6 +813,31 @@ class GenCoroutineTest(AsyncTestCase):
self.assertEqual(result, 42)
self.finished = True
@gen_test
def test_moment(self):
calls = []
@gen.coroutine
def f(name, yieldable):
for i in range(5):
calls.append(name)
yield yieldable
# First, confirm the behavior without moment: each coroutine
# monopolizes the event loop until it finishes.
immediate = Future()
immediate.set_result(None)
yield [f('a', immediate), f('b', immediate)]
self.assertEqual(''.join(calls), 'aaaaabbbbb')
# With moment, they take turns.
calls = []
yield [f('a', gen.moment), f('b', gen.moment)]
self.assertEqual(''.join(calls), 'ababababab')
self.finished = True
calls = []
yield [f('a', gen.moment), f('b', immediate)]
self.assertEqual(''.join(calls), 'abbbbbaaaa')
class GenSequenceHandler(RequestHandler):
@asynchronous
@ -803,7 +912,6 @@ class GenExceptionHandler(RequestHandler):
class GenCoroutineExceptionHandler(RequestHandler):
@asynchronous
@gen.coroutine
def get(self):
# This test depends on the order of the two decorators.
@ -909,5 +1017,55 @@ class GenWebTest(AsyncHTTPTestCase):
response = self.fetch('/async_prepare_error')
self.assertEqual(response.code, 403)
class WithTimeoutTest(AsyncTestCase):
@gen_test
def test_timeout(self):
with self.assertRaises(gen.TimeoutError):
yield gen.with_timeout(datetime.timedelta(seconds=0.1),
Future())
@gen_test
def test_completes_before_timeout(self):
future = Future()
self.io_loop.add_timeout(datetime.timedelta(seconds=0.1),
lambda: future.set_result('asdf'))
result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
future)
self.assertEqual(result, 'asdf')
@gen_test
def test_fails_before_timeout(self):
future = Future()
self.io_loop.add_timeout(
datetime.timedelta(seconds=0.1),
lambda: future.set_exception(ZeroDivisionError))
with self.assertRaises(ZeroDivisionError):
yield gen.with_timeout(datetime.timedelta(seconds=3600), future)
@gen_test
def test_already_resolved(self):
future = Future()
future.set_result('asdf')
result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
future)
self.assertEqual(result, 'asdf')
@unittest.skipIf(futures is None, 'futures module not present')
@gen_test
def test_timeout_concurrent_future(self):
with futures.ThreadPoolExecutor(1) as executor:
with self.assertRaises(gen.TimeoutError):
yield gen.with_timeout(self.io_loop.time(),
executor.submit(time.sleep, 0.1))
@unittest.skipIf(futures is None, 'futures module not present')
@gen_test
def test_completed_concurrent_future(self):
with futures.ThreadPoolExecutor(1) as executor:
yield gen.with_timeout(datetime.timedelta(seconds=3600),
executor.submit(lambda: None))
if __name__ == '__main__':
unittest.main()

View file

@ -18,7 +18,7 @@ from tornado.log import gen_log
from tornado import netutil
from tornado.stack_context import ExceptionStackContext, NullContext
from tornado.testing import AsyncHTTPTestCase, bind_unused_port, gen_test, ExpectLog
from tornado.test.util import unittest
from tornado.test.util import unittest, skipOnTravis
from tornado.util import u, bytes_type
from tornado.web import Application, RequestHandler, url
@ -110,6 +110,7 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase):
url("/all_methods", AllMethodsHandler),
], gzip=True)
@skipOnTravis
def test_hello_world(self):
response = self.fetch("/hello")
self.assertEqual(response.code, 200)
@ -309,7 +310,7 @@ Transfer-Encoding: chunked
self.assertIs(exc_info[0][0], ZeroDivisionError)
def test_configure_defaults(self):
defaults = dict(user_agent='TestDefaultUserAgent')
defaults = dict(user_agent='TestDefaultUserAgent', allow_ipv6=False)
# Construct a new instance of the configured client class
client = self.http_client.__class__(self.io_loop, force_instance=True,
defaults=defaults)
@ -355,11 +356,10 @@ Transfer-Encoding: chunked
@gen_test
def test_future_http_error(self):
try:
with self.assertRaises(HTTPError) as context:
yield self.http_client.fetch(self.get_url('/notfound'))
except HTTPError as e:
self.assertEqual(e.code, 404)
self.assertEqual(e.response.code, 404)
self.assertEqual(context.exception.code, 404)
self.assertEqual(context.exception.response.code, 404)
@gen_test
def test_reuse_request_from_response(self):
@ -387,6 +387,19 @@ Transfer-Encoding: chunked
allow_nonstandard_methods=True)
self.assertEqual(response.body, b'OTHER')
@gen_test
def test_body(self):
hello_url = self.get_url('/hello')
with self.assertRaises(AssertionError) as context:
yield self.http_client.fetch(hello_url, body='data')
self.assertTrue('must be empty' in str(context.exception))
with self.assertRaises(AssertionError) as context:
yield self.http_client.fetch(hello_url, method='POST')
self.assertTrue('must not be empty' in str(context.exception))
class RequestProxyTest(unittest.TestCase):
def test_request_set(self):
@ -433,17 +446,22 @@ class HTTPResponseTestCase(unittest.TestCase):
class SyncHTTPClientTest(unittest.TestCase):
def setUp(self):
if IOLoop.configured_class().__name__ == 'TwistedIOLoop':
if IOLoop.configured_class().__name__ in ('TwistedIOLoop',
'AsyncIOMainLoop'):
# TwistedIOLoop only supports the global reactor, so we can't have
# separate IOLoops for client and server threads.
# AsyncIOMainLoop doesn't work with the default policy
# (although it could with some tweaks to this test and a
# policy that created loops for non-main threads).
raise unittest.SkipTest(
'Sync HTTPClient not compatible with TwistedIOLoop')
'Sync HTTPClient not compatible with TwistedIOLoop or '
'AsyncIOMainLoop')
self.server_ioloop = IOLoop()
sock, self.port = bind_unused_port()
app = Application([('/', HelloWorldHandler)])
server = HTTPServer(app, io_loop=self.server_ioloop)
server.add_socket(sock)
self.server = HTTPServer(app, io_loop=self.server_ioloop)
self.server.add_socket(sock)
self.server_thread = threading.Thread(target=self.server_ioloop.start)
self.server_thread.start()
@ -451,7 +469,10 @@ class SyncHTTPClientTest(unittest.TestCase):
self.http_client = HTTPClient()
def tearDown(self):
self.server_ioloop.add_callback(self.server_ioloop.stop)
def stop_server():
self.server.stop()
self.server_ioloop.stop()
self.server_ioloop.add_callback(stop_server)
self.server_thread.join()
self.http_client.close()
self.server_ioloop.close(all_fds=True)
@ -469,3 +490,28 @@ class SyncHTTPClientTest(unittest.TestCase):
with self.assertRaises(HTTPError) as assertion:
self.http_client.fetch(self.get_url('/notfound'))
self.assertEqual(assertion.exception.code, 404)
class HTTPRequestTestCase(unittest.TestCase):
def test_headers(self):
request = HTTPRequest('http://example.com', headers={'foo': 'bar'})
self.assertEqual(request.headers, {'foo': 'bar'})
def test_headers_setter(self):
request = HTTPRequest('http://example.com')
request.headers = {'bar': 'baz'}
self.assertEqual(request.headers, {'bar': 'baz'})
def test_null_headers_setter(self):
request = HTTPRequest('http://example.com')
request.headers = None
self.assertEqual(request.headers, {})
def test_body(self):
request = HTTPRequest('http://example.com', body='foo')
self.assertEqual(request.body, utf8('foo'))
def test_body_setter(self):
request = HTTPRequest('http://example.com')
request.body = 'foo'
self.assertEqual(request.body, utf8('foo'))

View file

@ -2,20 +2,23 @@
from __future__ import absolute_import, division, print_function, with_statement
from tornado import httpclient, simple_httpclient, netutil
from tornado.escape import json_decode, utf8, _unicode, recursive_unicode, native_str
from tornado import netutil
from tornado.escape import json_decode, json_encode, utf8, _unicode, recursive_unicode, native_str
from tornado import gen
from tornado.http1connection import HTTP1Connection
from tornado.httpserver import HTTPServer
from tornado.httputil import HTTPHeaders
from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine
from tornado.iostream import IOStream
from tornado.log import gen_log
from tornado.netutil import ssl_options_to_context, Resolver
from tornado.log import gen_log, app_log
from tornado.netutil import ssl_options_to_context
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog
from tornado.test.util import unittest
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test
from tornado.test.util import unittest, skipOnTravis
from tornado.util import u, bytes_type
from tornado.web import Application, RequestHandler, asynchronous
from tornado.web import Application, RequestHandler, asynchronous, stream_request_body
from contextlib import closing
import datetime
import gzip
import os
import shutil
import socket
@ -23,6 +26,28 @@ import ssl
import sys
import tempfile
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO # python 2
def read_stream_body(stream, callback):
"""Reads an HTTP response from `stream` and runs callback with its
headers and body."""
chunks = []
class Delegate(HTTPMessageDelegate):
def headers_received(self, start_line, headers):
self.headers = headers
def data_received(self, chunk):
chunks.append(chunk)
def finish(self):
callback((self.headers, b''.join(chunks)))
conn = HTTP1Connection(stream, True)
conn.read_response(Delegate())
class HandlerBaseTestCase(AsyncHTTPTestCase):
def get_app(self):
@ -86,11 +111,13 @@ class SSLTestMixin(object):
# connection, rather than waiting for a timeout or otherwise
# misbehaving.
with ExpectLog(gen_log, '(SSL Error|uncaught exception)'):
self.http_client.fetch(self.get_url("/").replace('https:', 'http:'),
self.stop,
request_timeout=3600,
connect_timeout=3600)
response = self.wait()
with ExpectLog(gen_log, 'Uncaught exception', required=False):
self.http_client.fetch(
self.get_url("/").replace('https:', 'http:'),
self.stop,
request_timeout=3600,
connect_timeout=3600)
response = self.wait()
self.assertEqual(response.code, 599)
# Python's SSL implementation differs significantly between versions.
@ -163,18 +190,7 @@ class MultipartTestHandler(RequestHandler):
})
class RawRequestHTTPConnection(simple_httpclient._HTTPConnection):
def set_request(self, request):
self.__next_request = request
def _on_connect(self):
self.stream.write(self.__next_request)
self.__next_request = None
self.stream.read_until(b"\r\n\r\n", self._on_headers)
# This test is also called from wsgi_test
class HTTPConnectionTest(AsyncHTTPTestCase):
def get_handlers(self):
return [("/multipart", MultipartTestHandler),
@ -184,23 +200,16 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
return Application(self.get_handlers())
def raw_fetch(self, headers, body):
with closing(Resolver(io_loop=self.io_loop)) as resolver:
with closing(SimpleAsyncHTTPClient(self.io_loop,
resolver=resolver)) as client:
conn = RawRequestHTTPConnection(
self.io_loop, client,
httpclient._RequestProxy(
httpclient.HTTPRequest(self.get_url("/")),
dict(httpclient.HTTPRequest._DEFAULTS)),
None, self.stop,
1024 * 1024, resolver)
conn.set_request(
b"\r\n".join(headers +
[utf8("Content-Length: %d\r\n" % len(body))]) +
b"\r\n" + body)
response = self.wait()
response.rethrow()
return response
with closing(IOStream(socket.socket())) as stream:
stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
self.wait()
stream.write(
b"\r\n".join(headers +
[utf8("Content-Length: %d\r\n" % len(body))]) +
b"\r\n" + body)
read_stream_body(stream, self.stop)
headers, body = self.wait()
return body
def test_multipart_form(self):
# Encodings here are tricky: Headers are latin1, bodies can be
@ -211,17 +220,17 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
b"X-Header-encoding-test: \xe9",
],
b"\r\n".join([
b"Content-Disposition: form-data; name=argument",
b"",
u("\u00e1").encode("utf-8"),
b"--1234567890",
u('Content-Disposition: form-data; name="files"; filename="\u00f3"').encode("utf8"),
b"",
u("\u00fa").encode("utf-8"),
b"--1234567890--",
b"",
b"Content-Disposition: form-data; name=argument",
b"",
u("\u00e1").encode("utf-8"),
b"--1234567890",
u('Content-Disposition: form-data; name="files"; filename="\u00f3"').encode("utf8"),
b"",
u("\u00fa").encode("utf-8"),
b"--1234567890--",
b"",
]))
data = json_decode(response.body)
data = json_decode(response)
self.assertEqual(u("\u00e9"), data["header"])
self.assertEqual(u("\u00e1"), data["argument"])
self.assertEqual(u("\u00f3"), data["filename"])
@ -344,6 +353,21 @@ class HTTPServerTest(AsyncHTTPTestCase):
self.assertEqual(200, response.code)
self.assertEqual(json_decode(response.body), {})
def test_malformed_body(self):
# parse_qs is pretty forgiving, but it will fail on python 3
# if the data is not utf8. On python 2 parse_qs will work,
# but then the recursive_unicode call in EchoHandler will
# fail.
if str is bytes_type:
return
with ExpectLog(gen_log, 'Invalid x-www-form-urlencoded body'):
response = self.fetch(
'/echo', method="POST",
headers={'Content-Type': 'application/x-www-form-urlencoded'},
body=b'\xe9')
self.assertEqual(200, response.code)
self.assertEqual(b'{}', response.body)
class HTTPServerRawTest(AsyncHTTPTestCase):
def get_app(self):
@ -382,6 +406,25 @@ class HTTPServerRawTest(AsyncHTTPTestCase):
self.stop)
self.wait()
def test_chunked_request_body(self):
# Chunked requests are not widely supported and we don't have a way
# to generate them in AsyncHTTPClient, but HTTPServer will read them.
self.stream.write(b"""\
POST /echo HTTP/1.1
Transfer-Encoding: chunked
Content-Type: application/x-www-form-urlencoded
4
foo=
3
bar
0
""".replace(b"\n", b"\r\n"))
read_stream_body(self.stream, self.stop)
headers, response = self.wait()
self.assertEqual(json_decode(response), {u('foo'): [u('bar')]})
class XHeaderTest(HandlerBaseTestCase):
class Handler(RequestHandler):
@ -497,31 +540,40 @@ class UnixSocketTest(AsyncTestCase):
def setUp(self):
super(UnixSocketTest, self).setUp()
self.tmpdir = tempfile.mkdtemp()
self.sockfile = os.path.join(self.tmpdir, "test.sock")
sock = netutil.bind_unix_socket(self.sockfile)
app = Application([("/hello", HelloWorldRequestHandler)])
self.server = HTTPServer(app, io_loop=self.io_loop)
self.server.add_socket(sock)
self.stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop)
self.stream.connect(self.sockfile, self.stop)
self.wait()
def tearDown(self):
self.stream.close()
self.server.stop()
shutil.rmtree(self.tmpdir)
super(UnixSocketTest, self).tearDown()
def test_unix_socket(self):
sockfile = os.path.join(self.tmpdir, "test.sock")
sock = netutil.bind_unix_socket(sockfile)
app = Application([("/hello", HelloWorldRequestHandler)])
server = HTTPServer(app, io_loop=self.io_loop)
server.add_socket(sock)
stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop)
stream.connect(sockfile, self.stop)
self.wait()
stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
stream.read_until(b"\r\n", self.stop)
self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
self.stream.read_until(b"\r\n", self.stop)
response = self.wait()
self.assertEqual(response, b"HTTP/1.0 200 OK\r\n")
stream.read_until(b"\r\n\r\n", self.stop)
self.stream.read_until(b"\r\n\r\n", self.stop)
headers = HTTPHeaders.parse(self.wait().decode('latin1'))
stream.read_bytes(int(headers["Content-Length"]), self.stop)
self.stream.read_bytes(int(headers["Content-Length"]), self.stop)
body = self.wait()
self.assertEqual(body, b"Hello world")
stream.close()
server.stop()
def test_unix_socket_bad_request(self):
# Unix sockets don't have remote addresses so they just return an
# empty string.
with ExpectLog(gen_log, "Malformed HTTP message from"):
self.stream.write(b"garbage\r\n\r\n")
self.stream.read_until_close(self.stop)
response = self.wait()
self.assertEqual(response, b"")
class KeepAliveTest(AsyncHTTPTestCase):
@ -586,8 +638,8 @@ class KeepAliveTest(AsyncHTTPTestCase):
return headers
def read_response(self):
headers = self.read_headers()
self.stream.read_bytes(int(headers['Content-Length']), self.stop)
self.headers = self.read_headers()
self.stream.read_bytes(int(self.headers['Content-Length']), self.stop)
body = self.wait()
self.assertEqual(b'Hello world', body)
@ -621,6 +673,7 @@ class KeepAliveTest(AsyncHTTPTestCase):
self.stream.read_until_close(callback=self.stop)
data = self.wait()
self.assertTrue(not data)
self.assertTrue('Connection' not in self.headers)
self.close()
def test_http10_keepalive(self):
@ -628,8 +681,10 @@ class KeepAliveTest(AsyncHTTPTestCase):
self.connect()
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
self.read_response()
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
self.read_response()
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.close()
def test_pipelined_requests(self):
@ -659,3 +714,322 @@ class KeepAliveTest(AsyncHTTPTestCase):
self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n')
self.read_headers()
self.close()
class GzipBaseTest(object):
def get_app(self):
return Application([('/', EchoHandler)])
def post_gzip(self, body):
bytesio = BytesIO()
gzip_file = gzip.GzipFile(mode='w', fileobj=bytesio)
gzip_file.write(utf8(body))
gzip_file.close()
compressed_body = bytesio.getvalue()
return self.fetch('/', method='POST', body=compressed_body,
headers={'Content-Encoding': 'gzip'})
def test_uncompressed(self):
response = self.fetch('/', method='POST', body='foo=bar')
self.assertEquals(json_decode(response.body), {u('foo'): [u('bar')]})
class GzipTest(GzipBaseTest, AsyncHTTPTestCase):
def get_httpserver_options(self):
return dict(decompress_request=True)
def test_gzip(self):
response = self.post_gzip('foo=bar')
self.assertEquals(json_decode(response.body), {u('foo'): [u('bar')]})
class GzipUnsupportedTest(GzipBaseTest, AsyncHTTPTestCase):
def test_gzip_unsupported(self):
# Gzip support is opt-in; without it the server fails to parse
# the body (but parsing form bodies is currently just a log message,
# not a fatal error).
with ExpectLog(gen_log, "Unsupported Content-Encoding"):
response = self.post_gzip('foo=bar')
self.assertEquals(json_decode(response.body), {})
class StreamingChunkSizeTest(AsyncHTTPTestCase):
# 50 characters long, and repetitive so it can be compressed.
BODY = b'01234567890123456789012345678901234567890123456789'
CHUNK_SIZE = 16
def get_http_client(self):
# body_producer doesn't work on curl_httpclient, so override the
# configured AsyncHTTPClient implementation.
return SimpleAsyncHTTPClient(io_loop=self.io_loop)
def get_httpserver_options(self):
return dict(chunk_size=self.CHUNK_SIZE, decompress_request=True)
class MessageDelegate(HTTPMessageDelegate):
def __init__(self, connection):
self.connection = connection
def headers_received(self, start_line, headers):
self.chunk_lengths = []
def data_received(self, chunk):
self.chunk_lengths.append(len(chunk))
def finish(self):
response_body = utf8(json_encode(self.chunk_lengths))
self.connection.write_headers(
ResponseStartLine('HTTP/1.1', 200, 'OK'),
HTTPHeaders({'Content-Length': str(len(response_body))}))
self.connection.write(response_body)
self.connection.finish()
def get_app(self):
class App(HTTPServerConnectionDelegate):
def start_request(self, connection):
return StreamingChunkSizeTest.MessageDelegate(connection)
return App()
def fetch_chunk_sizes(self, **kwargs):
response = self.fetch('/', method='POST', **kwargs)
response.rethrow()
chunks = json_decode(response.body)
self.assertEqual(len(self.BODY), sum(chunks))
for chunk_size in chunks:
self.assertLessEqual(chunk_size, self.CHUNK_SIZE,
'oversized chunk: ' + str(chunks))
self.assertGreater(chunk_size, 0,
'empty chunk: ' + str(chunks))
return chunks
def compress(self, body):
bytesio = BytesIO()
gzfile = gzip.GzipFile(mode='w', fileobj=bytesio)
gzfile.write(body)
gzfile.close()
compressed = bytesio.getvalue()
if len(compressed) >= len(body):
raise Exception("body did not shrink when compressed")
return compressed
def test_regular_body(self):
chunks = self.fetch_chunk_sizes(body=self.BODY)
# Without compression we know exactly what to expect.
self.assertEqual([16, 16, 16, 2], chunks)
def test_compressed_body(self):
self.fetch_chunk_sizes(body=self.compress(self.BODY),
headers={'Content-Encoding': 'gzip'})
# Compression creates irregular boundaries so the assertions
# in fetch_chunk_sizes are as specific as we can get.
def test_chunked_body(self):
def body_producer(write):
write(self.BODY[:20])
write(self.BODY[20:])
chunks = self.fetch_chunk_sizes(body_producer=body_producer)
# HTTP chunk boundaries translate to application-visible breaks
self.assertEqual([16, 4, 16, 14], chunks)
def test_chunked_compressed(self):
compressed = self.compress(self.BODY)
self.assertGreater(len(compressed), 20)
def body_producer(write):
write(compressed[:20])
write(compressed[20:])
self.fetch_chunk_sizes(body_producer=body_producer,
headers={'Content-Encoding': 'gzip'})
class MaxHeaderSizeTest(AsyncHTTPTestCase):
def get_app(self):
return Application([('/', HelloWorldRequestHandler)])
def get_httpserver_options(self):
return dict(max_header_size=1024)
def test_small_headers(self):
response = self.fetch("/", headers={'X-Filler': 'a' * 100})
response.rethrow()
self.assertEqual(response.body, b"Hello world")
def test_large_headers(self):
with ExpectLog(gen_log, "Unsatisfiable read"):
response = self.fetch("/", headers={'X-Filler': 'a' * 1000})
self.assertEqual(response.code, 599)
@skipOnTravis
class IdleTimeoutTest(AsyncHTTPTestCase):
def get_app(self):
return Application([('/', HelloWorldRequestHandler)])
def get_httpserver_options(self):
return dict(idle_connection_timeout=0.1)
def setUp(self):
super(IdleTimeoutTest, self).setUp()
self.streams = []
def tearDown(self):
super(IdleTimeoutTest, self).tearDown()
for stream in self.streams:
stream.close()
def connect(self):
stream = IOStream(socket.socket())
stream.connect(('localhost', self.get_http_port()), self.stop)
self.wait()
self.streams.append(stream)
return stream
def test_unused_connection(self):
stream = self.connect()
stream.set_close_callback(self.stop)
self.wait()
def test_idle_after_use(self):
stream = self.connect()
stream.set_close_callback(lambda: self.stop("closed"))
# Use the connection twice to make sure keep-alives are working
for i in range(2):
stream.write(b"GET / HTTP/1.1\r\n\r\n")
stream.read_until(b"\r\n\r\n", self.stop)
self.wait()
stream.read_bytes(11, self.stop)
data = self.wait()
self.assertEqual(data, b"Hello world")
# Now let the timeout trigger and close the connection.
data = self.wait()
self.assertEqual(data, "closed")
class BodyLimitsTest(AsyncHTTPTestCase):
def get_app(self):
class BufferedHandler(RequestHandler):
def put(self):
self.write(str(len(self.request.body)))
@stream_request_body
class StreamingHandler(RequestHandler):
def initialize(self):
self.bytes_read = 0
def prepare(self):
if 'expected_size' in self.request.arguments:
self.request.connection.set_max_body_size(
int(self.get_argument('expected_size')))
if 'body_timeout' in self.request.arguments:
self.request.connection.set_body_timeout(
float(self.get_argument('body_timeout')))
def data_received(self, data):
self.bytes_read += len(data)
def put(self):
self.write(str(self.bytes_read))
return Application([('/buffered', BufferedHandler),
('/streaming', StreamingHandler)])
def get_httpserver_options(self):
return dict(body_timeout=3600, max_body_size=4096)
def get_http_client(self):
# body_producer doesn't work on curl_httpclient, so override the
# configured AsyncHTTPClient implementation.
return SimpleAsyncHTTPClient(io_loop=self.io_loop)
def test_small_body(self):
response = self.fetch('/buffered', method='PUT', body=b'a' * 4096)
self.assertEqual(response.body, b'4096')
response = self.fetch('/streaming', method='PUT', body=b'a' * 4096)
self.assertEqual(response.body, b'4096')
def test_large_body_buffered(self):
with ExpectLog(gen_log, '.*Content-Length too long'):
response = self.fetch('/buffered', method='PUT', body=b'a' * 10240)
self.assertEqual(response.code, 599)
def test_large_body_buffered_chunked(self):
with ExpectLog(gen_log, '.*chunked body too large'):
response = self.fetch('/buffered', method='PUT',
body_producer=lambda write: write(b'a' * 10240))
self.assertEqual(response.code, 599)
def test_large_body_streaming(self):
with ExpectLog(gen_log, '.*Content-Length too long'):
response = self.fetch('/streaming', method='PUT', body=b'a' * 10240)
self.assertEqual(response.code, 599)
def test_large_body_streaming_chunked(self):
with ExpectLog(gen_log, '.*chunked body too large'):
response = self.fetch('/streaming', method='PUT',
body_producer=lambda write: write(b'a' * 10240))
self.assertEqual(response.code, 599)
def test_large_body_streaming_override(self):
response = self.fetch('/streaming?expected_size=10240', method='PUT',
body=b'a' * 10240)
self.assertEqual(response.body, b'10240')
def test_large_body_streaming_chunked_override(self):
response = self.fetch('/streaming?expected_size=10240', method='PUT',
body_producer=lambda write: write(b'a' * 10240))
self.assertEqual(response.body, b'10240')
@gen_test
def test_timeout(self):
stream = IOStream(socket.socket())
try:
yield stream.connect(('127.0.0.1', self.get_http_port()))
# Use a raw stream because AsyncHTTPClient won't let us read a
# response without finishing a body.
stream.write(b'PUT /streaming?body_timeout=0.1 HTTP/1.0\r\n'
b'Content-Length: 42\r\n\r\n')
with ExpectLog(gen_log, 'Timeout reading body'):
response = yield stream.read_until_close()
self.assertEqual(response, b'')
finally:
stream.close()
@gen_test
def test_body_size_override_reset(self):
# The max_body_size override is reset between requests.
stream = IOStream(socket.socket())
try:
yield stream.connect(('127.0.0.1', self.get_http_port()))
# Use a raw stream so we can make sure it's all on one connection.
stream.write(b'PUT /streaming?expected_size=10240 HTTP/1.1\r\n'
b'Content-Length: 10240\r\n\r\n')
stream.write(b'a' * 10240)
headers, response = yield gen.Task(read_stream_body, stream)
self.assertEqual(response, b'10240')
# Without the ?expected_size parameter, we get the old default value
stream.write(b'PUT /streaming HTTP/1.1\r\n'
b'Content-Length: 10240\r\n\r\n')
with ExpectLog(gen_log, '.*Content-Length too long'):
data = yield stream.read_until_close()
self.assertEqual(data, b'')
finally:
stream.close()
class LegacyInterfaceTest(AsyncHTTPTestCase):
def get_app(self):
# The old request_callback interface does not implement the
# delegate interface, and writes its response via request.write
# instead of request.connection.write_headers.
def handle_request(request):
message = b"Hello world"
request.write(utf8("HTTP/1.1 200 OK\r\n"
"Content-Length: %d\r\n\r\n" % len(message)))
request.write(message)
request.finish()
return handle_request
def test_legacy_interface(self):
response = self.fetch('/')
self.assertEqual(response.body, b"Hello world")

View file

@ -13,6 +13,7 @@ class ImportTest(unittest.TestCase):
# import tornado.curl_httpclient # depends on pycurl
import tornado.escape
import tornado.gen
import tornado.http1connection
import tornado.httpclient
import tornado.httpserver
import tornado.httputil

View file

@ -12,8 +12,9 @@ import time
from tornado import gen
from tornado.ioloop import IOLoop, TimeoutError
from tornado.log import app_log
from tornado.stack_context import ExceptionStackContext, StackContext, wrap, NullContext
from tornado.testing import AsyncTestCase, bind_unused_port
from tornado.testing import AsyncTestCase, bind_unused_port, ExpectLog
from tornado.test.util import unittest, skipIfNonUnix, skipOnTravis
try:
@ -51,7 +52,8 @@ class TestIOLoop(AsyncTestCase):
thread = threading.Thread(target=target)
self.io_loop.add_callback(thread.start)
self.wait()
self.assertAlmostEqual(time.time(), self.stop_time, places=2)
delta = time.time() - self.stop_time
self.assertLess(delta, 0.1)
thread.join()
def test_add_timeout_timedelta(self):
@ -153,7 +155,7 @@ class TestIOLoop(AsyncTestCase):
def test_remove_timeout_after_fire(self):
# It is not an error to call remove_timeout after it has run.
handle = self.io_loop.add_timeout(self.io_loop.time(), self.stop())
handle = self.io_loop.add_timeout(self.io_loop.time(), self.stop)
self.wait()
self.io_loop.remove_timeout(handle)
@ -171,6 +173,131 @@ class TestIOLoop(AsyncTestCase):
self.io_loop.add_callback(lambda: self.io_loop.add_callback(self.stop))
self.wait()
def test_timeout_with_arguments(self):
# This tests that all the timeout methods pass through *args correctly.
results = []
self.io_loop.add_timeout(self.io_loop.time(), results.append, 1)
self.io_loop.add_timeout(datetime.timedelta(seconds=0),
results.append, 2)
self.io_loop.call_at(self.io_loop.time(), results.append, 3)
self.io_loop.call_later(0, results.append, 4)
self.io_loop.call_later(0, self.stop)
self.wait()
self.assertEqual(results, [1, 2, 3, 4])
def test_close_file_object(self):
"""When a file object is used instead of a numeric file descriptor,
the object should be closed (by IOLoop.close(all_fds=True),
not just the fd.
"""
# Use a socket since they are supported by IOLoop on all platforms.
# Unfortunately, sockets don't support the .closed attribute for
# inspecting their close status, so we must use a wrapper.
class SocketWrapper(object):
def __init__(self, sockobj):
self.sockobj = sockobj
self.closed = False
def fileno(self):
return self.sockobj.fileno()
def close(self):
self.closed = True
self.sockobj.close()
sockobj, port = bind_unused_port()
socket_wrapper = SocketWrapper(sockobj)
io_loop = IOLoop()
io_loop.add_handler(socket_wrapper, lambda fd, events: None,
IOLoop.READ)
io_loop.close(all_fds=True)
self.assertTrue(socket_wrapper.closed)
def test_handler_callback_file_object(self):
"""The handler callback receives the same fd object it passed in."""
server_sock, port = bind_unused_port()
fds = []
def handle_connection(fd, events):
fds.append(fd)
conn, addr = server_sock.accept()
conn.close()
self.stop()
self.io_loop.add_handler(server_sock, handle_connection, IOLoop.READ)
with contextlib.closing(socket.socket()) as client_sock:
client_sock.connect(('127.0.0.1', port))
self.wait()
self.io_loop.remove_handler(server_sock)
self.io_loop.add_handler(server_sock.fileno(), handle_connection,
IOLoop.READ)
with contextlib.closing(socket.socket()) as client_sock:
client_sock.connect(('127.0.0.1', port))
self.wait()
self.assertIs(fds[0], server_sock)
self.assertEqual(fds[1], server_sock.fileno())
self.io_loop.remove_handler(server_sock.fileno())
server_sock.close()
def test_mixed_fd_fileobj(self):
server_sock, port = bind_unused_port()
def f(fd, events):
pass
self.io_loop.add_handler(server_sock, f, IOLoop.READ)
with self.assertRaises(Exception):
# The exact error is unspecified - some implementations use
# IOError, others use ValueError.
self.io_loop.add_handler(server_sock.fileno(), f, IOLoop.READ)
self.io_loop.remove_handler(server_sock.fileno())
server_sock.close()
def test_reentrant(self):
"""Calling start() twice should raise an error, not deadlock."""
returned_from_start = [False]
got_exception = [False]
def callback():
try:
self.io_loop.start()
returned_from_start[0] = True
except Exception:
got_exception[0] = True
self.stop()
self.io_loop.add_callback(callback)
self.wait()
self.assertTrue(got_exception[0])
self.assertFalse(returned_from_start[0])
def test_exception_logging(self):
"""Uncaught exceptions get logged by the IOLoop."""
# Use a NullContext to keep the exception from being caught by
# AsyncTestCase.
with NullContext():
self.io_loop.add_callback(lambda: 1/0)
self.io_loop.add_callback(self.stop)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
def test_exception_logging_future(self):
"""The IOLoop examines exceptions from Futures and logs them."""
with NullContext():
@gen.coroutine
def callback():
self.io_loop.add_callback(self.stop)
1/0
self.io_loop.add_callback(callback)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
def test_spawn_callback(self):
# An added callback runs in the test's stack_context, so will be
# re-arised in wait().
self.io_loop.add_callback(lambda: 1/0)
with self.assertRaises(ZeroDivisionError):
self.wait()
# A spawned callback is run directly on the IOLoop, so it will be
# logged without stopping the test.
self.io_loop.spawn_callback(lambda: 1/0)
self.io_loop.add_callback(self.stop)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
# Deliberately not a subclass of AsyncTestCase so the IOLoop isn't
# automatically set as current.
@ -329,5 +456,6 @@ class TestIOLoopRunSync(unittest.TestCase):
yield gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1)
self.assertRaises(TimeoutError, self.io_loop.run_sync, f, timeout=0.01)
if __name__ == "__main__":
unittest.main()

View file

@ -1,13 +1,16 @@
from __future__ import absolute_import, division, print_function, with_statement
from tornado.concurrent import Future
from tornado import gen
from tornado import netutil
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream, SSLIOStream, PipeIOStream
from tornado.iostream import IOStream, SSLIOStream, PipeIOStream, StreamClosedError
from tornado.httputil import HTTPHeaders
from tornado.log import gen_log, app_log
from tornado.netutil import ssl_wrap_socket
from tornado.stack_context import NullContext
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test
from tornado.test.util import unittest, skipIfNonUnix
from tornado.web import RequestHandler, Application
import certifi
import errno
import logging
import os
@ -17,6 +20,13 @@ import ssl
import sys
def _server_ssl_options():
return dict(
certfile=os.path.join(os.path.dirname(__file__), 'test.crt'),
keyfile=os.path.join(os.path.dirname(__file__), 'test.key'),
)
class HelloHandler(RequestHandler):
def get(self):
self.write("Hello")
@ -106,6 +116,48 @@ class TestIOStreamWebMixin(object):
stream.close()
@gen_test
def test_future_interface(self):
"""Basic test of IOStream's ability to return Futures."""
stream = self._make_client_iostream()
connect_result = yield stream.connect(
("localhost", self.get_http_port()))
self.assertIs(connect_result, stream)
yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
first_line = yield stream.read_until(b"\r\n")
self.assertEqual(first_line, b"HTTP/1.0 200 OK\r\n")
# callback=None is equivalent to no callback.
header_data = yield stream.read_until(b"\r\n\r\n", callback=None)
headers = HTTPHeaders.parse(header_data.decode('latin1'))
content_length = int(headers['Content-Length'])
body = yield stream.read_bytes(content_length)
self.assertEqual(body, b'Hello')
stream.close()
@gen_test
def test_future_close_while_reading(self):
stream = self._make_client_iostream()
yield stream.connect(("localhost", self.get_http_port()))
yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
with self.assertRaises(StreamClosedError):
yield stream.read_bytes(1024 * 1024)
stream.close()
@gen_test
def test_future_read_until_close(self):
# Ensure that the data comes through before the StreamClosedError.
stream = self._make_client_iostream()
yield stream.connect(("localhost", self.get_http_port()))
yield stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n")
yield stream.read_until(b"\r\n\r\n")
body = yield stream.read_until_close()
self.assertEqual(body, b"Hello")
# Nothing else to read; the error comes immediately without waiting
# for yield.
with self.assertRaises(StreamClosedError):
stream.read_bytes(1)
class TestIOStreamMixin(object):
def _make_server_iostream(self, connection, **kwargs):
@ -120,16 +172,6 @@ class TestIOStreamMixin(object):
def accept_callback(connection, address):
streams[0] = self._make_server_iostream(connection, **kwargs)
if isinstance(streams[0], SSLIOStream):
# HACK: The SSL handshake won't complete (and
# therefore the client connect callback won't be
# run)until the server side has tried to do something
# with the connection. For these tests we want both
# sides to connect before we do anything else with the
# connection, so we must cause some dummy activity on the
# server. If this turns out to be useful for real apps
# it should have a cleaner interface.
streams[0]._add_io_state(IOLoop.READ)
self.stop()
def connect_callback():
@ -168,9 +210,6 @@ class TestIOStreamMixin(object):
server, client = self.make_iostream_pair()
server.write(b'', callback=self.stop)
self.wait()
# As a side effect, the stream is now listening for connection
# close (if it wasn't already), but is not listening for writes
self.assertEqual(server._state, IOLoop.READ | IOLoop.ERROR)
server.close()
client.close()
@ -193,8 +232,11 @@ class TestIOStreamMixin(object):
self.assertFalse(self.connect_called)
self.assertTrue(isinstance(stream.error, socket.error), stream.error)
if sys.platform != 'cygwin':
_ERRNO_CONNREFUSED = (errno.ECONNREFUSED,)
if hasattr(errno, "WSAECONNREFUSED"):
_ERRNO_CONNREFUSED += (errno.WSAECONNREFUSED,)
# cygwin's errnos don't match those used on native windows python
self.assertEqual(stream.error.args[0], errno.ECONNREFUSED)
self.assertTrue(stream.error.args[0] in _ERRNO_CONNREFUSED)
def test_gaierror(self):
# Test that IOStream sets its exc_info on getaddrinfo error
@ -308,6 +350,25 @@ class TestIOStreamMixin(object):
server.close()
client.close()
def test_future_delayed_close_callback(self):
# Same as test_delayed_close_callback, but with the future interface.
server, client = self.make_iostream_pair()
# We can't call make_iostream_pair inside a gen_test function
# because the ioloop is not reentrant.
@gen_test
def f(self):
server.write(b"12")
chunks = []
chunks.append((yield client.read_bytes(1)))
server.close()
chunks.append((yield client.read_bytes(1)))
self.assertEqual(chunks, [b"1", b"2"])
try:
f(self)
finally:
server.close()
client.close()
def test_close_buffered_data(self):
# Similar to the previous test, but with data stored in the OS's
# socket buffers instead of the IOStream's read buffer. Out-of-band
@ -340,14 +401,18 @@ class TestIOStreamMixin(object):
# Similar to test_delayed_close_callback, but read_until_close takes
# a separate code path so test it separately.
server, client = self.make_iostream_pair()
client.set_close_callback(self.stop)
try:
server.write(b"1234")
server.close()
self.wait()
# Read one byte to make sure the client has received the data.
# It won't run the close callback as long as there is more buffered
# data that could satisfy a later read.
client.read_bytes(1, self.stop)
data = self.wait()
self.assertEqual(data, b"1")
client.read_until_close(self.stop)
data = self.wait()
self.assertEqual(data, b"1234")
self.assertEqual(data, b"234")
finally:
server.close()
client.close()
@ -357,17 +422,18 @@ class TestIOStreamMixin(object):
# All data should go through the streaming callback,
# and the final read callback just gets an empty string.
server, client = self.make_iostream_pair()
client.set_close_callback(self.stop)
try:
server.write(b"1234")
server.close()
self.wait()
client.read_bytes(1, self.stop)
data = self.wait()
self.assertEqual(data, b"1")
streaming_data = []
client.read_until_close(self.stop,
streaming_callback=streaming_data.append)
data = self.wait()
self.assertEqual(b'', data)
self.assertEqual(b''.join(streaming_data), b"1234")
self.assertEqual(b''.join(streaming_data), b"234")
finally:
server.close()
client.close()
@ -461,6 +527,203 @@ class TestIOStreamMixin(object):
server.close()
client.close()
def test_future_close_callback(self):
# Regression test for interaction between the Future read interfaces
# and IOStream._maybe_add_error_listener.
server, client = self.make_iostream_pair()
closed = [False]
def close_callback():
closed[0] = True
self.stop()
server.set_close_callback(close_callback)
try:
client.write(b'a')
future = server.read_bytes(1)
self.io_loop.add_future(future, self.stop)
self.assertEqual(self.wait().result(), b'a')
self.assertFalse(closed[0])
client.close()
self.wait()
self.assertTrue(closed[0])
finally:
server.close()
client.close()
def test_read_bytes_partial(self):
server, client = self.make_iostream_pair()
try:
# Ask for more than is available with partial=True
client.read_bytes(50, self.stop, partial=True)
server.write(b"hello")
data = self.wait()
self.assertEqual(data, b"hello")
# Ask for less than what is available; num_bytes is still
# respected.
client.read_bytes(3, self.stop, partial=True)
server.write(b"world")
data = self.wait()
self.assertEqual(data, b"wor")
# Partial reads won't return an empty string, but read_bytes(0)
# will.
client.read_bytes(0, self.stop, partial=True)
data = self.wait()
self.assertEqual(data, b'')
finally:
server.close()
client.close()
def test_read_until_max_bytes(self):
server, client = self.make_iostream_pair()
client.set_close_callback(lambda: self.stop("closed"))
try:
# Extra room under the limit
client.read_until(b"def", self.stop, max_bytes=50)
server.write(b"abcdef")
data = self.wait()
self.assertEqual(data, b"abcdef")
# Just enough space
client.read_until(b"def", self.stop, max_bytes=6)
server.write(b"abcdef")
data = self.wait()
self.assertEqual(data, b"abcdef")
# Not enough space, but we don't know it until all we can do is
# log a warning and close the connection.
with ExpectLog(gen_log, "Unsatisfiable read"):
client.read_until(b"def", self.stop, max_bytes=5)
server.write(b"123456")
data = self.wait()
self.assertEqual(data, "closed")
finally:
server.close()
client.close()
def test_read_until_max_bytes_inline(self):
server, client = self.make_iostream_pair()
client.set_close_callback(lambda: self.stop("closed"))
try:
# Similar to the error case in the previous test, but the
# server writes first so client reads are satisfied
# inline. For consistency with the out-of-line case, we
# do not raise the error synchronously.
server.write(b"123456")
with ExpectLog(gen_log, "Unsatisfiable read"):
client.read_until(b"def", self.stop, max_bytes=5)
data = self.wait()
self.assertEqual(data, "closed")
finally:
server.close()
client.close()
def test_read_until_max_bytes_ignores_extra(self):
server, client = self.make_iostream_pair()
client.set_close_callback(lambda: self.stop("closed"))
try:
# Even though data that matches arrives the same packet that
# puts us over the limit, we fail the request because it was not
# found within the limit.
server.write(b"abcdef")
with ExpectLog(gen_log, "Unsatisfiable read"):
client.read_until(b"def", self.stop, max_bytes=5)
data = self.wait()
self.assertEqual(data, "closed")
finally:
server.close()
client.close()
def test_read_until_regex_max_bytes(self):
server, client = self.make_iostream_pair()
client.set_close_callback(lambda: self.stop("closed"))
try:
# Extra room under the limit
client.read_until_regex(b"def", self.stop, max_bytes=50)
server.write(b"abcdef")
data = self.wait()
self.assertEqual(data, b"abcdef")
# Just enough space
client.read_until_regex(b"def", self.stop, max_bytes=6)
server.write(b"abcdef")
data = self.wait()
self.assertEqual(data, b"abcdef")
# Not enough space, but we don't know it until all we can do is
# log a warning and close the connection.
with ExpectLog(gen_log, "Unsatisfiable read"):
client.read_until_regex(b"def", self.stop, max_bytes=5)
server.write(b"123456")
data = self.wait()
self.assertEqual(data, "closed")
finally:
server.close()
client.close()
def test_read_until_regex_max_bytes_inline(self):
server, client = self.make_iostream_pair()
client.set_close_callback(lambda: self.stop("closed"))
try:
# Similar to the error case in the previous test, but the
# server writes first so client reads are satisfied
# inline. For consistency with the out-of-line case, we
# do not raise the error synchronously.
server.write(b"123456")
with ExpectLog(gen_log, "Unsatisfiable read"):
client.read_until_regex(b"def", self.stop, max_bytes=5)
data = self.wait()
self.assertEqual(data, "closed")
finally:
server.close()
client.close()
def test_read_until_regex_max_bytes_ignores_extra(self):
server, client = self.make_iostream_pair()
client.set_close_callback(lambda: self.stop("closed"))
try:
# Even though data that matches arrives the same packet that
# puts us over the limit, we fail the request because it was not
# found within the limit.
server.write(b"abcdef")
with ExpectLog(gen_log, "Unsatisfiable read"):
client.read_until_regex(b"def", self.stop, max_bytes=5)
data = self.wait()
self.assertEqual(data, "closed")
finally:
server.close()
client.close()
def test_small_reads_from_large_buffer(self):
# 10KB buffer size, 100KB available to read.
# Read 1KB at a time and make sure that the buffer is not eagerly
# filled.
server, client = self.make_iostream_pair(max_buffer_size=10 * 1024)
try:
server.write(b"a" * 1024 * 100)
for i in range(100):
client.read_bytes(1024, self.stop)
data = self.wait()
self.assertEqual(data, b"a" * 1024)
finally:
server.close()
client.close()
def test_small_read_untils_from_large_buffer(self):
# 10KB buffer size, 100KB available to read.
# Read 1KB at a time and make sure that the buffer is not eagerly
# filled.
server, client = self.make_iostream_pair(max_buffer_size=10 * 1024)
try:
server.write((b"a" * 1023 + b"\n") * 100)
for i in range(100):
client.read_until(b"\n", self.stop, max_bytes=4096)
data = self.wait()
self.assertEqual(data, b"a" * 1023 + b"\n")
finally:
server.close()
client.close()
class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase):
def _make_client_iostream(self):
@ -482,14 +745,10 @@ class TestIOStream(TestIOStreamMixin, AsyncTestCase):
class TestIOStreamSSL(TestIOStreamMixin, AsyncTestCase):
def _make_server_iostream(self, connection, **kwargs):
ssl_options = dict(
certfile=os.path.join(os.path.dirname(__file__), 'test.crt'),
keyfile=os.path.join(os.path.dirname(__file__), 'test.key'),
)
connection = ssl.wrap_socket(connection,
server_side=True,
do_handshake_on_connect=False,
**ssl_options)
**_server_ssl_options())
return SSLIOStream(connection, io_loop=self.io_loop, **kwargs)
def _make_client_iostream(self, connection, **kwargs):
@ -517,6 +776,91 @@ class TestIOStreamSSLContext(TestIOStreamMixin, AsyncTestCase):
ssl_options=context, **kwargs)
class TestIOStreamStartTLS(AsyncTestCase):
def setUp(self):
try:
super(TestIOStreamStartTLS, self).setUp()
self.listener, self.port = bind_unused_port()
self.server_stream = None
self.server_accepted = Future()
netutil.add_accept_handler(self.listener, self.accept)
self.client_stream = IOStream(socket.socket())
self.io_loop.add_future(self.client_stream.connect(
('127.0.0.1', self.port)), self.stop)
self.wait()
self.io_loop.add_future(self.server_accepted, self.stop)
self.wait()
except Exception as e:
print(e)
raise
def tearDown(self):
if self.server_stream is not None:
self.server_stream.close()
if self.client_stream is not None:
self.client_stream.close()
self.listener.close()
super(TestIOStreamStartTLS, self).tearDown()
def accept(self, connection, address):
if self.server_stream is not None:
self.fail("should only get one connection")
self.server_stream = IOStream(connection)
self.server_accepted.set_result(None)
@gen.coroutine
def client_send_line(self, line):
self.client_stream.write(line)
recv_line = yield self.server_stream.read_until(b"\r\n")
self.assertEqual(line, recv_line)
@gen.coroutine
def server_send_line(self, line):
self.server_stream.write(line)
recv_line = yield self.client_stream.read_until(b"\r\n")
self.assertEqual(line, recv_line)
def client_start_tls(self, ssl_options=None):
client_stream = self.client_stream
self.client_stream = None
return client_stream.start_tls(False, ssl_options)
def server_start_tls(self, ssl_options=None):
server_stream = self.server_stream
self.server_stream = None
return server_stream.start_tls(True, ssl_options)
@gen_test
def test_start_tls_smtp(self):
# This flow is simplified from RFC 3207 section 5.
# We don't really need all of this, but it helps to make sure
# that after realistic back-and-forth traffic the buffers end up
# in a sane state.
yield self.server_send_line(b"220 mail.example.com ready\r\n")
yield self.client_send_line(b"EHLO mail.example.com\r\n")
yield self.server_send_line(b"250-mail.example.com welcome\r\n")
yield self.server_send_line(b"250 STARTTLS\r\n")
yield self.client_send_line(b"STARTTLS\r\n")
yield self.server_send_line(b"220 Go ahead\r\n")
client_future = self.client_start_tls()
server_future = self.server_start_tls(_server_ssl_options())
self.client_stream = yield client_future
self.server_stream = yield server_future
self.assertTrue(isinstance(self.client_stream, SSLIOStream))
self.assertTrue(isinstance(self.server_stream, SSLIOStream))
yield self.client_send_line(b"EHLO mail.example.com\r\n")
yield self.server_send_line(b"250 mail.example.com welcome\r\n")
@gen_test
def test_handshake_fail(self):
self.server_start_tls(_server_ssl_options())
client_future = self.client_start_tls(
dict(cert_reqs=ssl.CERT_REQUIRED, ca_certs=certifi.where()))
with ExpectLog(gen_log, "SSL Error"):
with self.assertRaises(ssl.SSLError):
yield client_future
@skipIfNonUnix
class TestPipeIOStream(AsyncTestCase):
def test_pipe_iostream(self):
@ -543,3 +887,21 @@ class TestPipeIOStream(AsyncTestCase):
self.assertEqual(data, b"ld")
rs.close()
def test_pipe_iostream_big_write(self):
r, w = os.pipe()
rs = PipeIOStream(r, io_loop=self.io_loop)
ws = PipeIOStream(w, io_loop=self.io_loop)
NUM_BYTES = 1048576
# Write 1MB of data, which should fill the buffer
ws.write(b"1" * NUM_BYTES)
rs.read_bytes(NUM_BYTES, self.stop)
data = self.wait()
self.assertEqual(data, b"1" * NUM_BYTES)
ws.close()
rs.close()

View file

@ -20,6 +20,8 @@ import glob
import logging
import os
import re
import subprocess
import sys
import tempfile
import warnings
@ -52,7 +54,6 @@ class LogFormatterTest(unittest.TestCase):
logging.ERROR: u("\u0001"),
}
self.formatter._normal = u("\u0002")
self.formatter._color = True
# construct a Logger directly to bypass getLogger's caching
self.logger = logging.Logger('LogFormatterTest')
self.logger.propagate = False
@ -157,3 +158,50 @@ class EnablePrettyLoggingTest(unittest.TestCase):
for filename in glob.glob(tmpdir + '/test_log*'):
os.unlink(filename)
os.rmdir(tmpdir)
class LoggingOptionTest(unittest.TestCase):
"""Test the ability to enable and disable Tornado's logging hooks."""
def logs_present(self, statement, args=None):
# Each test may manipulate and/or parse the options and then logs
# a line at the 'info' level. This level is ignored in the
# logging module by default, but Tornado turns it on by default
# so it is the easiest way to tell whether tornado's logging hooks
# ran.
IMPORT = 'from tornado.options import options, parse_command_line'
LOG_INFO = 'import logging; logging.info("hello")'
program = ';'.join([IMPORT, statement, LOG_INFO])
proc = subprocess.Popen(
[sys.executable, '-c', program] + (args or []),
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
stdout, stderr = proc.communicate()
self.assertEqual(proc.returncode, 0, 'process failed: %r' % stdout)
return b'hello' in stdout
def test_default(self):
self.assertFalse(self.logs_present('pass'))
def test_tornado_default(self):
self.assertTrue(self.logs_present('parse_command_line()'))
def test_disable_command_line(self):
self.assertFalse(self.logs_present('parse_command_line()',
['--logging=none']))
def test_disable_command_line_case_insensitive(self):
self.assertFalse(self.logs_present('parse_command_line()',
['--logging=None']))
def test_disable_code_string(self):
self.assertFalse(self.logs_present(
'options.logging = "none"; parse_command_line()'))
def test_disable_code_none(self):
self.assertFalse(self.logs_present(
'options.logging = None; parse_command_line()'))
def test_disable_override(self):
# command line trumps code defaults
self.assertTrue(self.logs_present(
'options.logging = None; parse_command_line()',
['--logging=info']))

View file

@ -1,10 +1,16 @@
from __future__ import absolute_import, division, print_function, with_statement
import os
import signal
import socket
from subprocess import Popen
import sys
import time
from tornado.netutil import BlockingResolver, ThreadedResolver, is_valid_ip
from tornado.netutil import BlockingResolver, ThreadedResolver, is_valid_ip, bind_sockets
from tornado.stack_context import ExceptionStackContext
from tornado.testing import AsyncTestCase, gen_test
from tornado.test.util import unittest
from tornado.test.util import unittest, skipIfNoNetwork
try:
from concurrent import futures
@ -20,6 +26,7 @@ else:
try:
import twisted
import twisted.names
except ImportError:
twisted = None
else:
@ -27,6 +34,15 @@ else:
class _ResolverTestMixin(object):
def skipOnCares(self):
# Some DNS-hijacking ISPs (e.g. Time Warner) return non-empty results
# with an NXDOMAIN status code. Most resolvers treat this as an error;
# C-ares returns the results, making the "bad_host" tests unreliable.
# C-ares will try to resolve even malformed names, such as the
# name with spaces used in this test.
if self.resolver.__class__.__name__ == 'CaresResolver':
self.skipTest("CaresResolver doesn't recognize fake NXDOMAIN")
def test_localhost(self):
self.resolver.resolve('localhost', 80, callback=self.stop)
result = self.wait()
@ -39,13 +55,34 @@ class _ResolverTestMixin(object):
self.assertIn((socket.AF_INET, ('127.0.0.1', 80)),
addrinfo)
def test_bad_host(self):
self.skipOnCares()
def handler(exc_typ, exc_val, exc_tb):
self.stop(exc_val)
return True # Halt propagation.
with ExceptionStackContext(handler):
self.resolver.resolve('an invalid domain', 80, callback=self.stop)
result = self.wait()
self.assertIsInstance(result, Exception)
@gen_test
def test_future_interface_bad_host(self):
self.skipOnCares()
with self.assertRaises(Exception):
yield self.resolver.resolve('an invalid domain', 80,
socket.AF_UNSPEC)
@skipIfNoNetwork
class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(BlockingResolverTest, self).setUp()
self.resolver = BlockingResolver(io_loop=self.io_loop)
@skipIfNoNetwork
@unittest.skipIf(futures is None, "futures module not present")
class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
@ -57,6 +94,34 @@ class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
super(ThreadedResolverTest, self).tearDown()
@skipIfNoNetwork
@unittest.skipIf(futures is None, "futures module not present")
@unittest.skipIf(sys.platform == 'win32', "preexec_fn not available on win32")
class ThreadedResolverImportTest(unittest.TestCase):
def test_import(self):
TIMEOUT = 5
# Test for a deadlock when importing a module that runs the
# ThreadedResolver at import-time. See resolve_test.py for
# full explanation.
command = [
sys.executable,
'-c',
'import tornado.test.resolve_test_helper']
start = time.time()
popen = Popen(command, preexec_fn=lambda: signal.alarm(TIMEOUT))
while time.time() - start < TIMEOUT:
return_code = popen.poll()
if return_code is not None:
self.assertEqual(0, return_code)
return # Success.
time.sleep(0.05)
self.fail("import timed out")
@skipIfNoNetwork
@unittest.skipIf(pycares is None, "pycares module not present")
class CaresResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
@ -64,6 +129,7 @@ class CaresResolverTest(AsyncTestCase, _ResolverTestMixin):
self.resolver = CaresResolver(io_loop=self.io_loop)
@skipIfNoNetwork
@unittest.skipIf(twisted is None, "twisted module not present")
@unittest.skipIf(getattr(twisted, '__version__', '0.0') < "12.1", "old version of twisted")
class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin):
@ -82,3 +148,21 @@ class IsValidIPTest(unittest.TestCase):
self.assertTrue(not is_valid_ip('localhost'))
self.assertTrue(not is_valid_ip('4.4.4.4<'))
self.assertTrue(not is_valid_ip(' 127.0.0.1'))
self.assertTrue(not is_valid_ip(''))
self.assertTrue(not is_valid_ip(' '))
self.assertTrue(not is_valid_ip('\n'))
self.assertTrue(not is_valid_ip('\x00'))
class TestPortAllocation(unittest.TestCase):
def test_same_port_allocation(self):
if 'TRAVIS' in os.environ:
self.skipTest("dual-stack servers often have port conflicts on travis")
sockets = bind_sockets(None, 'localhost')
try:
port = sockets[0].getsockname()[1]
self.assertTrue(all(s.getsockname()[1] == port
for s in sockets[1:]))
finally:
for sock in sockets:
sock.close()

View file

@ -19,8 +19,10 @@ from tornado.web import RequestHandler, Application
def skip_if_twisted():
if IOLoop.configured_class().__name__.endswith('TwistedIOLoop'):
raise unittest.SkipTest("Process tests not compatible with TwistedIOLoop")
if IOLoop.configured_class().__name__.endswith(('TwistedIOLoop',
'AsyncIOMainLoop')):
raise unittest.SkipTest("Process tests not compatible with "
"TwistedIOLoop or AsyncIOMainLoop")
# Not using AsyncHTTPTestCase because we need control over the IOLoop.
@ -135,6 +137,14 @@ class ProcessTest(unittest.TestCase):
@skipIfNonUnix
class SubprocessTest(AsyncTestCase):
def test_subprocess(self):
if IOLoop.configured_class().__name__.endswith('LayeredTwistedIOLoop'):
# This test fails non-deterministically with LayeredTwistedIOLoop.
# (the read_until('\n') returns '\n' instead of 'hello\n')
# This probably indicates a problem with either TornadoReactor
# or TwistedIOLoop, but I haven't been able to track it down
# and for now this is just causing spurious travis-ci failures.
raise unittest.SkipTest("Subprocess tests not compatible with "
"LayeredTwistedIOLoop")
subproc = Subprocess([sys.executable, '-u', '-i'],
stdin=Subprocess.STREAM,
stdout=Subprocess.STREAM, stderr=subprocess.STDOUT,

View file

@ -0,0 +1,12 @@
from __future__ import absolute_import, division, print_function, with_statement
from tornado.ioloop import IOLoop
from tornado.netutil import ThreadedResolver
from tornado.util import u
# When this module is imported, it runs getaddrinfo on a thread. Since
# the hostname is unicode, getaddrinfo attempts to import encodings.idna
# but blocks on the import lock. Verify that ThreadedResolver avoids
# this deadlock.
resolver = ThreadedResolver()
IOLoop.current().run_sync(lambda: resolver.resolve(u('localhost'), 80))

View file

@ -13,6 +13,11 @@ from tornado.netutil import Resolver
from tornado.options import define, options, add_parse_callback
from tornado.test.util import unittest
try:
reduce # py2
except NameError:
from functools import reduce # py3
TEST_MODULES = [
'tornado.httputil.doctests',
'tornado.iostream.doctests',
@ -35,6 +40,7 @@ TEST_MODULES = [
'tornado.test.process_test',
'tornado.test.simple_httpclient_test',
'tornado.test.stack_context_test',
'tornado.test.tcpclient_test',
'tornado.test.template_test',
'tornado.test.testing_test',
'tornado.test.twisted_test',
@ -60,7 +66,8 @@ class TornadoTextTestRunner(unittest.TextTestRunner):
self.stream.write("\n")
return result
if __name__ == '__main__':
def main():
# The -W command-line option does not work in a virtualenv with
# python 3 (as of virtualenv 1.7), so configure warnings
# programmatically instead.
@ -77,6 +84,9 @@ if __name__ == '__main__':
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("error", category=DeprecationWarning,
module=r"tornado\..*")
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
warnings.filterwarnings("error", category=PendingDeprecationWarning,
module=r"tornado\..*")
# The unittest module is aggressive about deprecating redundant methods,
# leaving some without non-deprecated spellings that work on both
# 2.7 and 3.2
@ -86,7 +96,8 @@ if __name__ == '__main__':
logging.getLogger("tornado.access").setLevel(logging.CRITICAL)
define('httpclient', type=str, default=None,
callback=AsyncHTTPClient.configure)
callback=lambda s: AsyncHTTPClient.configure(
s, defaults=dict(allow_ipv6=False)))
define('ioloop', type=str, default=None)
define('ioloop_time_monotonic', default=False)
define('resolver', type=str, default=None,
@ -121,3 +132,6 @@ if __name__ == '__main__':
kwargs['warnings'] = False
kwargs['testRunner'] = TornadoTextTestRunner
tornado.testing.main(**kwargs)
if __name__ == '__main__':
main()

View file

@ -10,16 +10,18 @@ import re
import socket
import sys
from tornado import gen
from tornado.httpclient import AsyncHTTPClient
from tornado.httputil import HTTPHeaders
from tornado.ioloop import IOLoop
from tornado.log import gen_log
from tornado.simple_httpclient import SimpleAsyncHTTPClient, _DEFAULT_CA_CERTS
from tornado.log import gen_log, app_log
from tornado.netutil import Resolver, bind_sockets
from tornado.simple_httpclient import SimpleAsyncHTTPClient, _default_ca_certs
from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler
from tornado.test import httpclient_test
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog
from tornado.test.util import unittest, skipOnTravis
from tornado.web import RequestHandler, Application, asynchronous, url
from tornado.test.util import skipOnTravis, skipIfNoIPv6
from tornado.web import RequestHandler, Application, asynchronous, url, stream_request_body
class SimpleHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
@ -69,7 +71,8 @@ class OptionsHandler(RequestHandler):
class NoContentHandler(RequestHandler):
def get(self):
if self.get_argument("error", None):
self.set_header("Content-Length", "7")
self.set_header("Content-Length", "5")
self.write("hello")
self.set_status(204)
@ -93,6 +96,30 @@ class HostEchoHandler(RequestHandler):
self.write(self.request.headers["Host"])
class NoContentLengthHandler(RequestHandler):
@gen.coroutine
def get(self):
# Emulate the old HTTP/1.0 behavior of returning a body with no
# content-length. Tornado handles content-length at the framework
# level so we have to go around it.
stream = self.request.connection.stream
yield stream.write(b"HTTP/1.0 200 OK\r\n\r\n"
b"hello")
stream.close()
class EchoPostHandler(RequestHandler):
def post(self):
self.write(self.request.body)
@stream_request_body
class RespondInPrepareHandler(RequestHandler):
def prepare(self):
self.set_status(403)
self.finish("forbidden")
class SimpleHTTPClientTestMixin(object):
def get_app(self):
# callable objects to finish pending /trigger requests
@ -111,6 +138,9 @@ class SimpleHTTPClientTestMixin(object):
url("/see_other_post", SeeOtherPostHandler),
url("/see_other_get", SeeOtherGetHandler),
url("/host_echo", HostEchoHandler),
url("/no_content_length", NoContentLengthHandler),
url("/echo_post", EchoPostHandler),
url("/respond_in_prepare", RespondInPrepareHandler),
], gzip=True)
def test_singleton(self):
@ -122,9 +152,9 @@ class SimpleHTTPClientTestMixin(object):
SimpleAsyncHTTPClient(self.io_loop,
force_instance=True))
# different IOLoops use different objects
io_loop2 = IOLoop()
self.assertTrue(SimpleAsyncHTTPClient(self.io_loop) is not
SimpleAsyncHTTPClient(io_loop2))
with closing(IOLoop()) as io_loop2:
self.assertTrue(SimpleAsyncHTTPClient(self.io_loop) is not
SimpleAsyncHTTPClient(io_loop2))
def test_connection_limit(self):
with closing(self.create_client(max_clients=2)) as client:
@ -162,7 +192,7 @@ class SimpleHTTPClientTestMixin(object):
response.rethrow()
def test_default_certificates_exist(self):
open(_DEFAULT_CA_CERTS).close()
open(_default_ca_certs()).close()
def test_gzip(self):
# All the tests in this file should be using gzip, but this test
@ -212,28 +242,30 @@ class SimpleHTTPClientTestMixin(object):
# trigger the hanging request to let it clean up after itself
self.triggers.popleft()()
@unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present')
@skipIfNoIPv6
def test_ipv6(self):
try:
self.http_server.listen(self.get_http_port(), address='::1')
[sock] = bind_sockets(None, '::1', family=socket.AF_INET6)
port = sock.getsockname()[1]
self.http_server.add_socket(sock)
except socket.gaierror as e:
if e.args[0] == socket.EAI_ADDRFAMILY:
# python supports ipv6, but it's not configured on the network
# interface, so skip this test.
return
raise
url = self.get_url("/hello").replace("localhost", "[::1]")
url = '%s://[::1]:%d/hello' % (self.get_protocol(), port)
# ipv6 is currently disabled by default and must be explicitly requested
self.http_client.fetch(url, self.stop)
# ipv6 is currently enabled by default but can be disabled
self.http_client.fetch(url, self.stop, allow_ipv6=False)
response = self.wait()
self.assertEqual(response.code, 599)
self.http_client.fetch(url, self.stop, allow_ipv6=True)
self.http_client.fetch(url, self.stop)
response = self.wait()
self.assertEqual(response.body, b"Hello world!")
def test_multiple_content_length_accepted(self):
def xtest_multiple_content_length_accepted(self):
response = self.fetch("/content_length?value=2,2")
self.assertEqual(response.body, b"ok")
response = self.fetch("/content_length?value=2,%202,2")
@ -265,7 +297,8 @@ class SimpleHTTPClientTestMixin(object):
self.assertEqual(response.headers["Content-length"], "0")
# 204 status with non-zero content length is malformed
response = self.fetch("/no_content?error=1")
with ExpectLog(app_log, "Uncaught exception"):
response = self.fetch("/no_content?error=1")
self.assertEqual(response.code, 599)
def test_host_header(self):
@ -288,14 +321,86 @@ class SimpleHTTPClientTestMixin(object):
if sys.platform != 'cygwin':
# cygwin returns EPERM instead of ECONNREFUSED here
self.assertTrue(str(errno.ECONNREFUSED) in str(response.error),
response.error)
contains_errno = str(errno.ECONNREFUSED) in str(response.error)
if not contains_errno and hasattr(errno, "WSAECONNREFUSED"):
contains_errno = str(errno.WSAECONNREFUSED) in str(response.error)
self.assertTrue(contains_errno, response.error)
# This is usually "Connection refused".
# On windows, strerror is broken and returns "Unknown error".
expected_message = os.strerror(errno.ECONNREFUSED)
self.assertTrue(expected_message in str(response.error),
response.error)
def test_queue_timeout(self):
with closing(self.create_client(max_clients=1)) as client:
client.fetch(self.get_url('/trigger'), self.stop,
request_timeout=10)
# Wait for the trigger request to block, not complete.
self.wait()
client.fetch(self.get_url('/hello'), self.stop,
connect_timeout=0.1)
response = self.wait()
self.assertEqual(response.code, 599)
self.assertTrue(response.request_time < 1, response.request_time)
self.assertEqual(str(response.error), "HTTP 599: Timeout")
self.triggers.popleft()()
self.wait()
def test_no_content_length(self):
response = self.fetch("/no_content_length")
self.assertEquals(b"hello", response.body)
def sync_body_producer(self, write):
write(b'1234')
write(b'5678')
@gen.coroutine
def async_body_producer(self, write):
yield write(b'1234')
yield gen.Task(IOLoop.current().add_callback)
yield write(b'5678')
def test_sync_body_producer_chunked(self):
response = self.fetch("/echo_post", method="POST",
body_producer=self.sync_body_producer)
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_sync_body_producer_content_length(self):
response = self.fetch("/echo_post", method="POST",
body_producer=self.sync_body_producer,
headers={'Content-Length': '8'})
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_async_body_producer_chunked(self):
response = self.fetch("/echo_post", method="POST",
body_producer=self.async_body_producer)
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_async_body_producer_content_length(self):
response = self.fetch("/echo_post", method="POST",
body_producer=self.async_body_producer,
headers={'Content-Length': '8'})
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_100_continue(self):
response = self.fetch("/echo_post", method="POST",
body=b"1234",
expect_100_continue=True)
self.assertEqual(response.body, b"1234")
def test_100_continue_early_response(self):
def body_producer(write):
raise Exception("should not be called")
response = self.fetch("/respond_in_prepare", method="POST",
body_producer=body_producer,
expect_100_continue=True)
self.assertEqual(response.code, 403)
class SimpleHTTPClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPTestCase):
def setUp(self):
@ -396,3 +501,52 @@ class HostnameMappingTestCase(AsyncHTTPTestCase):
response = self.wait()
response.rethrow()
self.assertEqual(response.body, b'Hello world!')
class ResolveTimeoutTestCase(AsyncHTTPTestCase):
def setUp(self):
# Dummy Resolver subclass that never invokes its callback.
class BadResolver(Resolver):
def resolve(self, *args, **kwargs):
pass
super(ResolveTimeoutTestCase, self).setUp()
self.http_client = SimpleAsyncHTTPClient(
self.io_loop,
resolver=BadResolver())
def get_app(self):
return Application([url("/hello", HelloWorldHandler), ])
def test_resolve_timeout(self):
response = self.fetch('/hello', connect_timeout=0.1)
self.assertEqual(response.code, 599)
class MaxHeaderSizeTest(AsyncHTTPTestCase):
def get_app(self):
class SmallHeaders(RequestHandler):
def get(self):
self.set_header("X-Filler", "a" * 100)
self.write("ok")
class LargeHeaders(RequestHandler):
def get(self):
self.set_header("X-Filler", "a" * 1000)
self.write("ok")
return Application([('/small', SmallHeaders),
('/large', LargeHeaders)])
def get_http_client(self):
return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_header_size=1024)
def test_small_headers(self):
response = self.fetch('/small')
response.rethrow()
self.assertEqual(response.body, b'ok')
def test_large_headers(self):
with ExpectLog(gen_log, "Unsatisfiable read"):
response = self.fetch('/large')
self.assertEqual(response.code, 599)

View file

@ -35,11 +35,11 @@ class TestRequestHandler(RequestHandler):
logging.debug('in part3()')
raise Exception('test exception')
def get_error_html(self, status_code, **kwargs):
if 'exception' in kwargs and str(kwargs['exception']) == 'test exception':
return 'got expected exception'
def write_error(self, status_code, **kwargs):
if 'exc_info' in kwargs and str(kwargs['exc_info'][1]) == 'test exception':
self.write('got expected exception')
else:
return 'unexpected failure'
self.write('unexpected failure')
class HTTPStackContextTest(AsyncHTTPTestCase):
@ -219,16 +219,22 @@ class StackContextTest(AsyncTestCase):
def test_yield_in_with(self):
@gen.engine
def f():
self.callback = yield gen.Callback('a')
with StackContext(functools.partial(self.context, 'c1')):
# This yield is a problem: the generator will be suspended
# and the StackContext's __exit__ is not called yet, so
# the context will be left on _state.contexts for anything
# that runs before the yield resolves.
yield gen.Task(self.io_loop.add_callback)
yield gen.Wait('a')
with self.assertRaises(StackContextInconsistentError):
f()
self.wait()
# Cleanup: to avoid GC warnings (which for some reason only seem
# to show up on py33-asyncio), invoke the callback (which will do
# nothing since the gen.Runner is already finished) and delete it.
self.callback()
del self.callback
@gen_test
def test_yield_outside_with(self):
@ -256,12 +262,13 @@ class StackContextTest(AsyncTestCase):
self.io_loop.add_callback(cb)
yield gen.Wait('k1')
@gen_test
def test_run_with_stack_context(self):
@gen.coroutine
def f1():
self.assertEqual(self.active_contexts, ['c1'])
yield run_with_stack_context(
StackContext(functools.partial(self.context, 'c1')),
StackContext(functools.partial(self.context, 'c2')),
f2)
self.assertEqual(self.active_contexts, ['c1'])
@ -272,7 +279,7 @@ class StackContextTest(AsyncTestCase):
self.assertEqual(self.active_contexts, ['c1', 'c2'])
self.assertEqual(self.active_contexts, [])
run_with_stack_context(
yield run_with_stack_context(
StackContext(functools.partial(self.context, 'c1')),
f1)
self.assertEqual(self.active_contexts, [])

View file

@ -0,0 +1,278 @@
#!/usr/bin/env python
#
# Copyright 2014 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, with_statement
from contextlib import closing
import os
import socket
from tornado.concurrent import Future
from tornado.netutil import bind_sockets, Resolver
from tornado.tcpclient import TCPClient, _Connector
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, bind_unused_port, gen_test
from tornado.test.util import skipIfNoIPv6, unittest
# Fake address families for testing. Used in place of AF_INET
# and AF_INET6 because some installations do not have AF_INET6.
AF1, AF2 = 1, 2
class TestTCPServer(TCPServer):
def __init__(self, family):
super(TestTCPServer, self).__init__()
self.streams = []
sockets = bind_sockets(None, 'localhost', family)
self.add_sockets(sockets)
self.port = sockets[0].getsockname()[1]
def handle_stream(self, stream, address):
self.streams.append(stream)
def stop(self):
super(TestTCPServer, self).stop()
for stream in self.streams:
stream.close()
class TCPClientTest(AsyncTestCase):
def setUp(self):
super(TCPClientTest, self).setUp()
self.server = None
self.client = TCPClient()
def start_server(self, family):
if family == socket.AF_UNSPEC and 'TRAVIS' in os.environ:
self.skipTest("dual-stack servers often have port conflicts on travis")
self.server = TestTCPServer(family)
return self.server.port
def stop_server(self):
if self.server is not None:
self.server.stop()
self.server = None
def tearDown(self):
self.client.close()
self.stop_server()
super(TCPClientTest, self).tearDown()
def skipIfLocalhostV4(self):
Resolver().resolve('localhost', 0, callback=self.stop)
addrinfo = self.wait()
families = set(addr[0] for addr in addrinfo)
if socket.AF_INET6 not in families:
self.skipTest("localhost does not resolve to ipv6")
@gen_test
def do_test_connect(self, family, host):
port = self.start_server(family)
stream = yield self.client.connect(host, port)
with closing(stream):
stream.write(b"hello")
data = yield self.server.streams[0].read_bytes(5)
self.assertEqual(data, b"hello")
def test_connect_ipv4_ipv4(self):
self.do_test_connect(socket.AF_INET, '127.0.0.1')
def test_connect_ipv4_dual(self):
self.do_test_connect(socket.AF_INET, 'localhost')
@skipIfNoIPv6
def test_connect_ipv6_ipv6(self):
self.skipIfLocalhostV4()
self.do_test_connect(socket.AF_INET6, '::1')
@skipIfNoIPv6
def test_connect_ipv6_dual(self):
self.skipIfLocalhostV4()
if Resolver.configured_class().__name__.endswith('TwistedResolver'):
self.skipTest('TwistedResolver does not support multiple addresses')
self.do_test_connect(socket.AF_INET6, 'localhost')
def test_connect_unspec_ipv4(self):
self.do_test_connect(socket.AF_UNSPEC, '127.0.0.1')
@skipIfNoIPv6
def test_connect_unspec_ipv6(self):
self.skipIfLocalhostV4()
self.do_test_connect(socket.AF_UNSPEC, '::1')
def test_connect_unspec_dual(self):
self.do_test_connect(socket.AF_UNSPEC, 'localhost')
@gen_test
def test_refused_ipv4(self):
sock, port = bind_unused_port()
sock.close()
with self.assertRaises(IOError):
yield self.client.connect('127.0.0.1', port)
class TestConnectorSplit(unittest.TestCase):
def test_one_family(self):
# These addresses aren't in the right format, but split doesn't care.
primary, secondary = _Connector.split(
[(AF1, 'a'),
(AF1, 'b')])
self.assertEqual(primary, [(AF1, 'a'),
(AF1, 'b')])
self.assertEqual(secondary, [])
def test_mixed(self):
primary, secondary = _Connector.split(
[(AF1, 'a'),
(AF2, 'b'),
(AF1, 'c'),
(AF2, 'd')])
self.assertEqual(primary, [(AF1, 'a'), (AF1, 'c')])
self.assertEqual(secondary, [(AF2, 'b'), (AF2, 'd')])
class ConnectorTest(AsyncTestCase):
class FakeStream(object):
def __init__(self):
self.closed = False
def close(self):
self.closed = True
def setUp(self):
super(ConnectorTest, self).setUp()
self.connect_futures = {}
self.streams = {}
self.addrinfo = [(AF1, 'a'), (AF1, 'b'),
(AF2, 'c'), (AF2, 'd')]
def tearDown(self):
# Unless explicitly checked (and popped) in the test, we shouldn't
# be closing any streams
for stream in self.streams.values():
self.assertFalse(stream.closed)
super(ConnectorTest, self).tearDown()
def create_stream(self, af, addr):
future = Future()
self.connect_futures[(af, addr)] = future
return future
def assert_pending(self, *keys):
self.assertEqual(sorted(self.connect_futures.keys()), sorted(keys))
def resolve_connect(self, af, addr, success):
future = self.connect_futures.pop((af, addr))
if success:
self.streams[addr] = ConnectorTest.FakeStream()
future.set_result(self.streams[addr])
else:
future.set_exception(IOError())
def start_connect(self, addrinfo):
conn = _Connector(addrinfo, self.io_loop, self.create_stream)
# Give it a huge timeout; we'll trigger timeouts manually.
future = conn.start(3600)
return conn, future
def test_immediate_success(self):
conn, future = self.start_connect(self.addrinfo)
self.assertEqual(list(self.connect_futures.keys()),
[(AF1, 'a')])
self.resolve_connect(AF1, 'a', True)
self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
def test_immediate_failure(self):
# Fail with just one address.
conn, future = self.start_connect([(AF1, 'a')])
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assertRaises(IOError, future.result)
def test_one_family_second_try(self):
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'))
self.resolve_connect(AF1, 'b', True)
self.assertEqual(future.result(), (AF1, 'b', self.streams['b']))
def test_one_family_second_try_failure(self):
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'))
self.resolve_connect(AF1, 'b', False)
self.assertRaises(IOError, future.result)
def test_one_family_second_try_timeout(self):
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
self.assert_pending((AF1, 'a'))
# trigger the timeout while the first lookup is pending;
# nothing happens.
conn.on_timeout()
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'))
self.resolve_connect(AF1, 'b', True)
self.assertEqual(future.result(), (AF1, 'b', self.streams['b']))
def test_two_families_immediate_failure(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'), (AF2, 'c'))
self.resolve_connect(AF1, 'b', False)
self.resolve_connect(AF2, 'c', True)
self.assertEqual(future.result(), (AF2, 'c', self.streams['c']))
def test_two_families_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, 'a'))
conn.on_timeout()
self.assert_pending((AF1, 'a'), (AF2, 'c'))
self.resolve_connect(AF2, 'c', True)
self.assertEqual(future.result(), (AF2, 'c', self.streams['c']))
# resolving 'a' after the connection has completed doesn't start 'b'
self.resolve_connect(AF1, 'a', False)
self.assert_pending()
def test_success_after_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, 'a'))
conn.on_timeout()
self.assert_pending((AF1, 'a'), (AF2, 'c'))
self.resolve_connect(AF1, 'a', True)
self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
# resolving 'c' after completion closes the connection.
self.resolve_connect(AF2, 'c', True)
self.assertTrue(self.streams.pop('c').closed)
def test_all_fail(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, 'a'))
conn.on_timeout()
self.assert_pending((AF1, 'a'), (AF2, 'c'))
self.resolve_connect(AF2, 'c', False)
self.assert_pending((AF1, 'a'), (AF2, 'd'))
self.resolve_connect(AF2, 'd', False)
# one queue is now empty
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'))
self.assertFalse(future.done())
self.resolve_connect(AF1, 'b', False)
self.assertRaises(IOError, future.result)

View file

@ -182,6 +182,7 @@ three
"""})
try:
loader.load("test.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# test.html:2" in traceback.format_exc())
@ -192,6 +193,7 @@ three{%end%}
"""})
try:
loader.load("test.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# test.html:2" in traceback.format_exc())
@ -202,6 +204,7 @@ three{%end%}
}, namespace={"_tt_modules": ObjectDict({"Template": lambda path, **kwargs: loader.load(path).generate(**kwargs)})})
try:
loader.load("base.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
exc_stack = traceback.format_exc()
self.assertTrue('# base.html:1' in exc_stack)
@ -214,6 +217,7 @@ three{%end%}
})
try:
loader.load("base.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# sub.html:1 (via base.html:1)" in
traceback.format_exc())
@ -225,6 +229,7 @@ three{%end%}
})
try:
loader.load("sub.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
exc_stack = traceback.format_exc()
self.assertTrue("# base.html:1" in exc_stack)
@ -240,6 +245,7 @@ three{%end%}
"""})
try:
loader.load("sub.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# sub.html:4 (via base.html:1)" in
traceback.format_exc())
@ -252,6 +258,7 @@ three{%end%}
})
try:
loader.load("a.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# c.html:1 (via b.html:1, a.html:1)" in
traceback.format_exc())
@ -380,6 +387,20 @@ raw: {% raw name %}""",
self.assertEqual(render("foo.py", ["not a string"]),
b"""s = "['not a string']"\n""")
def test_minimize_whitespace(self):
# Whitespace including newlines is allowed within template tags
# and directives, and this is one way to avoid long lines while
# keeping extra whitespace out of the rendered output.
loader = DictLoader({'foo.txt': """\
{% for i in items
%}{% if i > 0 %}, {% end %}{#
#}{{i
}}{% end
%}""",
})
self.assertEqual(loader.load("foo.txt").generate(items=range(5)),
b"0, 1, 2, 3, 4")
class TemplateLoaderTest(unittest.TestCase):
def setUp(self):

View file

@ -8,11 +8,12 @@ from tornado.test.util import unittest
import contextlib
import os
import traceback
@contextlib.contextmanager
def set_environ(name, value):
old_value = os.environ.get('name')
old_value = os.environ.get(name)
os.environ[name] = value
try:
@ -62,6 +63,39 @@ class AsyncTestCaseTest(AsyncTestCase):
self.wait(timeout=0.15)
class AsyncTestCaseWrapperTest(unittest.TestCase):
def test_undecorated_generator(self):
class Test(AsyncTestCase):
def test_gen(self):
yield
test = Test('test_gen')
result = unittest.TestResult()
test.run(result)
self.assertEqual(len(result.errors), 1)
self.assertIn("should be decorated", result.errors[0][1])
def test_undecorated_generator_with_skip(self):
class Test(AsyncTestCase):
@unittest.skip("don't run this")
def test_gen(self):
yield
test = Test('test_gen')
result = unittest.TestResult()
test.run(result)
self.assertEqual(len(result.errors), 0)
self.assertEqual(len(result.skipped), 1)
def test_other_return(self):
class Test(AsyncTestCase):
def test_other_return(self):
return 42
test = Test('test_other_return')
result = unittest.TestResult()
test.run(result)
self.assertEqual(len(result.errors), 1)
self.assertIn("Return value from test method ignored", result.errors[0][1])
class SetUpTearDownTest(unittest.TestCase):
def test_set_up_tear_down(self):
"""
@ -115,8 +149,17 @@ class GenTest(AsyncTestCase):
def test(self):
yield gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1)
with self.assertRaises(ioloop.TimeoutError):
# This can't use assertRaises because we need to inspect the
# exc_info triple (and not just the exception object)
try:
test(self)
self.fail("did not get expected exception")
except ioloop.TimeoutError:
# The stack trace should blame the add_timeout line, not just
# unrelated IOLoop/testing internals.
self.assertIn(
"gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1)",
traceback.format_exc())
self.finished = True
@ -155,5 +198,23 @@ class GenTest(AsyncTestCase):
self.finished = True
def test_with_method_args(self):
@gen_test
def test_with_args(self, *args):
self.assertEqual(args, ('test',))
yield gen.Task(self.io_loop.add_callback)
test_with_args(self, 'test')
self.finished = True
def test_with_method_kwargs(self):
@gen_test
def test_with_kwargs(self, **kwargs):
self.assertDictEqual(kwargs, {'test': 'test'})
yield gen.Task(self.io_loop.add_callback)
test_with_kwargs(self, test='test')
self.finished = True
if __name__ == '__main__':
unittest.main()

View file

@ -470,14 +470,17 @@ if have_twisted:
'twisted.internet.test.test_core.ObjectModelIntegrationTest': [],
'twisted.internet.test.test_core.SystemEventTestsBuilder': [
'test_iterate', # deliberately not supported
'test_runAfterCrash', # fails because TwistedIOLoop uses the global reactor
] if issubclass(IOLoop.configured_class(), TwistedIOLoop) else [
'test_iterate', # deliberately not supported
# Fails on TwistedIOLoop and AsyncIOLoop.
'test_runAfterCrash',
],
'twisted.internet.test.test_fdset.ReactorFDSetTestsBuilder': [
"test_lostFileDescriptor", # incompatible with epoll and kqueue
],
'twisted.internet.test.test_process.ProcessTestsBuilder': [
# Only work as root. Twisted's "skip" functionality works
# with py27+, but not unittest2 on py26.
'test_changeGID',
'test_changeUID',
],
# Process tests appear to work on OSX 10.7, but not 10.6
#'twisted.internet.test.test_process.PTYProcessTestsBuilder': [

View file

@ -1,14 +1,18 @@
from __future__ import absolute_import, division, print_function, with_statement
import os
import socket
import sys
# Encapsulate the choice of unittest or unittest2 here.
# To be used as 'from tornado.test.util import unittest'.
if sys.version_info >= (2, 7):
import unittest
else:
if sys.version_info < (2, 7):
# In py26, we must always use unittest2.
import unittest2 as unittest
else:
# Otherwise, use whichever version of unittest was imported in
# tornado.testing.
from tornado.testing import unittest
skipIfNonUnix = unittest.skipIf(os.name != 'posix' or sys.platform == 'cygwin',
"non-unix platform")
@ -17,3 +21,10 @@ skipIfNonUnix = unittest.skipIf(os.name != 'posix' or sys.platform == 'cygwin',
# timing-related tests unreliable.
skipOnTravis = unittest.skipIf('TRAVIS' in os.environ,
'timing tests unreliable on travis')
# Set the environment variable NO_NETWORK=1 to disable any tests that
# depend on an external network.
skipIfNoNetwork = unittest.skipIf('NO_NETWORK' in os.environ,
'network access disabled')
skipIfNoIPv6 = unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present')

View file

@ -151,14 +151,22 @@ class ArgReplacerTest(unittest.TestCase):
self.replacer = ArgReplacer(function, 'callback')
def test_omitted(self):
self.assertEqual(self.replacer.replace('new', (1, 2), dict()),
args = (1, 2)
kwargs = dict()
self.assertIs(self.replacer.get_old_value(args, kwargs), None)
self.assertEqual(self.replacer.replace('new', args, kwargs),
(None, (1, 2), dict(callback='new')))
def test_position(self):
self.assertEqual(self.replacer.replace('new', (1, 2, 'old', 3), dict()),
args = (1, 2, 'old', 3)
kwargs = dict()
self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old')
self.assertEqual(self.replacer.replace('new', args, kwargs),
('old', [1, 2, 'new', 3], dict()))
def test_keyword(self):
self.assertEqual(self.replacer.replace('new', (1,),
dict(y=2, callback='old', z=3)),
args = (1,)
kwargs = dict(y=2, callback='old', z=3)
self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old')
self.assertEqual(self.replacer.replace('new', args, kwargs),
('old', (1,), dict(y=2, callback='new', z=3)))

File diff suppressed because it is too large Load diff

View file

@ -1,21 +1,66 @@
from __future__ import absolute_import, division, print_function, with_statement
import traceback
from tornado.concurrent import Future
from tornado import gen
from tornado.httpclient import HTTPError
from tornado.log import gen_log
from tornado.httpclient import HTTPError, HTTPRequest
from tornado.log import gen_log, app_log
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
from tornado.test.util import unittest
from tornado.web import Application, RequestHandler
from tornado.util import u
try:
import tornado.websocket
from tornado.util import _websocket_mask_python
except ImportError:
# The unittest module presents misleading errors on ImportError
# (it acts as if websocket_test could not be found, hiding the underlying
# error). If we get an ImportError here (which could happen due to
# TORNADO_EXTENSION=1), print some extra information before failing.
traceback.print_exc()
raise
from tornado.websocket import WebSocketHandler, websocket_connect, WebSocketError
try:
from tornado import speedups
except ImportError:
speedups = None
class EchoHandler(WebSocketHandler):
class TestWebSocketHandler(WebSocketHandler):
"""Base class for testing handlers that exposes the on_close event.
This allows for deterministic cleanup of the associated socket.
"""
def initialize(self, close_future):
self.close_future = close_future
def on_close(self):
self.close_future.set_result((self.close_code, self.close_reason))
class EchoHandler(TestWebSocketHandler):
def on_message(self, message):
self.write_message(message, isinstance(message, bytes))
def on_close(self):
self.close_future.set_result(None)
class ErrorInOnMessageHandler(TestWebSocketHandler):
def on_message(self, message):
1/0
class HeaderHandler(TestWebSocketHandler):
def open(self):
try:
# In a websocket context, many RequestHandler methods
# raise RuntimeErrors.
self.set_status(503)
raise Exception("did not get expected exception")
except RuntimeError:
pass
self.write_message(self.request.headers.get('X-Test', ''))
class NonWebSocketHandler(RequestHandler):
@ -23,14 +68,29 @@ class NonWebSocketHandler(RequestHandler):
self.write('ok')
class CloseReasonHandler(TestWebSocketHandler):
def open(self):
self.close(1001, "goodbye")
class WebSocketTest(AsyncHTTPTestCase):
def get_app(self):
self.close_future = Future()
return Application([
('/echo', EchoHandler, dict(close_future=self.close_future)),
('/non_ws', NonWebSocketHandler),
('/header', HeaderHandler, dict(close_future=self.close_future)),
('/close_reason', CloseReasonHandler,
dict(close_future=self.close_future)),
('/error_in_on_message', ErrorInOnMessageHandler,
dict(close_future=self.close_future)),
])
def test_http_request(self):
# WS server, HTTP client.
response = self.fetch('/echo')
self.assertEqual(response.code, 400)
@gen_test
def test_websocket_gen(self):
ws = yield websocket_connect(
@ -39,6 +99,8 @@ class WebSocketTest(AsyncHTTPTestCase):
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
ws.close()
yield self.close_future
def test_websocket_callbacks(self):
websocket_connect(
@ -49,6 +111,40 @@ class WebSocketTest(AsyncHTTPTestCase):
ws.read_message(self.stop)
response = self.wait().result()
self.assertEqual(response, 'hello')
self.close_future.add_done_callback(lambda f: self.stop())
ws.close()
self.wait()
@gen_test
def test_binary_message(self):
ws = yield websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port())
ws.write_message(b'hello \xe9', binary=True)
response = yield ws.read_message()
self.assertEqual(response, b'hello \xe9')
ws.close()
yield self.close_future
@gen_test
def test_unicode_message(self):
ws = yield websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port())
ws.write_message(u('hello \u00e9'))
response = yield ws.read_message()
self.assertEqual(response, u('hello \u00e9'))
ws.close()
yield self.close_future
@gen_test
def test_error_in_on_message(self):
ws = yield websocket_connect(
'ws://localhost:%d/error_in_on_message' % self.get_http_port())
ws.write_message('hello')
with ExpectLog(app_log, "Uncaught exception"):
response = yield ws.read_message()
self.assertIs(response, None)
ws.close()
yield self.close_future
@gen_test
def test_websocket_http_fail(self):
@ -69,13 +165,12 @@ class WebSocketTest(AsyncHTTPTestCase):
def test_websocket_network_fail(self):
sock, port = bind_unused_port()
sock.close()
with self.assertRaises(HTTPError) as cm:
with self.assertRaises(IOError):
with ExpectLog(gen_log, ".*"):
yield websocket_connect(
'ws://localhost:%d/' % port,
io_loop=self.io_loop,
connect_timeout=0.01)
self.assertEqual(cm.exception.code, 599)
connect_timeout=3600)
@gen_test
def test_websocket_close_buffered_data(self):
@ -85,3 +180,134 @@ class WebSocketTest(AsyncHTTPTestCase):
ws.write_message('world')
ws.stream.close()
yield self.close_future
@gen_test
def test_websocket_headers(self):
# Ensure that arbitrary headers can be passed through websocket_connect.
ws = yield websocket_connect(
HTTPRequest('ws://localhost:%d/header' % self.get_http_port(),
headers={'X-Test': 'hello'}))
response = yield ws.read_message()
self.assertEqual(response, 'hello')
ws.close()
yield self.close_future
@gen_test
def test_server_close_reason(self):
ws = yield websocket_connect(
'ws://localhost:%d/close_reason' % self.get_http_port())
msg = yield ws.read_message()
# A message of None means the other side closed the connection.
self.assertIs(msg, None)
self.assertEqual(ws.close_code, 1001)
self.assertEqual(ws.close_reason, "goodbye")
@gen_test
def test_client_close_reason(self):
ws = yield websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port())
ws.close(1001, 'goodbye')
code, reason = yield self.close_future
self.assertEqual(code, 1001)
self.assertEqual(reason, 'goodbye')
@gen_test
def test_check_origin_valid_no_path(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
headers = {'Origin': 'http://localhost:%d' % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
ws.close()
yield self.close_future
@gen_test
def test_check_origin_valid_with_path(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
headers = {'Origin': 'http://localhost:%d/something' % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
ws.close()
yield self.close_future
@gen_test
def test_check_origin_invalid_partial_url(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
headers = {'Origin': 'localhost:%d' % port}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
self.assertEqual(cm.exception.code, 403)
@gen_test
def test_check_origin_invalid(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
# Host is localhost, which should not be accessible from some other
# domain
headers = {'Origin': 'http://somewhereelse.com'}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
self.assertEqual(cm.exception.code, 403)
@gen_test
def test_check_origin_invalid_subdomains(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
# Subdomains should be disallowed by default. If we could pass a
# resolver to websocket_connect we could test sibling domains as well.
headers = {'Origin': 'http://subtenant.localhost'}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
self.assertEqual(cm.exception.code, 403)
class MaskFunctionMixin(object):
# Subclasses should define self.mask(mask, data)
def test_mask(self):
self.assertEqual(self.mask(b'abcd', b''), b'')
self.assertEqual(self.mask(b'abcd', b'b'), b'\x03')
self.assertEqual(self.mask(b'abcd', b'54321'), b'TVPVP')
self.assertEqual(self.mask(b'ZXCV', b'98765432'), b'c`t`olpd')
# Include test cases with \x00 bytes (to ensure that the C
# extension isn't depending on null-terminated strings) and
# bytes with the high bit set (to smoke out signedness issues).
self.assertEqual(self.mask(b'\x00\x01\x02\x03',
b'\xff\xfb\xfd\xfc\xfe\xfa'),
b'\xff\xfa\xff\xff\xfe\xfb')
self.assertEqual(self.mask(b'\xff\xfb\xfd\xfc',
b'\x00\x01\x02\x03\x04\x05'),
b'\xff\xfa\xff\xff\xfb\xfe')
class PythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
def mask(self, mask, data):
return _websocket_mask_python(mask, data)
@unittest.skipIf(speedups is None, "tornado.speedups module not present")
class CythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
def mask(self, mask, data):
return speedups.websocket_mask(mask, data)

View file

@ -5,8 +5,8 @@ from tornado.escape import json_decode
from tornado.test.httpserver_test import TypeCheckHandler
from tornado.testing import AsyncHTTPTestCase
from tornado.util import u
from tornado.web import RequestHandler
from tornado.wsgi import WSGIApplication, WSGIContainer
from tornado.web import RequestHandler, Application
from tornado.wsgi import WSGIApplication, WSGIContainer, WSGIAdapter
class WSGIContainerTest(AsyncHTTPTestCase):
@ -74,14 +74,27 @@ class WSGIConnectionTest(httpserver_test.HTTPConnectionTest):
return WSGIContainer(validator(WSGIApplication(self.get_handlers())))
def wrap_web_tests():
def wrap_web_tests_application():
result = {}
for cls in web_test.wsgi_safe_tests:
class WSGIWrappedTest(cls):
class WSGIApplicationWrappedTest(cls):
def get_app(self):
self.app = WSGIApplication(self.get_handlers(),
**self.get_app_kwargs())
return WSGIContainer(validator(self.app))
result["WSGIWrapped_" + cls.__name__] = WSGIWrappedTest
result["WSGIApplication_" + cls.__name__] = WSGIApplicationWrappedTest
return result
globals().update(wrap_web_tests())
globals().update(wrap_web_tests_application())
def wrap_web_tests_adapter():
result = {}
for cls in web_test.wsgi_safe_tests:
class WSGIAdapterWrappedTest(cls):
def get_app(self):
self.app = Application(self.get_handlers(),
**self.get_app_kwargs())
return WSGIContainer(validator(WSGIAdapter(self.app)))
result["WSGIAdapter_" + cls.__name__] = WSGIAdapterWrappedTest
return result
globals().update(wrap_web_tests_adapter())