update shared deps

This commit is contained in:
j 2019-01-13 13:31:53 +05:30
commit 642ba49f68
275 changed files with 31987 additions and 19235 deletions

View file

@ -2,7 +2,7 @@
This only works in python 2.7+.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
from tornado.test.runtests import all, main

View file

@ -10,9 +10,11 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
from concurrent.futures import ThreadPoolExecutor
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.testing import AsyncTestCase, gen_test
from tornado.test.util import unittest, skipBefore33, skipBefore35, exec_test
@ -21,7 +23,7 @@ try:
except ImportError:
asyncio = None
else:
from tornado.platform.asyncio import AsyncIOLoop, to_asyncio_future
from tornado.platform.asyncio import AsyncIOLoop, to_asyncio_future, AnyThreadEventLoopPolicy
# This is used in dynamically-evaluated code, so silence pyflakes.
to_asyncio_future
@ -30,7 +32,6 @@ else:
class AsyncIOLoopTest(AsyncTestCase):
def get_new_ioloop(self):
io_loop = AsyncIOLoop()
asyncio.set_event_loop(io_loop.asyncio_loop)
return io_loop
def test_asyncio_callback(self):
@ -41,8 +42,15 @@ class AsyncIOLoopTest(AsyncTestCase):
@gen_test
def test_asyncio_future(self):
# Test that we can yield an asyncio future from a tornado coroutine.
# Without 'yield from', we must wrap coroutines in asyncio.async.
x = yield asyncio.async(
# Without 'yield from', we must wrap coroutines in ensure_future,
# which was introduced during Python 3.4, deprecating the prior "async".
if hasattr(asyncio, 'ensure_future'):
ensure_future = asyncio.ensure_future
else:
# async is a reserved word in Python 3.7
ensure_future = getattr(asyncio, 'async')
x = yield ensure_future(
asyncio.get_event_loop().run_in_executor(None, lambda: 42))
self.assertEqual(x, 42)
@ -69,7 +77,7 @@ class AsyncIOLoopTest(AsyncTestCase):
# as demonstrated by other tests in the package.
@gen.coroutine
def tornado_coroutine():
yield gen.Task(self.io_loop.add_callback)
yield gen.moment
raise gen.Return(42)
native_coroutine_without_adapter = exec_test(globals(), locals(), """
async def native_coroutine_without_adapter():
@ -99,10 +107,11 @@ class AsyncIOLoopTest(AsyncTestCase):
42)
# Asyncio only supports coroutines that yield asyncio-compatible
# Futures.
with self.assertRaises(RuntimeError):
# Futures (which our Future is since 5.0).
self.assertEqual(
asyncio.get_event_loop().run_until_complete(
native_coroutine_without_adapter())
native_coroutine_without_adapter()),
42)
self.assertEqual(
asyncio.get_event_loop().run_until_complete(
native_coroutine_with_adapter()),
@ -111,3 +120,87 @@ class AsyncIOLoopTest(AsyncTestCase):
asyncio.get_event_loop().run_until_complete(
native_coroutine_with_adapter2()),
42)
@unittest.skipIf(asyncio is None, "asyncio module not present")
class LeakTest(unittest.TestCase):
def setUp(self):
# Trigger a cleanup of the mapping so we start with a clean slate.
AsyncIOLoop().close()
# If we don't clean up after ourselves other tests may fail on
# py34.
self.orig_policy = asyncio.get_event_loop_policy()
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
def tearDown(self):
asyncio.get_event_loop().close()
asyncio.set_event_loop_policy(self.orig_policy)
def test_ioloop_close_leak(self):
orig_count = len(IOLoop._ioloop_for_asyncio)
for i in range(10):
# Create and close an AsyncIOLoop using Tornado interfaces.
loop = AsyncIOLoop()
loop.close()
new_count = len(IOLoop._ioloop_for_asyncio) - orig_count
self.assertEqual(new_count, 0)
def test_asyncio_close_leak(self):
orig_count = len(IOLoop._ioloop_for_asyncio)
for i in range(10):
# Create and close an AsyncIOMainLoop using asyncio interfaces.
loop = asyncio.new_event_loop()
loop.call_soon(IOLoop.current)
loop.call_soon(loop.stop)
loop.run_forever()
loop.close()
new_count = len(IOLoop._ioloop_for_asyncio) - orig_count
# Because the cleanup is run on new loop creation, we have one
# dangling entry in the map (but only one).
self.assertEqual(new_count, 1)
@unittest.skipIf(asyncio is None, "asyncio module not present")
class AnyThreadEventLoopPolicyTest(unittest.TestCase):
def setUp(self):
self.orig_policy = asyncio.get_event_loop_policy()
self.executor = ThreadPoolExecutor(1)
def tearDown(self):
asyncio.set_event_loop_policy(self.orig_policy)
self.executor.shutdown()
def get_event_loop_on_thread(self):
def get_and_close_event_loop():
"""Get the event loop. Close it if one is returned.
Returns the (closed) event loop. This is a silly thing
to do and leaves the thread in a broken state, but it's
enough for this test. Closing the loop avoids resource
leak warnings.
"""
loop = asyncio.get_event_loop()
loop.close()
return loop
future = self.executor.submit(get_and_close_event_loop)
return future.result()
def run_policy_test(self, accessor, expected_type):
# With the default policy, non-main threads don't get an event
# loop.
self.assertRaises((RuntimeError, AssertionError),
self.executor.submit(accessor).result)
# Set the policy and we can get a loop.
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
self.assertIsInstance(
self.executor.submit(accessor).result(),
expected_type)
# Clean up to silence leak warnings. Always use asyncio since
# IOLoop doesn't (currently) close the underlying loop.
self.executor.submit(lambda: asyncio.get_event_loop().close()).result()
def test_asyncio_accessor(self):
self.run_policy_test(asyncio.get_event_loop, asyncio.AbstractEventLoop)
def test_tornado_accessor(self):
self.run_policy_test(IOLoop.current, IOLoop)

View file

@ -4,37 +4,69 @@
# python 3)
from __future__ import absolute_import, division, print_function, with_statement
from tornado.auth import OpenIdMixin, OAuthMixin, OAuth2Mixin, TwitterMixin, AuthError, GoogleOAuth2Mixin, FacebookGraphMixin
from __future__ import absolute_import, division, print_function
import unittest
import warnings
from tornado.auth import (
AuthError, OpenIdMixin, OAuthMixin, OAuth2Mixin,
GoogleOAuth2Mixin, FacebookGraphMixin, TwitterMixin,
)
from tornado.concurrent import Future
from tornado.escape import json_decode
from tornado import gen
from tornado.httputil import url_concat
from tornado.log import gen_log
from tornado.log import gen_log, app_log
from tornado.testing import AsyncHTTPTestCase, ExpectLog
from tornado.util import u
from tornado.test.util import ignore_deprecation
from tornado.web import RequestHandler, Application, asynchronous, HTTPError
try:
from unittest import mock
except ImportError:
mock = None
class OpenIdClientLoginHandlerLegacy(RequestHandler, OpenIdMixin):
def initialize(self, test):
self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate')
with ignore_deprecation():
@asynchronous
def get(self):
if self.get_argument('openid.mode', None):
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
self.get_authenticated_user(
self.on_user, http_client=self.settings['http_client'])
return
res = self.authenticate_redirect()
assert isinstance(res, Future)
assert res.done()
def on_user(self, user):
if user is None:
raise Exception("user is None")
self.finish(user)
class OpenIdClientLoginHandler(RequestHandler, OpenIdMixin):
def initialize(self, test):
self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate')
@asynchronous
@gen.coroutine
def get(self):
if self.get_argument('openid.mode', None):
self.get_authenticated_user(
self.on_user, http_client=self.settings['http_client'])
user = yield self.get_authenticated_user(http_client=self.settings['http_client'])
if user is None:
raise Exception("user is None")
self.finish(user)
return
res = self.authenticate_redirect()
assert isinstance(res, Future)
assert res.done()
def on_user(self, user):
if user is None:
raise Exception("user is None")
self.finish(user)
class OpenIdServerAuthenticateHandler(RequestHandler):
def post(self):
@ -43,7 +75,7 @@ class OpenIdServerAuthenticateHandler(RequestHandler):
self.write('is_valid:true')
class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
class OAuth1ClientLoginHandlerLegacy(RequestHandler, OAuthMixin):
def initialize(self, test, version):
self._OAUTH_VERSION = version
self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
@ -53,14 +85,17 @@ class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
def _oauth_consumer_token(self):
return dict(key='asdf', secret='qwer')
@asynchronous
def get(self):
if self.get_argument('oauth_token', None):
self.get_authenticated_user(
self.on_user, http_client=self.settings['http_client'])
return
res = self.authorize_redirect(http_client=self.settings['http_client'])
assert isinstance(res, Future)
with ignore_deprecation():
@asynchronous
def get(self):
if self.get_argument('oauth_token', None):
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
self.get_authenticated_user(
self.on_user, http_client=self.settings['http_client'])
return
res = self.authorize_redirect(http_client=self.settings['http_client'])
assert isinstance(res, Future)
def on_user(self, user):
if user is None:
@ -75,6 +110,35 @@ class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
callback(dict(email='foo@example.com'))
class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
def initialize(self, test, version):
self._OAUTH_VERSION = version
self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize')
self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/oauth1/server/access_token')
def _oauth_consumer_token(self):
return dict(key='asdf', secret='qwer')
@gen.coroutine
def get(self):
if self.get_argument('oauth_token', None):
user = yield self.get_authenticated_user(http_client=self.settings['http_client'])
if user is None:
raise Exception("user is None")
self.finish(user)
return
yield self.authorize_redirect(http_client=self.settings['http_client'])
@gen.coroutine
def _oauth_get_user_future(self, access_token):
if self.get_argument('fail_in_get_user', None):
raise Exception("failing in get_user")
if access_token != dict(key='uiop', secret='5678'):
raise Exception("incorrect access token %r" % access_token)
return dict(email='foo@example.com')
class OAuth1ClientLoginCoroutineHandler(OAuth1ClientLoginHandler):
"""Replaces OAuth1ClientLoginCoroutineHandler's get() with a coroutine."""
@gen.coroutine
@ -150,7 +214,7 @@ class FacebookClientLoginHandler(RequestHandler, FacebookGraphMixin):
class FacebookServerAccessTokenHandler(RequestHandler):
def get(self):
self.write('access_token=asdf')
self.write(dict(access_token="asdf", expires_in=3600))
class FacebookServerMeHandler(RequestHandler):
@ -163,19 +227,21 @@ class TwitterClientHandler(RequestHandler, TwitterMixin):
self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/twitter/server/access_token')
self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize')
self._OAUTH_AUTHENTICATE_URL = test.get_url('/twitter/server/authenticate')
self._TWITTER_BASE_URL = test.get_url('/twitter/api')
def get_auth_http_client(self):
return self.settings['http_client']
class TwitterClientLoginHandler(TwitterClientHandler):
@asynchronous
def get(self):
if self.get_argument("oauth_token", None):
self.get_authenticated_user(self.on_user)
return
self.authorize_redirect()
class TwitterClientLoginHandlerLegacy(TwitterClientHandler):
with ignore_deprecation():
@asynchronous
def get(self):
if self.get_argument("oauth_token", None):
self.get_authenticated_user(self.on_user)
return
self.authorize_redirect()
def on_user(self, user):
if user is None:
@ -183,17 +249,44 @@ class TwitterClientLoginHandler(TwitterClientHandler):
self.finish(user)
class TwitterClientLoginGenEngineHandler(TwitterClientHandler):
@asynchronous
@gen.engine
class TwitterClientLoginHandler(TwitterClientHandler):
@gen.coroutine
def get(self):
if self.get_argument("oauth_token", None):
user = yield self.get_authenticated_user()
if user is None:
raise Exception("user is None")
self.finish(user)
else:
# Old style: with @gen.engine we can ignore the Future from
# authorize_redirect.
self.authorize_redirect()
return
yield self.authorize_redirect()
class TwitterClientAuthenticateHandler(TwitterClientHandler):
# Like TwitterClientLoginHandler, but uses authenticate_redirect
# instead of authorize_redirect.
@gen.coroutine
def get(self):
if self.get_argument("oauth_token", None):
user = yield self.get_authenticated_user()
if user is None:
raise Exception("user is None")
self.finish(user)
return
yield self.authenticate_redirect()
class TwitterClientLoginGenEngineHandler(TwitterClientHandler):
with ignore_deprecation():
@asynchronous
@gen.engine
def get(self):
if self.get_argument("oauth_token", None):
user = yield self.get_authenticated_user()
self.finish(user)
else:
# Old style: with @gen.engine we can ignore the Future from
# authorize_redirect.
self.authorize_redirect()
class TwitterClientLoginGenCoroutineHandler(TwitterClientHandler):
@ -208,36 +301,39 @@ class TwitterClientLoginGenCoroutineHandler(TwitterClientHandler):
yield self.authorize_redirect()
class TwitterClientShowUserHandlerLegacy(TwitterClientHandler):
with ignore_deprecation():
@asynchronous
@gen.engine
def get(self):
# TODO: would be nice to go through the login flow instead of
# cheating with a hard-coded access token.
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
response = yield gen.Task(self.twitter_request,
'/users/show/%s' % self.get_argument('name'),
access_token=dict(key='hjkl', secret='vbnm'))
if response is None:
self.set_status(500)
self.finish('error from twitter request')
else:
self.finish(response)
class TwitterClientShowUserHandler(TwitterClientHandler):
@asynchronous
@gen.engine
@gen.coroutine
def get(self):
# TODO: would be nice to go through the login flow instead of
# cheating with a hard-coded access token.
response = yield gen.Task(self.twitter_request,
'/users/show/%s' % self.get_argument('name'),
access_token=dict(key='hjkl', secret='vbnm'))
if response is None:
self.set_status(500)
self.finish('error from twitter request')
else:
self.finish(response)
class TwitterClientShowUserFutureHandler(TwitterClientHandler):
@asynchronous
@gen.engine
def get(self):
try:
response = yield self.twitter_request(
'/users/show/%s' % self.get_argument('name'),
access_token=dict(key='hjkl', secret='vbnm'))
except AuthError as e:
except AuthError:
self.set_status(500)
self.finish(str(e))
return
assert response is not None
self.finish(response)
self.finish('error from twitter request')
else:
self.finish(response)
class TwitterServerAccessTokenHandler(RequestHandler):
@ -276,12 +372,17 @@ class AuthTest(AsyncHTTPTestCase):
return Application(
[
# test endpoints
('/legacy/openid/client/login', OpenIdClientLoginHandlerLegacy, dict(test=self)),
('/openid/client/login', OpenIdClientLoginHandler, dict(test=self)),
('/legacy/oauth10/client/login', OAuth1ClientLoginHandlerLegacy,
dict(test=self, version='1.0')),
('/oauth10/client/login', OAuth1ClientLoginHandler,
dict(test=self, version='1.0')),
('/oauth10/client/request_params',
OAuth1ClientRequestParametersHandler,
dict(version='1.0')),
('/legacy/oauth10a/client/login', OAuth1ClientLoginHandlerLegacy,
dict(test=self, version='1.0a')),
('/oauth10a/client/login', OAuth1ClientLoginHandler,
dict(test=self, version='1.0a')),
('/oauth10a/client/login_coroutine',
@ -294,11 +395,17 @@ class AuthTest(AsyncHTTPTestCase):
('/facebook/client/login', FacebookClientLoginHandler, dict(test=self)),
('/legacy/twitter/client/login', TwitterClientLoginHandlerLegacy, dict(test=self)),
('/twitter/client/login', TwitterClientLoginHandler, dict(test=self)),
('/twitter/client/login_gen_engine', TwitterClientLoginGenEngineHandler, dict(test=self)),
('/twitter/client/login_gen_coroutine', TwitterClientLoginGenCoroutineHandler, dict(test=self)),
('/twitter/client/show_user', TwitterClientShowUserHandler, dict(test=self)),
('/twitter/client/show_user_future', TwitterClientShowUserFutureHandler, dict(test=self)),
('/twitter/client/authenticate', TwitterClientAuthenticateHandler, dict(test=self)),
('/twitter/client/login_gen_engine',
TwitterClientLoginGenEngineHandler, dict(test=self)),
('/twitter/client/login_gen_coroutine',
TwitterClientLoginGenCoroutineHandler, dict(test=self)),
('/legacy/twitter/client/show_user',
TwitterClientShowUserHandlerLegacy, dict(test=self)),
('/twitter/client/show_user',
TwitterClientShowUserHandler, dict(test=self)),
# simulated servers
('/openid/server/authenticate', OpenIdServerAuthenticateHandler),
@ -309,7 +416,8 @@ class AuthTest(AsyncHTTPTestCase):
('/facebook/server/me', FacebookServerMeHandler),
('/twitter/server/access_token', TwitterServerAccessTokenHandler),
(r'/twitter/api/users/show/(.*)\.json', TwitterServerShowUserHandler),
(r'/twitter/api/account/verify_credentials\.json', TwitterServerVerifyCredentialsHandler),
(r'/twitter/api/account/verify_credentials\.json',
TwitterServerVerifyCredentialsHandler),
],
http_client=self.http_client,
twitter_consumer_key='test_twitter_consumer_key',
@ -317,6 +425,21 @@ class AuthTest(AsyncHTTPTestCase):
facebook_api_key='test_facebook_api_key',
facebook_secret='test_facebook_secret')
def test_openid_redirect_legacy(self):
response = self.fetch('/legacy/openid/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(
'/openid/server/authenticate?' in response.headers['Location'])
def test_openid_get_user_legacy(self):
response = self.fetch('/legacy/openid/client/login?openid.mode=blah'
'&openid.ns.ax=http://openid.net/srv/ax/1.0'
'&openid.ax.type.email=http://axschema.org/contact/email'
'&openid.ax.value.email=foo@example.com')
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["email"], "foo@example.com")
def test_openid_redirect(self):
response = self.fetch('/openid/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
@ -324,11 +447,24 @@ class AuthTest(AsyncHTTPTestCase):
'/openid/server/authenticate?' in response.headers['Location'])
def test_openid_get_user(self):
response = self.fetch('/openid/client/login?openid.mode=blah&openid.ns.ax=http://openid.net/srv/ax/1.0&openid.ax.type.email=http://axschema.org/contact/email&openid.ax.value.email=foo@example.com')
response = self.fetch('/openid/client/login?openid.mode=blah'
'&openid.ns.ax=http://openid.net/srv/ax/1.0'
'&openid.ax.type.email=http://axschema.org/contact/email'
'&openid.ax.value.email=foo@example.com')
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["email"], "foo@example.com")
def test_oauth10_redirect_legacy(self):
response = self.fetch('/legacy/oauth10/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(response.headers['Location'].endswith(
'/oauth1/server/authorize?oauth_token=zxcv'))
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
response.headers['Set-Cookie'])
def test_oauth10_redirect(self):
response = self.fetch('/oauth10/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
@ -339,6 +475,16 @@ class AuthTest(AsyncHTTPTestCase):
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
response.headers['Set-Cookie'])
def test_oauth10_get_user_legacy(self):
with ignore_deprecation():
response = self.fetch(
'/legacy/oauth10/client/login?oauth_token=zxcv',
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed['email'], 'foo@example.com')
self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
def test_oauth10_get_user(self):
response = self.fetch(
'/oauth10/client/login?oauth_token=zxcv',
@ -357,6 +503,26 @@ class AuthTest(AsyncHTTPTestCase):
self.assertTrue('oauth_nonce' in parsed)
self.assertTrue('oauth_signature' in parsed)
def test_oauth10a_redirect_legacy(self):
response = self.fetch('/legacy/oauth10a/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(response.headers['Location'].endswith(
'/oauth1/server/authorize?oauth_token=zxcv'))
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
response.headers['Set-Cookie'])
def test_oauth10a_get_user_legacy(self):
with ignore_deprecation():
response = self.fetch(
'/legacy/oauth10a/client/login?oauth_token=zxcv',
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed['email'], 'foo@example.com')
self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
def test_oauth10a_redirect(self):
response = self.fetch('/oauth10a/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
@ -367,6 +533,14 @@ class AuthTest(AsyncHTTPTestCase):
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
response.headers['Set-Cookie'])
@unittest.skipIf(mock is None, 'mock package not present')
def test_oauth10a_redirect_error(self):
with mock.patch.object(OAuth1ServerRequestTokenHandler, 'get') as get:
get.side_effect = Exception("boom")
with ExpectLog(app_log, "Uncaught exception"):
response = self.fetch('/oauth10a/client/login', follow_redirects=False)
self.assertEqual(response.code, 500)
def test_oauth10a_get_user(self):
response = self.fetch(
'/oauth10a/client/login?oauth_token=zxcv',
@ -402,6 +576,9 @@ class AuthTest(AsyncHTTPTestCase):
self.assertTrue('/facebook/server/authorize?' in response.headers['Location'])
response = self.fetch('/facebook/client/login?code=1234', follow_redirects=False)
self.assertEqual(response.code, 200)
user = json_decode(response.body)
self.assertEqual(user['access_token'], 'asdf')
self.assertEqual(user['session_expires'], '3600')
def base_twitter_redirect(self, url):
# Same as test_oauth10a_redirect
@ -414,6 +591,9 @@ class AuthTest(AsyncHTTPTestCase):
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
response.headers['Set-Cookie'])
def test_twitter_redirect_legacy(self):
self.base_twitter_redirect('/legacy/twitter/client/login')
def test_twitter_redirect(self):
self.base_twitter_redirect('/twitter/client/login')
@ -423,6 +603,16 @@ class AuthTest(AsyncHTTPTestCase):
def test_twitter_redirect_gen_coroutine(self):
self.base_twitter_redirect('/twitter/client/login_gen_coroutine')
def test_twitter_authenticate_redirect(self):
response = self.fetch('/twitter/client/authenticate', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(response.headers['Location'].endswith(
'/twitter/server/authenticate?oauth_token=zxcv'), response.headers['Location'])
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
response.headers['Set-Cookie'])
def test_twitter_get_user(self):
response = self.fetch(
'/twitter/client/login?oauth_token=zxcv',
@ -430,12 +620,24 @@ class AuthTest(AsyncHTTPTestCase):
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed,
{u('access_token'): {u('key'): u('hjkl'),
u('screen_name'): u('foo'),
u('secret'): u('vbnm')},
u('name'): u('Foo'),
u('screen_name'): u('foo'),
u('username'): u('foo')})
{u'access_token': {u'key': u'hjkl',
u'screen_name': u'foo',
u'secret': u'vbnm'},
u'name': u'Foo',
u'screen_name': u'foo',
u'username': u'foo'})
def test_twitter_show_user_legacy(self):
response = self.fetch('/legacy/twitter/client/show_user?name=somebody')
response.rethrow()
self.assertEqual(json_decode(response.body),
{'name': 'Somebody', 'screen_name': 'somebody'})
def test_twitter_show_user_error_legacy(self):
with ExpectLog(gen_log, 'Error response HTTP 500'):
response = self.fetch('/legacy/twitter/client/show_user?name=error')
self.assertEqual(response.code, 500)
self.assertEqual(response.body, b'error from twitter request')
def test_twitter_show_user(self):
response = self.fetch('/twitter/client/show_user?name=somebody')
@ -444,22 +646,10 @@ class AuthTest(AsyncHTTPTestCase):
{'name': 'Somebody', 'screen_name': 'somebody'})
def test_twitter_show_user_error(self):
with ExpectLog(gen_log, 'Error response HTTP 500'):
response = self.fetch('/twitter/client/show_user?name=error')
response = self.fetch('/twitter/client/show_user?name=error')
self.assertEqual(response.code, 500)
self.assertEqual(response.body, b'error from twitter request')
def test_twitter_show_user_future(self):
response = self.fetch('/twitter/client/show_user_future?name=somebody')
response.rethrow()
self.assertEqual(json_decode(response.body),
{'name': 'Somebody', 'screen_name': 'somebody'})
def test_twitter_show_user_future_error(self):
response = self.fetch('/twitter/client/show_user_future?name=error')
self.assertEqual(response.code, 500)
self.assertIn(b'Error response HTTP 500', response.body)
class GoogleLoginHandler(RequestHandler, GoogleOAuth2Mixin):
def initialize(self, test):
@ -539,7 +729,7 @@ class GoogleOAuth2Test(AsyncHTTPTestCase):
def test_google_login(self):
response = self.fetch('/client/login')
self.assertDictEqual({
u('name'): u('Foo'),
u('email'): u('foo@example.com'),
u('access_token'): u('fake-access-token'),
u'name': u'Foo',
u'email': u'foo@example.com',
u'access_token': u'fake-access-token',
}, json_decode(response.body))

View file

@ -0,0 +1,114 @@
from __future__ import absolute_import, division, print_function
import os
import shutil
import subprocess
from subprocess import Popen
import sys
from tempfile import mkdtemp
import time
from tornado.test.util import unittest
class AutoreloadTest(unittest.TestCase):
def test_reload_module(self):
main = """\
import os
import sys
from tornado import autoreload
# This import will fail if path is not set up correctly
import testapp
print('Starting')
if 'TESTAPP_STARTED' not in os.environ:
os.environ['TESTAPP_STARTED'] = '1'
sys.stdout.flush()
autoreload._reload()
"""
# Create temporary test application
path = mkdtemp()
self.addCleanup(shutil.rmtree, path)
os.mkdir(os.path.join(path, 'testapp'))
open(os.path.join(path, 'testapp/__init__.py'), 'w').close()
with open(os.path.join(path, 'testapp/__main__.py'), 'w') as f:
f.write(main)
# Make sure the tornado module under test is available to the test
# application
pythonpath = os.getcwd()
if 'PYTHONPATH' in os.environ:
pythonpath += os.pathsep + os.environ['PYTHONPATH']
p = Popen(
[sys.executable, '-m', 'testapp'], stdout=subprocess.PIPE,
cwd=path, env=dict(os.environ, PYTHONPATH=pythonpath),
universal_newlines=True)
out = p.communicate()[0]
self.assertEqual(out, 'Starting\nStarting\n')
def test_reload_wrapper_preservation(self):
# This test verifies that when `python -m tornado.autoreload`
# is used on an application that also has an internal
# autoreload, the reload wrapper is preserved on restart.
main = """\
import os
import sys
# This import will fail if path is not set up correctly
import testapp
if 'tornado.autoreload' not in sys.modules:
raise Exception('started without autoreload wrapper')
import tornado.autoreload
print('Starting')
sys.stdout.flush()
if 'TESTAPP_STARTED' not in os.environ:
os.environ['TESTAPP_STARTED'] = '1'
# Simulate an internal autoreload (one not caused
# by the wrapper).
tornado.autoreload._reload()
else:
# Exit directly so autoreload doesn't catch it.
os._exit(0)
"""
# Create temporary test application
path = mkdtemp()
os.mkdir(os.path.join(path, 'testapp'))
self.addCleanup(shutil.rmtree, path)
init_file = os.path.join(path, 'testapp', '__init__.py')
open(init_file, 'w').close()
main_file = os.path.join(path, 'testapp', '__main__.py')
with open(main_file, 'w') as f:
f.write(main)
# Make sure the tornado module under test is available to the test
# application
pythonpath = os.getcwd()
if 'PYTHONPATH' in os.environ:
pythonpath += os.pathsep + os.environ['PYTHONPATH']
autoreload_proc = Popen(
[sys.executable, '-m', 'tornado.autoreload', '-m', 'testapp'],
stdout=subprocess.PIPE, cwd=path,
env=dict(os.environ, PYTHONPATH=pythonpath),
universal_newlines=True)
# This timeout needs to be fairly generous for pypy due to jit
# warmup costs.
for i in range(40):
if autoreload_proc.poll() is not None:
break
time.sleep(0.1)
else:
autoreload_proc.kill()
raise Exception("subprocess failed to terminate")
out = autoreload_proc.communicate()[0]
self.assertEqual(out, 'Starting\n' * 2)

View file

@ -1,4 +1,3 @@
#!/usr/bin/env python
#
# Copyright 2012 Facebook
#
@ -13,22 +12,27 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import gc
import logging
import re
import socket
import sys
import traceback
import warnings
from tornado.concurrent import Future, return_future, ReturnValueIgnoredError, run_on_executor
from tornado.concurrent import (Future, return_future, ReturnValueIgnoredError,
run_on_executor, future_set_result_unless_cancelled)
from tornado.escape import utf8, to_unicode
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream
from tornado.log import app_log
from tornado import stack_context
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, LogTrapTestCase, bind_unused_port, gen_test
from tornado.test.util import unittest
from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test
from tornado.test.util import unittest, skipBefore35, exec_test, ignore_deprecation
try:
@ -37,33 +41,51 @@ except ImportError:
futures = None
class MiscFutureTest(AsyncTestCase):
def test_future_set_result_unless_cancelled(self):
fut = Future()
future_set_result_unless_cancelled(fut, 42)
self.assertEqual(fut.result(), 42)
self.assertFalse(fut.cancelled())
fut = Future()
fut.cancel()
is_cancelled = fut.cancelled()
future_set_result_unless_cancelled(fut, 42)
self.assertEqual(fut.cancelled(), is_cancelled)
if not is_cancelled:
self.assertEqual(fut.result(), 42)
class ReturnFutureTest(AsyncTestCase):
@return_future
def sync_future(self, callback):
callback(42)
with ignore_deprecation():
@return_future
def sync_future(self, callback):
callback(42)
@return_future
def async_future(self, callback):
self.io_loop.add_callback(callback, 42)
@return_future
def async_future(self, callback):
self.io_loop.add_callback(callback, 42)
@return_future
def immediate_failure(self, callback):
1 / 0
@return_future
def immediate_failure(self, callback):
1 / 0
@return_future
def delayed_failure(self, callback):
self.io_loop.add_callback(lambda: 1 / 0)
@return_future
def delayed_failure(self, callback):
self.io_loop.add_callback(lambda: 1 / 0)
@return_future
def return_value(self, callback):
# Note that the result of both running the callback and returning
# a value (or raising an exception) is unspecified; with current
# implementations the last event prior to callback resolution wins.
return 42
@return_future
def return_value(self, callback):
# Note that the result of both running the callback and returning
# a value (or raising an exception) is unspecified; with current
# implementations the last event prior to callback resolution wins.
return 42
@return_future
def no_result_future(self, callback):
callback()
@return_future
def no_result_future(self, callback):
callback()
def test_immediate_failure(self):
with self.assertRaises(ZeroDivisionError):
@ -79,7 +101,8 @@ class ReturnFutureTest(AsyncTestCase):
self.return_value(callback=self.stop)
def test_callback_kw(self):
future = self.sync_future(callback=self.stop)
with ignore_deprecation():
future = self.sync_future(callback=self.stop)
result = self.wait()
self.assertEqual(result, 42)
self.assertEqual(future.result(), 42)
@ -87,7 +110,8 @@ class ReturnFutureTest(AsyncTestCase):
def test_callback_positional(self):
# When the callback is passed in positionally, future_wrap shouldn't
# add another callback in the kwargs.
future = self.sync_future(self.stop)
with ignore_deprecation():
future = self.sync_future(self.stop)
result = self.wait()
self.assertEqual(result, 42)
self.assertEqual(future.result(), 42)
@ -120,44 +144,68 @@ class ReturnFutureTest(AsyncTestCase):
def test_delayed_failure(self):
future = self.delayed_failure()
self.io_loop.add_future(future, self.stop)
future2 = self.wait()
with ignore_deprecation():
self.io_loop.add_future(future, self.stop)
future2 = self.wait()
self.assertIs(future, future2)
with self.assertRaises(ZeroDivisionError):
future.result()
def test_kw_only_callback(self):
@return_future
def f(**kwargs):
kwargs['callback'](42)
with ignore_deprecation():
@return_future
def f(**kwargs):
kwargs['callback'](42)
future = f()
self.assertEqual(future.result(), 42)
def test_error_in_callback(self):
self.sync_future(callback=lambda future: 1 / 0)
with ignore_deprecation():
self.sync_future(callback=lambda future: 1 / 0)
# The exception gets caught by our StackContext and will be re-raised
# when we wait.
self.assertRaises(ZeroDivisionError, self.wait)
def test_no_result_future(self):
future = self.no_result_future(self.stop)
with ignore_deprecation():
future = self.no_result_future(self.stop)
result = self.wait()
self.assertIs(result, None)
# result of this future is undefined, but not an error
future.result()
def test_no_result_future_callback(self):
future = self.no_result_future(callback=lambda: self.stop())
with ignore_deprecation():
future = self.no_result_future(callback=lambda: self.stop())
result = self.wait()
self.assertIs(result, None)
future.result()
@gen_test
def test_future_traceback_legacy(self):
with ignore_deprecation():
@return_future
@gen.engine
def f(callback):
yield gen.Task(self.io_loop.add_callback)
try:
1 / 0
except ZeroDivisionError:
self.expected_frame = traceback.extract_tb(
sys.exc_info()[2], limit=1)[0]
raise
try:
yield f()
self.fail("didn't get expected exception")
except ZeroDivisionError:
tb = traceback.extract_tb(sys.exc_info()[2])
self.assertIn(self.expected_frame, tb)
@gen_test
def test_future_traceback(self):
@return_future
@gen.engine
def f(callback):
yield gen.Task(self.io_loop.add_callback)
@gen.coroutine
def f():
yield gen.moment
try:
1 / 0
except ZeroDivisionError:
@ -171,25 +219,50 @@ class ReturnFutureTest(AsyncTestCase):
tb = traceback.extract_tb(sys.exc_info()[2])
self.assertIn(self.expected_frame, tb)
@gen_test
def test_uncaught_exception_log(self):
if IOLoop.configured_class().__name__.endswith('AsyncIOLoop'):
# Install an exception handler that mirrors our
# non-asyncio logging behavior.
def exc_handler(loop, context):
app_log.error('%s: %s', context['message'],
type(context.get('exception')))
self.io_loop.asyncio_loop.set_exception_handler(exc_handler)
@gen.coroutine
def f():
yield gen.moment
1 / 0
g = f()
with ExpectLog(app_log,
"(?s)Future.* exception was never retrieved:"
".*ZeroDivisionError"):
yield gen.moment
yield gen.moment
# For some reason, TwistedIOLoop and pypy3 need a third iteration
# in order to drain references to the future
yield gen.moment
del g
gc.collect() # for PyPy
# The following series of classes demonstrate and test various styles
# of use, with and without generators and futures.
class CapServer(TCPServer):
@gen.coroutine
def handle_stream(self, stream, address):
logging.info("handle_stream")
self.stream = stream
self.stream.read_until(b"\n", self.handle_read)
def handle_read(self, data):
logging.info("handle_read")
data = yield stream.read_until(b"\n")
data = to_unicode(data)
if data == data.upper():
self.stream.write(b"error\talready capitalized\n")
stream.write(b"error\talready capitalized\n")
else:
# data already has \n
self.stream.write(utf8("ok\t%s" % data.upper()))
self.stream.close()
stream.write(utf8("ok\t%s" % data.upper()))
stream.close()
class CapError(Exception):
@ -197,9 +270,8 @@ class CapError(Exception):
class BaseCapClient(object):
def __init__(self, port, io_loop):
def __init__(self, port):
self.port = port
self.io_loop = io_loop
def process_response(self, data):
status, message = re.match('(.*)\t(.*)\n', to_unicode(data)).groups()
@ -211,9 +283,9 @@ class BaseCapClient(object):
class ManualCapClient(BaseCapClient):
def capitalize(self, request_data, callback=None):
logging.info("capitalize")
logging.debug("capitalize")
self.request_data = request_data
self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
self.stream = IOStream(socket.socket())
self.stream.connect(('127.0.0.1', self.port),
callback=self.handle_connect)
self.future = Future()
@ -223,12 +295,12 @@ class ManualCapClient(BaseCapClient):
return self.future
def handle_connect(self):
logging.info("handle_connect")
logging.debug("handle_connect")
self.stream.write(utf8(self.request_data + "\n"))
self.stream.read_until(b'\n', callback=self.handle_read)
def handle_read(self, data):
logging.info("handle_read")
logging.debug("handle_read")
self.stream.close()
try:
self.future.set_result(self.process_response(data))
@ -237,62 +309,64 @@ class ManualCapClient(BaseCapClient):
class DecoratorCapClient(BaseCapClient):
@return_future
def capitalize(self, request_data, callback):
logging.info("capitalize")
self.request_data = request_data
self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
self.stream.connect(('127.0.0.1', self.port),
callback=self.handle_connect)
self.callback = callback
with ignore_deprecation():
@return_future
def capitalize(self, request_data, callback):
logging.debug("capitalize")
self.request_data = request_data
self.stream = IOStream(socket.socket())
self.stream.connect(('127.0.0.1', self.port),
callback=self.handle_connect)
self.callback = callback
def handle_connect(self):
logging.info("handle_connect")
logging.debug("handle_connect")
self.stream.write(utf8(self.request_data + "\n"))
self.stream.read_until(b'\n', callback=self.handle_read)
def handle_read(self, data):
logging.info("handle_read")
logging.debug("handle_read")
self.stream.close()
self.callback(self.process_response(data))
class GeneratorCapClient(BaseCapClient):
@return_future
@gen.engine
def capitalize(self, request_data, callback):
logging.info('capitalize')
stream = IOStream(socket.socket(), io_loop=self.io_loop)
logging.info('connecting')
yield gen.Task(stream.connect, ('127.0.0.1', self.port))
@gen.coroutine
def capitalize(self, request_data):
logging.debug('capitalize')
stream = IOStream(socket.socket())
logging.debug('connecting')
yield stream.connect(('127.0.0.1', self.port))
stream.write(utf8(request_data + '\n'))
logging.info('reading')
data = yield gen.Task(stream.read_until, b'\n')
logging.info('returning')
logging.debug('reading')
data = yield stream.read_until(b'\n')
logging.debug('returning')
stream.close()
callback(self.process_response(data))
raise gen.Return(self.process_response(data))
class ClientTestMixin(object):
def setUp(self):
super(ClientTestMixin, self).setUp()
self.server = CapServer(io_loop=self.io_loop)
super(ClientTestMixin, self).setUp() # type: ignore
self.server = CapServer()
sock, port = bind_unused_port()
self.server.add_sockets([sock])
self.client = self.client_class(io_loop=self.io_loop, port=port)
self.client = self.client_class(port=port)
def tearDown(self):
self.server.stop()
super(ClientTestMixin, self).tearDown()
super(ClientTestMixin, self).tearDown() # type: ignore
def test_callback(self):
self.client.capitalize("hello", callback=self.stop)
with ignore_deprecation():
self.client.capitalize("hello", callback=self.stop)
result = self.wait()
self.assertEqual(result, "HELLO")
def test_callback_error(self):
self.client.capitalize("HELLO", callback=self.stop)
self.assertRaisesRegexp(CapError, "already capitalized", self.wait)
with ignore_deprecation():
self.client.capitalize("HELLO", callback=self.stop)
self.assertRaisesRegexp(CapError, "already capitalized", self.wait)
def test_future(self):
future = self.client.capitalize("hello")
@ -307,33 +381,49 @@ class ClientTestMixin(object):
self.assertRaisesRegexp(CapError, "already capitalized", future.result)
def test_generator(self):
@gen.engine
@gen.coroutine
def f():
result = yield self.client.capitalize("hello")
self.assertEqual(result, "HELLO")
self.stop()
f()
self.wait()
self.io_loop.run_sync(f)
def test_generator_error(self):
@gen.engine
@gen.coroutine
def f():
with self.assertRaisesRegexp(CapError, "already capitalized"):
yield self.client.capitalize("HELLO")
self.stop()
f()
self.wait()
self.io_loop.run_sync(f)
class ManualClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
class ManualClientTest(ClientTestMixin, AsyncTestCase):
client_class = ManualCapClient
def setUp(self):
self.warning_catcher = warnings.catch_warnings()
self.warning_catcher.__enter__()
warnings.simplefilter('ignore', DeprecationWarning)
super(ManualClientTest, self).setUp()
class DecoratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
def tearDown(self):
super(ManualClientTest, self).tearDown()
self.warning_catcher.__exit__(None, None, None)
class DecoratorClientTest(ClientTestMixin, AsyncTestCase):
client_class = DecoratorCapClient
def setUp(self):
self.warning_catcher = warnings.catch_warnings()
self.warning_catcher.__enter__()
warnings.simplefilter('ignore', DeprecationWarning)
super(DecoratorClientTest, self).setUp()
class GeneratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
def tearDown(self):
super(DecoratorClientTest, self).tearDown()
self.warning_catcher.__exit__(None, None, None)
class GeneratorClientTest(ClientTestMixin, AsyncTestCase):
client_class = GeneratorCapClient
@ -342,74 +432,65 @@ class RunOnExecutorTest(AsyncTestCase):
@gen_test
def test_no_calling(self):
class Object(object):
def __init__(self, io_loop):
self.io_loop = io_loop
def __init__(self):
self.executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor
def f(self):
return 42
o = Object(io_loop=self.io_loop)
o = Object()
answer = yield o.f()
self.assertEqual(answer, 42)
@gen_test
def test_call_with_no_args(self):
class Object(object):
def __init__(self, io_loop):
self.io_loop = io_loop
def __init__(self):
self.executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor()
def f(self):
return 42
o = Object(io_loop=self.io_loop)
answer = yield o.f()
self.assertEqual(answer, 42)
@gen_test
def test_call_with_io_loop(self):
class Object(object):
def __init__(self, io_loop):
self._io_loop = io_loop
self.executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor(io_loop='_io_loop')
def f(self):
return 42
o = Object(io_loop=self.io_loop)
o = Object()
answer = yield o.f()
self.assertEqual(answer, 42)
@gen_test
def test_call_with_executor(self):
class Object(object):
def __init__(self, io_loop):
self.io_loop = io_loop
def __init__(self):
self.__executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor(executor='_Object__executor')
def f(self):
return 42
o = Object(io_loop=self.io_loop)
o = Object()
answer = yield o.f()
self.assertEqual(answer, 42)
@skipBefore35
@gen_test
def test_call_with_both(self):
def test_async_await(self):
class Object(object):
def __init__(self, io_loop):
self._io_loop = io_loop
self.__executor = futures.thread.ThreadPoolExecutor(1)
def __init__(self):
self.executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor(io_loop='_io_loop', executor='_Object__executor')
@run_on_executor()
def f(self):
return 42
o = Object(io_loop=self.io_loop)
answer = yield o.f()
self.assertEqual(answer, 42)
o = Object()
namespace = exec_test(globals(), locals(), """
async def f():
answer = await o.f()
return answer
""")
result = yield namespace['f']()
self.assertEqual(result, 42)
if __name__ == '__main__':
unittest.main()

View file

@ -1,18 +1,20 @@
from __future__ import absolute_import, division, print_function, with_statement
# coding: utf-8
from __future__ import absolute_import, division, print_function
from hashlib import md5
from tornado.escape import utf8
from tornado.httpclient import HTTPRequest
from tornado.httpclient import HTTPRequest, HTTPClientError
from tornado.locks import Event
from tornado.stack_context import ExceptionStackContext
from tornado.testing import AsyncHTTPTestCase
from tornado.testing import AsyncHTTPTestCase, gen_test
from tornado.test import httpclient_test
from tornado.test.util import unittest
from tornado.test.util import unittest, ignore_deprecation
from tornado.web import Application, RequestHandler
try:
import pycurl
import pycurl # type: ignore
except ImportError:
pycurl = None
@ -23,21 +25,22 @@ if pycurl is not None:
@unittest.skipIf(pycurl is None, "pycurl module not present")
class CurlHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
def get_http_client(self):
client = CurlAsyncHTTPClient(io_loop=self.io_loop,
defaults=dict(allow_ipv6=False))
client = CurlAsyncHTTPClient(defaults=dict(allow_ipv6=False))
# make sure AsyncHTTPClient magic doesn't give us the wrong class
self.assertTrue(isinstance(client, CurlAsyncHTTPClient))
return client
class DigestAuthHandler(RequestHandler):
def initialize(self, username, password):
self.username = username
self.password = password
def get(self):
realm = 'test'
opaque = 'asdf'
# Real implementations would use a random nonce.
nonce = "1234"
username = 'foo'
password = 'bar'
auth_header = self.request.headers.get('Authorization', None)
if auth_header is not None:
@ -52,9 +55,9 @@ class DigestAuthHandler(RequestHandler):
assert param_dict['realm'] == realm
assert param_dict['opaque'] == opaque
assert param_dict['nonce'] == nonce
assert param_dict['username'] == username
assert param_dict['username'] == self.username
assert param_dict['uri'] == self.request.path
h1 = md5(utf8('%s:%s:%s' % (username, realm, password))).hexdigest()
h1 = md5(utf8('%s:%s:%s' % (self.username, realm, self.password))).hexdigest()
h2 = md5(utf8('%s:%s' % (self.request.method,
self.request.path))).hexdigest()
digest = md5(utf8('%s:%s:%s' % (h1, nonce, h2))).hexdigest()
@ -83,29 +86,36 @@ class CustomFailReasonHandler(RequestHandler):
class CurlHTTPClientTestCase(AsyncHTTPTestCase):
def setUp(self):
super(CurlHTTPClientTestCase, self).setUp()
self.http_client = CurlAsyncHTTPClient(self.io_loop,
defaults=dict(allow_ipv6=False))
self.http_client = self.create_client()
def get_app(self):
return Application([
('/digest', DigestAuthHandler),
('/digest', DigestAuthHandler, {'username': 'foo', 'password': 'bar'}),
('/digest_non_ascii', DigestAuthHandler, {'username': 'foo', 'password': 'barユ£'}),
('/custom_reason', CustomReasonHandler),
('/custom_fail_reason', CustomFailReasonHandler),
])
def create_client(self, **kwargs):
return CurlAsyncHTTPClient(force_instance=True,
defaults=dict(allow_ipv6=False),
**kwargs)
@gen_test
def test_prepare_curl_callback_stack_context(self):
exc_info = []
error_event = Event()
def error_handler(typ, value, tb):
exc_info.append((typ, value, tb))
self.stop()
error_event.set()
return True
with ExceptionStackContext(error_handler):
request = HTTPRequest(self.get_url('/'),
prepare_curl_callback=lambda curl: 1 / 0)
self.http_client.fetch(request, callback=self.stop)
self.wait()
with ignore_deprecation():
with ExceptionStackContext(error_handler):
request = HTTPRequest(self.get_url('/custom_reason'),
prepare_curl_callback=lambda curl: 1 / 0)
yield [error_event.wait(), self.http_client.fetch(request)]
self.assertEqual(1, len(exc_info))
self.assertIs(exc_info[0][0], ZeroDivisionError)
@ -122,3 +132,22 @@ class CurlHTTPClientTestCase(AsyncHTTPTestCase):
response = self.fetch('/custom_fail_reason')
self.assertEqual(str(response.error), "HTTP 400: Custom reason")
def test_failed_setup(self):
self.http_client = self.create_client(max_clients=1)
for i in range(5):
with ignore_deprecation():
response = self.fetch(u'/ユニコード')
self.assertIsNot(response.error, None)
with self.assertRaises((UnicodeEncodeError, HTTPClientError)):
# This raises UnicodeDecodeError on py3 and
# HTTPClientError(404) on py2. The main motivation of
# this test is to ensure that the UnicodeEncodeError
# during the setup phase doesn't lead the request to
# be dropped on the floor.
response = self.fetch(u'/ユニコード', raise_error=True)
def test_digest_auth_non_ascii(self):
response = self.fetch('/digest_non_ascii', auth_mode='digest',
auth_username='foo', auth_password='barユ£')
self.assertEqual(response.body, b'ok')

View file

@ -1,134 +1,138 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
from __future__ import absolute_import, division, print_function, with_statement
import tornado.escape
from tornado.escape import utf8, xhtml_escape, xhtml_unescape, url_escape, url_unescape, to_unicode, json_decode, json_encode, squeeze, recursive_unicode
from tornado.util import u, unicode_type
from tornado.escape import (
utf8, xhtml_escape, xhtml_unescape, url_escape, url_unescape,
to_unicode, json_decode, json_encode, squeeze, recursive_unicode,
)
from tornado.util import unicode_type
from tornado.test.util import unittest
linkify_tests = [
# (input, linkify_kwargs, expected_output)
("hello http://world.com/!", {},
u('hello <a href="http://world.com/">http://world.com/</a>!')),
u'hello <a href="http://world.com/">http://world.com/</a>!'),
("hello http://world.com/with?param=true&stuff=yes", {},
u('hello <a href="http://world.com/with?param=true&amp;stuff=yes">http://world.com/with?param=true&amp;stuff=yes</a>')),
u'hello <a href="http://world.com/with?param=true&amp;stuff=yes">http://world.com/with?param=true&amp;stuff=yes</a>'), # noqa: E501
# an opened paren followed by many chars killed Gruber's regex
("http://url.com/w(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", {},
u('<a href="http://url.com/w">http://url.com/w</a>(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa')),
u'<a href="http://url.com/w">http://url.com/w</a>(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'), # noqa: E501
# as did too many dots at the end
("http://url.com/withmany.......................................", {},
u('<a href="http://url.com/withmany">http://url.com/withmany</a>.......................................')),
u'<a href="http://url.com/withmany">http://url.com/withmany</a>.......................................'), # noqa: E501
("http://url.com/withmany((((((((((((((((((((((((((((((((((a)", {},
u('<a href="http://url.com/withmany">http://url.com/withmany</a>((((((((((((((((((((((((((((((((((a)')),
u'<a href="http://url.com/withmany">http://url.com/withmany</a>((((((((((((((((((((((((((((((((((a)'), # noqa: E501
# some examples from http://daringfireball.net/2009/11/liberal_regex_for_matching_urls
# plus a fex extras (such as multiple parentheses).
("http://foo.com/blah_blah", {},
u('<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>')),
u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>'),
("http://foo.com/blah_blah/", {},
u('<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>')),
u'<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>'),
("(Something like http://foo.com/blah_blah)", {},
u('(Something like <a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>)')),
u'(Something like <a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>)'),
("http://foo.com/blah_blah_(wikipedia)", {},
u('<a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>')),
u'<a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>'),
("http://foo.com/blah_(blah)_(wikipedia)_blah", {},
u('<a href="http://foo.com/blah_(blah)_(wikipedia)_blah">http://foo.com/blah_(blah)_(wikipedia)_blah</a>')),
u'<a href="http://foo.com/blah_(blah)_(wikipedia)_blah">http://foo.com/blah_(blah)_(wikipedia)_blah</a>'), # noqa: E501
("(Something like http://foo.com/blah_blah_(wikipedia))", {},
u('(Something like <a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>)')),
u'(Something like <a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>)'), # noqa: E501
("http://foo.com/blah_blah.", {},
u('<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>.')),
u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>.'),
("http://foo.com/blah_blah/.", {},
u('<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>.')),
u'<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>.'),
("<http://foo.com/blah_blah>", {},
u('&lt;<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>&gt;')),
u'&lt;<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>&gt;'),
("<http://foo.com/blah_blah/>", {},
u('&lt;<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>&gt;')),
u'&lt;<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>&gt;'),
("http://foo.com/blah_blah,", {},
u('<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>,')),
u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>,'),
("http://www.example.com/wpstyle/?p=364.", {},
u('<a href="http://www.example.com/wpstyle/?p=364">http://www.example.com/wpstyle/?p=364</a>.')),
u'<a href="http://www.example.com/wpstyle/?p=364">http://www.example.com/wpstyle/?p=364</a>.'),
("rdar://1234",
{"permitted_protocols": ["http", "rdar"]},
u('<a href="rdar://1234">rdar://1234</a>')),
u'<a href="rdar://1234">rdar://1234</a>'),
("rdar:/1234",
{"permitted_protocols": ["rdar"]},
u('<a href="rdar:/1234">rdar:/1234</a>')),
u'<a href="rdar:/1234">rdar:/1234</a>'),
("http://userid:password@example.com:8080", {},
u('<a href="http://userid:password@example.com:8080">http://userid:password@example.com:8080</a>')),
u'<a href="http://userid:password@example.com:8080">http://userid:password@example.com:8080</a>'), # noqa: E501
("http://userid@example.com", {},
u('<a href="http://userid@example.com">http://userid@example.com</a>')),
u'<a href="http://userid@example.com">http://userid@example.com</a>'),
("http://userid@example.com:8080", {},
u('<a href="http://userid@example.com:8080">http://userid@example.com:8080</a>')),
u'<a href="http://userid@example.com:8080">http://userid@example.com:8080</a>'),
("http://userid:password@example.com", {},
u('<a href="http://userid:password@example.com">http://userid:password@example.com</a>')),
u'<a href="http://userid:password@example.com">http://userid:password@example.com</a>'),
("message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e",
{"permitted_protocols": ["http", "message"]},
u('<a href="message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e">message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e</a>')),
u'<a href="message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e">'
u'message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e</a>'),
(u("http://\u27a1.ws/\u4a39"), {},
u('<a href="http://\u27a1.ws/\u4a39">http://\u27a1.ws/\u4a39</a>')),
(u"http://\u27a1.ws/\u4a39", {},
u'<a href="http://\u27a1.ws/\u4a39">http://\u27a1.ws/\u4a39</a>'),
("<tag>http://example.com</tag>", {},
u('&lt;tag&gt;<a href="http://example.com">http://example.com</a>&lt;/tag&gt;')),
u'&lt;tag&gt;<a href="http://example.com">http://example.com</a>&lt;/tag&gt;'),
("Just a www.example.com link.", {},
u('Just a <a href="http://www.example.com">www.example.com</a> link.')),
u'Just a <a href="http://www.example.com">www.example.com</a> link.'),
("Just a www.example.com link.",
{"require_protocol": True},
u('Just a www.example.com link.')),
u'Just a www.example.com link.'),
("A http://reallylong.com/link/that/exceedsthelenglimit.html",
{"require_protocol": True, "shorten": True},
u('A <a href="http://reallylong.com/link/that/exceedsthelenglimit.html" title="http://reallylong.com/link/that/exceedsthelenglimit.html">http://reallylong.com/link...</a>')),
u'A <a href="http://reallylong.com/link/that/exceedsthelenglimit.html"'
u' title="http://reallylong.com/link/that/exceedsthelenglimit.html">http://reallylong.com/link...</a>'), # noqa: E501
("A http://reallylongdomainnamethatwillbetoolong.com/hi!",
{"shorten": True},
u('A <a href="http://reallylongdomainnamethatwillbetoolong.com/hi" title="http://reallylongdomainnamethatwillbetoolong.com/hi">http://reallylongdomainnametha...</a>!')),
u'A <a href="http://reallylongdomainnamethatwillbetoolong.com/hi"'
u' title="http://reallylongdomainnamethatwillbetoolong.com/hi">http://reallylongdomainnametha...</a>!'), # noqa: E501
("A file:///passwords.txt and http://web.com link", {},
u('A file:///passwords.txt and <a href="http://web.com">http://web.com</a> link')),
u'A file:///passwords.txt and <a href="http://web.com">http://web.com</a> link'),
("A file:///passwords.txt and http://web.com link",
{"permitted_protocols": ["file"]},
u('A <a href="file:///passwords.txt">file:///passwords.txt</a> and http://web.com link')),
u'A <a href="file:///passwords.txt">file:///passwords.txt</a> and http://web.com link'),
("www.external-link.com",
{"extra_params": 'rel="nofollow" class="external"'},
u('<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>')),
u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>'), # noqa: E501
("www.external-link.com and www.internal-link.com/blogs extra",
{"extra_params": lambda href: 'class="internal"' if href.startswith("http://www.internal-link.com") else 'rel="nofollow" class="external"'},
u('<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a> and <a href="http://www.internal-link.com/blogs" class="internal">www.internal-link.com/blogs</a> extra')),
{"extra_params": lambda href: 'class="internal"' if href.startswith("http://www.internal-link.com") else 'rel="nofollow" class="external"'}, # noqa: E501
u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>' # noqa: E501
u' and <a href="http://www.internal-link.com/blogs" class="internal">www.internal-link.com/blogs</a> extra'), # noqa: E501
("www.external-link.com",
{"extra_params": lambda href: ' rel="nofollow" class="external" '},
u('<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>')),
u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>'), # noqa: E501
]
@ -141,13 +145,13 @@ class EscapeTestCase(unittest.TestCase):
def test_xhtml_escape(self):
tests = [
("<foo>", "&lt;foo&gt;"),
(u("<foo>"), u("&lt;foo&gt;")),
(u"<foo>", u"&lt;foo&gt;"),
(b"<foo>", b"&lt;foo&gt;"),
("<>&\"'", "&lt;&gt;&amp;&quot;&#39;"),
("&amp;", "&amp;amp;"),
(u("<\u00e9>"), u("&lt;\u00e9&gt;")),
(u"<\u00e9>", u"&lt;\u00e9&gt;"),
(b"<\xc3\xa9>", b"&lt;\xc3\xa9&gt;"),
]
for unescaped, escaped in tests:
@ -159,7 +163,7 @@ class EscapeTestCase(unittest.TestCase):
('foo&#32;bar', 'foo bar'),
('foo&#x20;bar', 'foo bar'),
('foo&#X20;bar', 'foo bar'),
('foo&#xabc;bar', u('foo\u0abcbar')),
('foo&#xabc;bar', u'foo\u0abcbar'),
('foo&#xyz;bar', 'foo&#xyz;bar'), # invalid encoding
('foo&#;bar', 'foo&#;bar'), # invalid encoding
('foo&#x;bar', 'foo&#x;bar'), # invalid encoding
@ -170,20 +174,20 @@ class EscapeTestCase(unittest.TestCase):
def test_url_escape_unicode(self):
tests = [
# byte strings are passed through as-is
(u('\u00e9').encode('utf8'), '%C3%A9'),
(u('\u00e9').encode('latin1'), '%E9'),
(u'\u00e9'.encode('utf8'), '%C3%A9'),
(u'\u00e9'.encode('latin1'), '%E9'),
# unicode strings become utf8
(u('\u00e9'), '%C3%A9'),
(u'\u00e9', '%C3%A9'),
]
for unescaped, escaped in tests:
self.assertEqual(url_escape(unescaped), escaped)
def test_url_unescape_unicode(self):
tests = [
('%C3%A9', u('\u00e9'), 'utf8'),
('%C3%A9', u('\u00c3\u00a9'), 'latin1'),
('%C3%A9', utf8(u('\u00e9')), None),
('%C3%A9', u'\u00e9', 'utf8'),
('%C3%A9', u'\u00c3\u00a9', 'latin1'),
('%C3%A9', utf8(u'\u00e9'), None),
]
for escaped, unescaped, encoding in tests:
# input strings to url_unescape should only contain ascii
@ -209,28 +213,29 @@ class EscapeTestCase(unittest.TestCase):
# On python2 the escape methods should generally return the same
# type as their argument
self.assertEqual(type(xhtml_escape("foo")), str)
self.assertEqual(type(xhtml_escape(u("foo"))), unicode_type)
self.assertEqual(type(xhtml_escape(u"foo")), unicode_type)
def test_json_decode(self):
# json_decode accepts both bytes and unicode, but strings it returns
# are always unicode.
self.assertEqual(json_decode(b'"foo"'), u("foo"))
self.assertEqual(json_decode(u('"foo"')), u("foo"))
self.assertEqual(json_decode(b'"foo"'), u"foo")
self.assertEqual(json_decode(u'"foo"'), u"foo")
# Non-ascii bytes are interpreted as utf8
self.assertEqual(json_decode(utf8(u('"\u00e9"'))), u("\u00e9"))
self.assertEqual(json_decode(utf8(u'"\u00e9"')), u"\u00e9")
def test_json_encode(self):
# json deals with strings, not bytes. On python 2 byte strings will
# convert automatically if they are utf8; on python 3 byte strings
# are not allowed.
self.assertEqual(json_decode(json_encode(u("\u00e9"))), u("\u00e9"))
self.assertEqual(json_decode(json_encode(u"\u00e9")), u"\u00e9")
if bytes is str:
self.assertEqual(json_decode(json_encode(utf8(u("\u00e9")))), u("\u00e9"))
self.assertEqual(json_decode(json_encode(utf8(u"\u00e9"))), u"\u00e9")
self.assertRaises(UnicodeDecodeError, json_encode, b"\xe9")
def test_squeeze(self):
self.assertEqual(squeeze(u('sequences of whitespace chars')), u('sequences of whitespace chars'))
self.assertEqual(squeeze(u'sequences of whitespace chars'),
u'sequences of whitespace chars')
def test_recursive_unicode(self):
tests = {
@ -239,7 +244,7 @@ class EscapeTestCase(unittest.TestCase):
'tuple': (b"foo", b"bar"),
'bytes': b"foo"
}
self.assertEqual(recursive_unicode(tests['dict']), {u("foo"): u("bar")})
self.assertEqual(recursive_unicode(tests['list']), [u("foo"), u("bar")])
self.assertEqual(recursive_unicode(tests['tuple']), (u("foo"), u("bar")))
self.assertEqual(recursive_unicode(tests['bytes']), u("foo"))
self.assertEqual(recursive_unicode(tests['dict']), {u"foo": u"bar"})
self.assertEqual(recursive_unicode(tests['list']), [u"foo", u"bar"])
self.assertEqual(recursive_unicode(tests['tuple']), (u"foo", u"bar"))
self.assertEqual(recursive_unicode(tests['bytes']), u"foo")

View file

@ -1,12 +1,15 @@
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import gc
import contextlib
import datetime
import functools
import platform
import sys
import textwrap
import time
import weakref
import warnings
from tornado.concurrent import return_future, Future
from tornado.escape import url_escape
@ -15,7 +18,7 @@ from tornado.ioloop import IOLoop
from tornado.log import app_log
from tornado import stack_context
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
from tornado.test.util import unittest, skipOnTravis, skipBefore33, skipBefore35, skipNotCPython, exec_test
from tornado.test.util import unittest, skipOnTravis, skipBefore33, skipBefore35, skipNotCPython, exec_test, ignore_deprecation # noqa: E501
from tornado.web import Application, RequestHandler, asynchronous, HTTPError
from tornado import gen
@ -25,12 +28,24 @@ try:
except ImportError:
futures = None
try:
import asyncio
except ImportError:
asyncio = None
class GenEngineTest(AsyncTestCase):
def setUp(self):
self.warning_catcher = warnings.catch_warnings()
self.warning_catcher.__enter__()
warnings.simplefilter('ignore', DeprecationWarning)
super(GenEngineTest, self).setUp()
self.named_contexts = []
def tearDown(self):
super(GenEngineTest, self).tearDown()
self.warning_catcher.__exit__(None, None, None)
def named_context(self, name):
@contextlib.contextmanager
def context():
@ -53,9 +68,10 @@ class GenEngineTest(AsyncTestCase):
self.io_loop.add_callback(functools.partial(
self.delay_callback, iterations - 1, callback, arg))
@return_future
def async_future(self, result, callback):
self.io_loop.add_callback(callback, result)
with ignore_deprecation():
@return_future
def async_future(self, result, callback):
self.io_loop.add_callback(callback, result)
@gen.coroutine
def async_exception(self, e):
@ -276,6 +292,13 @@ class GenEngineTest(AsyncTestCase):
pass
self.orphaned_callback()
def test_none(self):
@gen.engine
def f():
yield None
self.stop()
self.run_gen(f)
def test_multi(self):
@gen.engine
def f():
@ -645,6 +668,237 @@ class GenEngineTest(AsyncTestCase):
self.assertIs(self.task_ref(), None)
# GenBasicTest duplicates the non-deprecated portions of GenEngineTest
# with gen.coroutine to ensure we don't lose coverage when gen.engine
# goes away.
class GenBasicTest(AsyncTestCase):
@gen.coroutine
def delay(self, iterations, arg):
"""Returns arg after a number of IOLoop iterations."""
for i in range(iterations):
yield gen.moment
raise gen.Return(arg)
with ignore_deprecation():
@return_future
def async_future(self, result, callback):
self.io_loop.add_callback(callback, result)
@gen.coroutine
def async_exception(self, e):
yield gen.moment
raise e
@gen.coroutine
def add_one_async(self, x):
yield gen.moment
raise gen.Return(x + 1)
def test_no_yield(self):
@gen.coroutine
def f():
pass
self.io_loop.run_sync(f)
def test_exception_phase1(self):
@gen.coroutine
def f():
1 / 0
self.assertRaises(ZeroDivisionError, self.io_loop.run_sync, f)
def test_exception_phase2(self):
@gen.coroutine
def f():
yield gen.moment
1 / 0
self.assertRaises(ZeroDivisionError, self.io_loop.run_sync, f)
def test_bogus_yield(self):
@gen.coroutine
def f():
yield 42
self.assertRaises(gen.BadYieldError, self.io_loop.run_sync, f)
def test_bogus_yield_tuple(self):
@gen.coroutine
def f():
yield (1, 2)
self.assertRaises(gen.BadYieldError, self.io_loop.run_sync, f)
def test_reuse(self):
@gen.coroutine
def f():
yield gen.moment
self.io_loop.run_sync(f)
self.io_loop.run_sync(f)
def test_none(self):
@gen.coroutine
def f():
yield None
self.io_loop.run_sync(f)
def test_multi(self):
@gen.coroutine
def f():
results = yield [self.add_one_async(1), self.add_one_async(2)]
self.assertEqual(results, [2, 3])
self.io_loop.run_sync(f)
def test_multi_dict(self):
@gen.coroutine
def f():
results = yield dict(foo=self.add_one_async(1), bar=self.add_one_async(2))
self.assertEqual(results, dict(foo=2, bar=3))
self.io_loop.run_sync(f)
def test_multi_delayed(self):
@gen.coroutine
def f():
# callbacks run at different times
responses = yield gen.multi_future([
self.delay(3, "v1"),
self.delay(1, "v2"),
])
self.assertEqual(responses, ["v1", "v2"])
self.io_loop.run_sync(f)
def test_multi_dict_delayed(self):
@gen.coroutine
def f():
# callbacks run at different times
responses = yield gen.multi_future(dict(
foo=self.delay(3, "v1"),
bar=self.delay(1, "v2"),
))
self.assertEqual(responses, dict(foo="v1", bar="v2"))
self.io_loop.run_sync(f)
@skipOnTravis
@gen_test
def test_multi_performance(self):
# Yielding a list used to have quadratic performance; make
# sure a large list stays reasonable. On my laptop a list of
# 2000 used to take 1.8s, now it takes 0.12.
start = time.time()
yield [gen.moment for i in range(2000)]
end = time.time()
self.assertLess(end - start, 1.0)
@gen_test
def test_multi_empty(self):
# Empty lists or dicts should return the same type.
x = yield []
self.assertTrue(isinstance(x, list))
y = yield {}
self.assertTrue(isinstance(y, dict))
@gen_test
def test_future(self):
result = yield self.async_future(1)
self.assertEqual(result, 1)
@gen_test
def test_multi_future(self):
results = yield [self.async_future(1), self.async_future(2)]
self.assertEqual(results, [1, 2])
@gen_test
def test_multi_future_duplicate(self):
f = self.async_future(2)
results = yield [self.async_future(1), f, self.async_future(3), f]
self.assertEqual(results, [1, 2, 3, 2])
@gen_test
def test_multi_dict_future(self):
results = yield dict(foo=self.async_future(1), bar=self.async_future(2))
self.assertEqual(results, dict(foo=1, bar=2))
@gen_test
def test_multi_exceptions(self):
with ExpectLog(app_log, "Multiple exceptions in yield list"):
with self.assertRaises(RuntimeError) as cm:
yield gen.Multi([self.async_exception(RuntimeError("error 1")),
self.async_exception(RuntimeError("error 2"))])
self.assertEqual(str(cm.exception), "error 1")
# With only one exception, no error is logged.
with self.assertRaises(RuntimeError):
yield gen.Multi([self.async_exception(RuntimeError("error 1")),
self.async_future(2)])
# Exception logging may be explicitly quieted.
with self.assertRaises(RuntimeError):
yield gen.Multi([self.async_exception(RuntimeError("error 1")),
self.async_exception(RuntimeError("error 2"))],
quiet_exceptions=RuntimeError)
@gen_test
def test_multi_future_exceptions(self):
with ExpectLog(app_log, "Multiple exceptions in yield list"):
with self.assertRaises(RuntimeError) as cm:
yield [self.async_exception(RuntimeError("error 1")),
self.async_exception(RuntimeError("error 2"))]
self.assertEqual(str(cm.exception), "error 1")
# With only one exception, no error is logged.
with self.assertRaises(RuntimeError):
yield [self.async_exception(RuntimeError("error 1")),
self.async_future(2)]
# Exception logging may be explicitly quieted.
with self.assertRaises(RuntimeError):
yield gen.multi_future(
[self.async_exception(RuntimeError("error 1")),
self.async_exception(RuntimeError("error 2"))],
quiet_exceptions=RuntimeError)
def test_sync_raise_return(self):
@gen.coroutine
def f():
raise gen.Return()
self.io_loop.run_sync(f)
def test_async_raise_return(self):
@gen.coroutine
def f():
yield gen.moment
raise gen.Return()
self.io_loop.run_sync(f)
def test_sync_raise_return_value(self):
@gen.coroutine
def f():
raise gen.Return(42)
self.assertEqual(42, self.io_loop.run_sync(f))
def test_sync_raise_return_value_tuple(self):
@gen.coroutine
def f():
raise gen.Return((1, 2))
self.assertEqual((1, 2), self.io_loop.run_sync(f))
def test_async_raise_return_value(self):
@gen.coroutine
def f():
yield gen.moment
raise gen.Return(42)
self.assertEqual(42, self.io_loop.run_sync(f))
def test_async_raise_return_value_tuple(self):
@gen.coroutine
def f():
yield gen.moment
raise gen.Return((1, 2))
self.assertEqual((1, 2), self.io_loop.run_sync(f))
class GenCoroutineTest(AsyncTestCase):
def setUp(self):
# Stray StopIteration exceptions can lead to tests exiting prematurely,
@ -657,6 +911,28 @@ class GenCoroutineTest(AsyncTestCase):
super(GenCoroutineTest, self).tearDown()
assert self.finished
def test_attributes(self):
self.finished = True
def f():
yield gen.moment
coro = gen.coroutine(f)
self.assertEqual(coro.__name__, f.__name__)
self.assertEqual(coro.__module__, f.__module__)
self.assertIs(coro.__wrapped__, f)
def test_is_coroutine_function(self):
self.finished = True
def f():
yield gen.moment
coro = gen.coroutine(f)
self.assertFalse(gen.is_coroutine_function(f))
self.assertTrue(gen.is_coroutine_function(coro))
self.assertFalse(gen.is_coroutine_function(coro()))
@gen_test
def test_sync_gen_return(self):
@gen.coroutine
@ -670,7 +946,7 @@ class GenCoroutineTest(AsyncTestCase):
def test_async_gen_return(self):
@gen.coroutine
def f():
yield gen.Task(self.io_loop.add_callback)
yield gen.moment
raise gen.Return(42)
result = yield f()
self.assertEqual(result, 42)
@ -691,7 +967,7 @@ class GenCoroutineTest(AsyncTestCase):
namespace = exec_test(globals(), locals(), """
@gen.coroutine
def f():
yield gen.Task(self.io_loop.add_callback)
yield gen.moment
return 42
""")
result = yield namespace['f']()
@ -718,12 +994,32 @@ class GenCoroutineTest(AsyncTestCase):
@skipBefore35
@gen_test
def test_async_await(self):
@gen.coroutine
def f1():
yield gen.moment
raise gen.Return(42)
# This test verifies that an async function can await a
# yield-based gen.coroutine, and that a gen.coroutine
# (the test method itself) can yield an async function.
namespace = exec_test(globals(), locals(), """
async def f2():
result = await f1()
return result
""")
result = yield namespace['f2']()
self.assertEqual(result, 42)
self.finished = True
@skipBefore35
@gen_test
def test_asyncio_sleep_zero(self):
# asyncio.sleep(0) turns into a special case (equivalent to
# `yield None`)
namespace = exec_test(globals(), locals(), """
async def f():
await gen.Task(self.io_loop.add_callback)
import asyncio
await asyncio.sleep(0)
return 42
""")
result = yield namespace['f']()
@ -733,18 +1029,22 @@ class GenCoroutineTest(AsyncTestCase):
@skipBefore35
@gen_test
def test_async_await_mixed_multi_native_future(self):
@gen.coroutine
def f1():
yield gen.moment
namespace = exec_test(globals(), locals(), """
async def f1():
await gen.Task(self.io_loop.add_callback)
async def f2():
await f1()
return 42
""")
@gen.coroutine
def f2():
yield gen.Task(self.io_loop.add_callback)
def f3():
yield gen.moment
raise gen.Return(43)
results = yield [namespace['f1'](), f2()]
results = yield [namespace['f2'](), f3()]
self.assertEqual(results, [42, 43])
self.finished = True
@ -762,11 +1062,25 @@ class GenCoroutineTest(AsyncTestCase):
yield gen.Task(self.io_loop.add_callback)
raise gen.Return(43)
f2(callback=(yield gen.Callback('cb')))
results = yield [namespace['f1'](), gen.Wait('cb')]
with ignore_deprecation():
f2(callback=(yield gen.Callback('cb')))
results = yield [namespace['f1'](), gen.Wait('cb')]
self.assertEqual(results, [42, 43])
self.finished = True
@skipBefore35
@gen_test
def test_async_with_timeout(self):
namespace = exec_test(globals(), locals(), """
async def f1():
return 42
""")
result = yield gen.with_timeout(datetime.timedelta(hours=1),
namespace['f1']())
self.assertEqual(result, 42)
self.finished = True
@gen_test
def test_sync_return_no_value(self):
@gen.coroutine
@ -781,7 +1095,7 @@ class GenCoroutineTest(AsyncTestCase):
# Without a return value we don't need python 3.3.
@gen.coroutine
def f():
yield gen.Task(self.io_loop.add_callback)
yield gen.moment
return
result = yield f()
self.assertEqual(result, None)
@ -804,7 +1118,7 @@ class GenCoroutineTest(AsyncTestCase):
def test_async_raise(self):
@gen.coroutine
def f():
yield gen.Task(self.io_loop.add_callback)
yield gen.moment
1 / 0
future = f()
with self.assertRaises(ZeroDivisionError):
@ -813,10 +1127,11 @@ class GenCoroutineTest(AsyncTestCase):
@gen_test
def test_pass_callback(self):
@gen.coroutine
def f():
raise gen.Return(42)
result = yield gen.Task(f)
with ignore_deprecation():
@gen.coroutine
def f():
raise gen.Return(42)
result = yield gen.Task(f)
self.assertEqual(result, 42)
self.finished = True
@ -861,46 +1176,48 @@ class GenCoroutineTest(AsyncTestCase):
@gen_test
def test_replace_context_exception(self):
# Test exception handling: exceptions thrown into the stack context
# can be caught and replaced.
# Note that this test and the following are for behavior that is
# not really supported any more: coroutines no longer create a
# stack context automatically; but one is created after the first
# YieldPoint (i.e. not a Future).
@gen.coroutine
def f2():
(yield gen.Callback(1))()
yield gen.Wait(1)
self.io_loop.add_callback(lambda: 1 / 0)
try:
yield gen.Task(self.io_loop.add_timeout,
self.io_loop.time() + 10)
except ZeroDivisionError:
raise KeyError()
with ignore_deprecation():
# Test exception handling: exceptions thrown into the stack context
# can be caught and replaced.
# Note that this test and the following are for behavior that is
# not really supported any more: coroutines no longer create a
# stack context automatically; but one is created after the first
# YieldPoint (i.e. not a Future).
@gen.coroutine
def f2():
(yield gen.Callback(1))()
yield gen.Wait(1)
self.io_loop.add_callback(lambda: 1 / 0)
try:
yield gen.Task(self.io_loop.add_timeout,
self.io_loop.time() + 10)
except ZeroDivisionError:
raise KeyError()
future = f2()
with self.assertRaises(KeyError):
yield future
self.finished = True
future = f2()
with self.assertRaises(KeyError):
yield future
self.finished = True
@gen_test
def test_swallow_context_exception(self):
# Test exception handling: exceptions thrown into the stack context
# can be caught and ignored.
@gen.coroutine
def f2():
(yield gen.Callback(1))()
yield gen.Wait(1)
self.io_loop.add_callback(lambda: 1 / 0)
try:
yield gen.Task(self.io_loop.add_timeout,
self.io_loop.time() + 10)
except ZeroDivisionError:
raise gen.Return(42)
with ignore_deprecation():
# Test exception handling: exceptions thrown into the stack context
# can be caught and ignored.
@gen.coroutine
def f2():
(yield gen.Callback(1))()
yield gen.Wait(1)
self.io_loop.add_callback(lambda: 1 / 0)
try:
yield gen.Task(self.io_loop.add_timeout,
self.io_loop.time() + 10)
except ZeroDivisionError:
raise gen.Return(42)
result = yield f2()
self.assertEqual(result, 42)
self.finished = True
result = yield f2()
self.assertEqual(result, 42)
self.finished = True
@gen_test
def test_moment(self):
@ -957,77 +1274,132 @@ class GenCoroutineTest(AsyncTestCase):
self.finished = True
@skipNotCPython
@unittest.skipIf((3,) < sys.version_info < (3, 6),
"asyncio.Future has reference cycles")
def test_coroutine_refcounting(self):
# On CPython, tasks and their arguments should be released immediately
# without waiting for garbage collection.
@gen.coroutine
def inner():
class Foo(object):
pass
local_var = Foo()
self.local_ref = weakref.ref(local_var)
yield gen.coroutine(lambda: None)()
raise ValueError('Some error')
@gen.coroutine
def inner2():
try:
yield inner()
except ValueError:
pass
self.io_loop.run_sync(inner2, timeout=3)
self.assertIs(self.local_ref(), None)
self.finished = True
@unittest.skipIf(sys.version_info < (3,),
"test only relevant with asyncio Futures")
def test_asyncio_future_debug_info(self):
self.finished = True
# Enable debug mode
asyncio_loop = asyncio.get_event_loop()
self.addCleanup(asyncio_loop.set_debug, asyncio_loop.get_debug())
asyncio_loop.set_debug(True)
def f():
yield gen.moment
coro = gen.coroutine(f)()
self.assertIsInstance(coro, asyncio.Future)
# We expect the coroutine repr() to show the place where
# it was instantiated
expected = ("created at %s:%d"
% (__file__, f.__code__.co_firstlineno + 3))
actual = repr(coro)
self.assertIn(expected, actual)
@unittest.skipIf(asyncio is None, "asyncio module not present")
@gen_test
def test_asyncio_gather(self):
# This demonstrates that tornado coroutines can be understood
# by asyncio (This failed prior to Tornado 5.0).
@gen.coroutine
def f():
yield gen.moment
raise gen.Return(1)
ret = yield asyncio.gather(f(), f())
self.assertEqual(ret, [1, 1])
self.finished = True
class GenSequenceHandler(RequestHandler):
@asynchronous
@gen.engine
def get(self):
self.io_loop = self.request.connection.stream.io_loop
self.io_loop.add_callback((yield gen.Callback("k1")))
yield gen.Wait("k1")
self.write("1")
self.io_loop.add_callback((yield gen.Callback("k2")))
yield gen.Wait("k2")
self.write("2")
# reuse an old key
self.io_loop.add_callback((yield gen.Callback("k1")))
yield gen.Wait("k1")
self.finish("3")
with ignore_deprecation():
@asynchronous
@gen.engine
def get(self):
# The outer ignore_deprecation applies at definition time.
# We need another for serving time.
with ignore_deprecation():
self.io_loop = self.request.connection.stream.io_loop
self.io_loop.add_callback((yield gen.Callback("k1")))
yield gen.Wait("k1")
self.write("1")
self.io_loop.add_callback((yield gen.Callback("k2")))
yield gen.Wait("k2")
self.write("2")
# reuse an old key
self.io_loop.add_callback((yield gen.Callback("k1")))
yield gen.Wait("k1")
self.finish("3")
class GenCoroutineSequenceHandler(RequestHandler):
@gen.coroutine
def get(self):
self.io_loop = self.request.connection.stream.io_loop
self.io_loop.add_callback((yield gen.Callback("k1")))
yield gen.Wait("k1")
yield gen.moment
self.write("1")
self.io_loop.add_callback((yield gen.Callback("k2")))
yield gen.Wait("k2")
yield gen.moment
self.write("2")
# reuse an old key
self.io_loop.add_callback((yield gen.Callback("k1")))
yield gen.Wait("k1")
yield gen.moment
self.finish("3")
class GenCoroutineUnfinishedSequenceHandler(RequestHandler):
@asynchronous
@gen.coroutine
def get(self):
self.io_loop = self.request.connection.stream.io_loop
self.io_loop.add_callback((yield gen.Callback("k1")))
yield gen.Wait("k1")
yield gen.moment
self.write("1")
self.io_loop.add_callback((yield gen.Callback("k2")))
yield gen.Wait("k2")
yield gen.moment
self.write("2")
# reuse an old key
self.io_loop.add_callback((yield gen.Callback("k1")))
yield gen.Wait("k1")
yield gen.moment
# just write, don't finish
self.write("3")
class GenTaskHandler(RequestHandler):
@asynchronous
@gen.engine
@gen.coroutine
def get(self):
io_loop = self.request.connection.stream.io_loop
client = AsyncHTTPClient(io_loop=io_loop)
response = yield gen.Task(client.fetch, self.get_argument('url'))
client = AsyncHTTPClient()
with ignore_deprecation():
response = yield gen.Task(client.fetch, self.get_argument('url'))
response.rethrow()
self.finish(b"got response: " + response.body)
class GenExceptionHandler(RequestHandler):
@asynchronous
@gen.engine
def get(self):
# This test depends on the order of the two decorators.
io_loop = self.request.connection.stream.io_loop
yield gen.Task(io_loop.add_callback)
raise Exception("oops")
with ignore_deprecation():
@asynchronous
@gen.engine
def get(self):
# This test depends on the order of the two decorators.
io_loop = self.request.connection.stream.io_loop
yield gen.Task(io_loop.add_callback)
raise Exception("oops")
class GenCoroutineExceptionHandler(RequestHandler):
@ -1040,19 +1412,18 @@ class GenCoroutineExceptionHandler(RequestHandler):
class GenYieldExceptionHandler(RequestHandler):
@asynchronous
@gen.engine
@gen.coroutine
def get(self):
io_loop = self.request.connection.stream.io_loop
# Test the interaction of the two stack_contexts.
def fail_task(callback):
io_loop.add_callback(lambda: 1 / 0)
try:
yield gen.Task(fail_task)
raise Exception("did not get expected exception")
except ZeroDivisionError:
self.finish('ok')
with ignore_deprecation():
def fail_task(callback):
io_loop.add_callback(lambda: 1 / 0)
try:
yield gen.Task(fail_task)
raise Exception("did not get expected exception")
except ZeroDivisionError:
self.finish('ok')
# "Undecorated" here refers to the absence of @asynchronous.
@ -1060,22 +1431,22 @@ class UndecoratedCoroutinesHandler(RequestHandler):
@gen.coroutine
def prepare(self):
self.chunks = []
yield gen.Task(IOLoop.current().add_callback)
yield gen.moment
self.chunks.append('1')
@gen.coroutine
def get(self):
self.chunks.append('2')
yield gen.Task(IOLoop.current().add_callback)
yield gen.moment
self.chunks.append('3')
yield gen.Task(IOLoop.current().add_callback)
yield gen.moment
self.write(''.join(self.chunks))
class AsyncPrepareErrorHandler(RequestHandler):
@gen.coroutine
def prepare(self):
yield gen.Task(IOLoop.current().add_callback)
yield gen.moment
raise HTTPError(403)
def get(self):
@ -1086,7 +1457,8 @@ class NativeCoroutineHandler(RequestHandler):
if sys.version_info > (3, 5):
exec(textwrap.dedent("""
async def get(self):
await gen.Task(IOLoop.current().add_callback)
import asyncio
await asyncio.sleep(0)
self.write("ok")
"""))
@ -1167,7 +1539,7 @@ class WithTimeoutTest(AsyncTestCase):
self.io_loop.add_timeout(datetime.timedelta(seconds=0.1),
lambda: future.set_result('asdf'))
result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
future, io_loop=self.io_loop)
future)
self.assertEqual(result, 'asdf')
@gen_test
@ -1178,19 +1550,20 @@ class WithTimeoutTest(AsyncTestCase):
lambda: future.set_exception(ZeroDivisionError()))
with self.assertRaises(ZeroDivisionError):
yield gen.with_timeout(datetime.timedelta(seconds=3600),
future, io_loop=self.io_loop)
future)
@gen_test
def test_already_resolved(self):
future = Future()
future.set_result('asdf')
result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
future, io_loop=self.io_loop)
future)
self.assertEqual(result, 'asdf')
@unittest.skipIf(futures is None, 'futures module not present')
@gen_test
def test_timeout_concurrent_future(self):
# A concurrent future that does not resolve before the timeout.
with futures.ThreadPoolExecutor(1) as executor:
with self.assertRaises(gen.TimeoutError):
yield gen.with_timeout(self.io_loop.time(),
@ -1199,9 +1572,20 @@ class WithTimeoutTest(AsyncTestCase):
@unittest.skipIf(futures is None, 'futures module not present')
@gen_test
def test_completed_concurrent_future(self):
# A concurrent future that is resolved before we even submit it
# to with_timeout.
with futures.ThreadPoolExecutor(1) as executor:
f = executor.submit(lambda: None)
f.result() # wait for completion
yield gen.with_timeout(datetime.timedelta(seconds=3600), f)
@unittest.skipIf(futures is None, 'futures module not present')
@gen_test
def test_normal_concurrent_future(self):
# A conccurrent future that resolves while waiting for the timeout.
with futures.ThreadPoolExecutor(1) as executor:
yield gen.with_timeout(datetime.timedelta(seconds=3600),
executor.submit(lambda: None))
executor.submit(lambda: time.sleep(0.01)))
class WaitIteratorTest(AsyncTestCase):
@ -1355,5 +1739,124 @@ class WaitIteratorTest(AsyncTestCase):
gen.WaitIterator(gen.sleep(0)).next())
class RunnerGCTest(AsyncTestCase):
def is_pypy3(self):
return (platform.python_implementation() == 'PyPy' and
sys.version_info > (3,))
@gen_test
def test_gc(self):
# Github issue 1769: Runner objects can get GCed unexpectedly
# while their future is alive.
weakref_scope = [None]
def callback():
gc.collect(2)
weakref_scope[0]().set_result(123)
@gen.coroutine
def tester():
fut = Future()
weakref_scope[0] = weakref.ref(fut)
self.io_loop.add_callback(callback)
yield fut
yield gen.with_timeout(
datetime.timedelta(seconds=0.2),
tester()
)
def test_gc_infinite_coro(self):
# Github issue 2229: suspended coroutines should be GCed when
# their loop is closed, even if they're involved in a reference
# cycle.
if IOLoop.configured_class().__name__.endswith('TwistedIOLoop'):
raise unittest.SkipTest("Test may fail on TwistedIOLoop")
loop = self.get_new_ioloop()
result = []
wfut = []
@gen.coroutine
def infinite_coro():
try:
while True:
yield gen.sleep(1e-3)
result.append(True)
finally:
# coroutine finalizer
result.append(None)
@gen.coroutine
def do_something():
fut = infinite_coro()
fut._refcycle = fut
wfut.append(weakref.ref(fut))
yield gen.sleep(0.2)
loop.run_sync(do_something)
loop.close()
gc.collect()
# Future was collected
self.assertIs(wfut[0](), None)
# At least one wakeup
self.assertGreaterEqual(len(result), 2)
if not self.is_pypy3():
# coroutine finalizer was called (not on PyPy3 apparently)
self.assertIs(result[-1], None)
@skipBefore35
def test_gc_infinite_async_await(self):
# Same as test_gc_infinite_coro, but with a `async def` function
import asyncio
namespace = exec_test(globals(), locals(), """
async def infinite_coro(result):
try:
while True:
await gen.sleep(1e-3)
result.append(True)
finally:
# coroutine finalizer
result.append(None)
""")
infinite_coro = namespace['infinite_coro']
loop = self.get_new_ioloop()
result = []
wfut = []
@gen.coroutine
def do_something():
fut = asyncio.get_event_loop().create_task(infinite_coro(result))
fut._refcycle = fut
wfut.append(weakref.ref(fut))
yield gen.sleep(0.2)
loop.run_sync(do_something)
with ExpectLog('asyncio', "Task was destroyed but it is pending"):
loop.close()
gc.collect()
# Future was collected
self.assertIs(wfut[0](), None)
# At least one wakeup and one finally
self.assertGreaterEqual(len(result), 2)
if not self.is_pypy3():
# coroutine finalizer was called (not on PyPy3 apparently)
self.assertIs(result[-1], None)
def test_multi_moment(self):
# Test gen.multi with moment
# now that it's not a real Future
@gen.coroutine
def wait_a_moment():
result = yield gen.multi([gen.moment, gen.moment])
raise gen.Return(result)
loop = self.get_new_ioloop()
result = loop.run_sync(wait_a_moment)
self.assertEqual(result, [None, None])
if __name__ == '__main__':
unittest.main()

View file

@ -0,0 +1,61 @@
from __future__ import absolute_import, division, print_function
import socket
from tornado.http1connection import HTTP1Connection
from tornado.httputil import HTTPMessageDelegate
from tornado.iostream import IOStream
from tornado.locks import Event
from tornado.netutil import add_accept_handler
from tornado.testing import AsyncTestCase, bind_unused_port, gen_test
class HTTP1ConnectionTest(AsyncTestCase):
def setUp(self):
super(HTTP1ConnectionTest, self).setUp()
self.asyncSetUp()
@gen_test
def asyncSetUp(self):
listener, port = bind_unused_port()
event = Event()
def accept_callback(conn, addr):
self.server_stream = IOStream(conn)
self.addCleanup(self.server_stream.close)
event.set()
add_accept_handler(listener, accept_callback)
self.client_stream = IOStream(socket.socket())
self.addCleanup(self.client_stream.close)
yield [self.client_stream.connect(('127.0.0.1', port)),
event.wait()]
self.io_loop.remove_handler(listener)
listener.close()
@gen_test
def test_http10_no_content_length(self):
# Regression test for a bug in which can_keep_alive would crash
# for an HTTP/1.0 (not 1.1) response with no content-length.
conn = HTTP1Connection(self.client_stream, True)
self.server_stream.write(b"HTTP/1.0 200 Not Modified\r\n\r\nhello")
self.server_stream.close()
event = Event()
test = self
body = []
class Delegate(HTTPMessageDelegate):
def headers_received(self, start_line, headers):
test.code = start_line.code
def data_received(self, data):
body.append(data)
def finish(self):
event.set()
yield conn.read_response(Delegate())
yield event.wait()
self.assertEqual(self.code, 200)
self.assertEqual(b''.join(body), b'hello')

View file

@ -1,18 +1,18 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import base64
import binascii
from contextlib import closing
import copy
import functools
import sys
import threading
import datetime
from io import BytesIO
import time
import unicodedata
from tornado.escape import utf8
from tornado.escape import utf8, native_str
from tornado import gen
from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient
from tornado.httpserver import HTTPServer
@ -22,8 +22,7 @@ from tornado.log import gen_log
from tornado import netutil
from tornado.stack_context import ExceptionStackContext, NullContext
from tornado.testing import AsyncHTTPTestCase, bind_unused_port, gen_test, ExpectLog
from tornado.test.util import unittest, skipOnTravis
from tornado.util import u
from tornado.test.util import unittest, skipOnTravis, ignore_deprecation
from tornado.web import Application, RequestHandler, url
from tornado.httputil import format_timestamp, HTTPHeaders
@ -114,6 +113,15 @@ class AllMethodsHandler(RequestHandler):
get = post = put = delete = options = patch = other = method
class SetHeaderHandler(RequestHandler):
def get(self):
# Use get_arguments for keys to get strings, but
# request.arguments for values to get bytes.
for k, v in zip(self.get_arguments('k'),
self.request.arguments['v']):
self.set_header(k, v)
# These tests end up getting run redundantly: once here with the default
# HTTPClient implementation, and then again in each implementation's own
# test suite.
@ -134,6 +142,7 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase):
url("/304_with_content_length", ContentLength304Handler),
url("/all_methods", AllMethodsHandler),
url('/patch', PatchHandler),
url('/set_header', SetHeaderHandler),
], gzip=True)
def test_patch_receives_payload(self):
@ -183,10 +192,15 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase):
# over several ioloop iterations, but the connection is already closed.
sock, port = bind_unused_port()
with closing(sock):
def write_response(stream, request_data):
@gen.coroutine
def accept_callback(conn, address):
# fake an HTTP server using chunked encoding where the final chunks
# and connection close all happen at once
stream = IOStream(conn)
request_data = yield stream.read_until(b"\r\n\r\n")
if b"HTTP/1." not in request_data:
self.skipTest("requires HTTP/1.x")
stream.write(b"""\
yield stream.write(b"""\
HTTP/1.1 200 OK
Transfer-Encoding: chunked
@ -196,17 +210,10 @@ Transfer-Encoding: chunked
2
0
""".replace(b"\n", b"\r\n"), callback=stream.close)
def accept_callback(conn, address):
# fake an HTTP server using chunked encoding where the final chunks
# and connection close all happen at once
stream = IOStream(conn, io_loop=self.io_loop)
stream.read_until(b"\r\n\r\n",
functools.partial(write_response, stream))
netutil.add_accept_handler(sock, accept_callback, self.io_loop)
self.http_client.fetch("http://127.0.0.1:%d/" % port, self.stop)
resp = self.wait()
""".replace(b"\n", b"\r\n"))
stream.close()
netutil.add_accept_handler(sock, accept_callback)
resp = self.fetch("http://127.0.0.1:%d/" % port)
resp.rethrow()
self.assertEqual(resp.body, b"12")
self.io_loop.remove_handler(sock.fileno())
@ -224,14 +231,16 @@ Transfer-Encoding: chunked
if chunk == b'qwer':
1 / 0
with ExceptionStackContext(error_handler):
self.fetch('/chunk', streaming_callback=streaming_cb)
with ignore_deprecation():
with ExceptionStackContext(error_handler):
self.fetch('/chunk', streaming_callback=streaming_cb)
self.assertEqual(chunks, [b'asdf', b'qwer'])
self.assertEqual(1, len(exc_info))
self.assertIs(exc_info[0][0], ZeroDivisionError)
def test_basic_auth(self):
# This test data appears in section 2 of RFC 7617.
self.assertEqual(self.fetch("/auth", auth_username="Aladdin",
auth_password="open sesame").body,
b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
@ -242,16 +251,30 @@ Transfer-Encoding: chunked
auth_mode="basic").body,
b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
def test_basic_auth_unicode(self):
# This test data appears in section 2.1 of RFC 7617.
self.assertEqual(self.fetch("/auth", auth_username="test",
auth_password="123£").body,
b"Basic dGVzdDoxMjPCow==")
# The standard mandates NFC. Give it a decomposed username
# and ensure it is normalized to composed form.
username = unicodedata.normalize("NFD", u"josé")
self.assertEqual(self.fetch("/auth",
auth_username=username,
auth_password="səcrət").body,
b"Basic am9zw6k6c8mZY3LJmXQ=")
def test_unsupported_auth_mode(self):
# curl and simple clients handle errors a bit differently; the
# important thing is that they don't fall back to basic auth
# on an unknown mode.
with ExpectLog(gen_log, "uncaught exception", required=False):
with self.assertRaises((ValueError, HTTPError)):
response = self.fetch("/auth", auth_username="Aladdin",
auth_password="open sesame",
auth_mode="asdf")
response.rethrow()
self.fetch("/auth", auth_username="Aladdin",
auth_password="open sesame",
auth_mode="asdf",
raise_error=True)
def test_follow_redirect(self):
response = self.fetch("/countdown/2", follow_redirects=False)
@ -265,13 +288,12 @@ Transfer-Encoding: chunked
def test_credentials_in_url(self):
url = self.get_url("/auth").replace("http://", "http://me:secret@")
self.http_client.fetch(url, self.stop)
response = self.wait()
response = self.fetch(url)
self.assertEqual(b"Basic " + base64.b64encode(b"me:secret"),
response.body)
def test_body_encoding(self):
unicode_body = u("\xe9")
unicode_body = u"\xe9"
byte_body = binascii.a2b_hex(b"e9")
# unicode string in body gets converted to utf8
@ -291,7 +313,7 @@ Transfer-Encoding: chunked
# break anything
response = self.fetch("/echopost", method="POST", body=byte_body,
headers={"Content-Type": "application/blah"},
user_agent=u("foo"))
user_agent=u"foo")
self.assertEqual(response.headers["Content-Length"], "1")
self.assertEqual(response.body, byte_body)
@ -341,19 +363,20 @@ Transfer-Encoding: chunked
if header_line.lower().startswith('content-type:'):
1 / 0
with ExceptionStackContext(error_handler):
self.fetch('/chunk', header_callback=header_callback)
with ignore_deprecation():
with ExceptionStackContext(error_handler):
self.fetch('/chunk', header_callback=header_callback)
self.assertEqual(len(exc_info), 1)
self.assertIs(exc_info[0][0], ZeroDivisionError)
@gen_test
def test_configure_defaults(self):
defaults = dict(user_agent='TestDefaultUserAgent', allow_ipv6=False)
# Construct a new instance of the configured client class
client = self.http_client.__class__(self.io_loop, force_instance=True,
client = self.http_client.__class__(force_instance=True,
defaults=defaults)
try:
client.fetch(self.get_url('/user_agent'), callback=self.stop)
response = self.wait()
response = yield client.fetch(self.get_url('/user_agent'))
self.assertEqual(response.body, b'TestDefaultUserAgent')
finally:
client.close()
@ -363,7 +386,7 @@ Transfer-Encoding: chunked
# in a plain dictionary or an HTTPHeaders object.
# Keys must always be the native str type.
# All combinations should have the same results on the wire.
for value in [u("MyUserAgent"), b"MyUserAgent"]:
for value in [u"MyUserAgent", b"MyUserAgent"]:
for container in [dict, HTTPHeaders]:
headers = container()
headers['User-Agent'] = value
@ -378,23 +401,22 @@ Transfer-Encoding: chunked
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
sock, port = bind_unused_port()
with closing(sock):
def write_response(stream, request_data):
@gen.coroutine
def accept_callback(conn, address):
stream = IOStream(conn)
request_data = yield stream.read_until(b"\r\n\r\n")
if b"HTTP/1." not in request_data:
self.skipTest("requires HTTP/1.x")
stream.write(b"""\
yield stream.write(b"""\
HTTP/1.1 200 OK
X-XSS-Protection: 1;
\tmode=block
""".replace(b"\n", b"\r\n"), callback=stream.close)
""".replace(b"\n", b"\r\n"))
stream.close()
def accept_callback(conn, address):
stream = IOStream(conn, io_loop=self.io_loop)
stream.read_until(b"\r\n\r\n",
functools.partial(write_response, stream))
netutil.add_accept_handler(sock, accept_callback, self.io_loop)
self.http_client.fetch("http://127.0.0.1:%d/" % port, self.stop)
resp = self.wait()
netutil.add_accept_handler(sock, accept_callback)
resp = self.fetch("http://127.0.0.1:%d/" % port)
resp.rethrow()
self.assertEqual(resp.headers['X-XSS-Protection'], "1; mode=block")
self.io_loop.remove_handler(sock.fileno())
@ -424,8 +446,9 @@ X-XSS-Protection: 1;
self.stop()
self.io_loop.handle_callback_exception = handle_callback_exception
with NullContext():
self.http_client.fetch(self.get_url('/hello'),
lambda response: 1 / 0)
with ignore_deprecation():
self.http_client.fetch(self.get_url('/hello'),
lambda response: 1 / 0)
self.wait()
self.assertEqual(exc_info[0][0], ZeroDivisionError)
@ -476,8 +499,7 @@ X-XSS-Protection: 1;
# These methods require a body.
for method in ('POST', 'PUT', 'PATCH'):
with self.assertRaises(ValueError) as context:
resp = self.fetch('/all_methods', method=method)
resp.rethrow()
self.fetch('/all_methods', method=method, raise_error=True)
self.assertIn('must not be None', str(context.exception))
resp = self.fetch('/all_methods', method=method,
@ -487,16 +509,14 @@ X-XSS-Protection: 1;
# These methods don't allow a body.
for method in ('GET', 'DELETE', 'OPTIONS'):
with self.assertRaises(ValueError) as context:
resp = self.fetch('/all_methods', method=method, body=b'asdf')
resp.rethrow()
self.fetch('/all_methods', method=method, body=b'asdf', raise_error=True)
self.assertIn('must be None', str(context.exception))
# In most cases this can be overridden, but curl_httpclient
# does not allow body with a GET at all.
if method != 'GET':
resp = self.fetch('/all_methods', method=method, body=b'asdf',
allow_nonstandard_methods=True)
resp.rethrow()
self.fetch('/all_methods', method=method, body=b'asdf',
allow_nonstandard_methods=True, raise_error=True)
self.assertEqual(resp.code, 200)
# This test causes odd failures with the combination of
@ -521,6 +541,28 @@ X-XSS-Protection: 1;
response.rethrow()
self.assertEqual(response.body, b"Put body: hello")
def test_non_ascii_header(self):
# Non-ascii headers are sent as latin1.
response = self.fetch("/set_header?k=foo&v=%E9")
response.rethrow()
self.assertEqual(response.headers["Foo"], native_str(u"\u00e9"))
def test_response_times(self):
# A few simple sanity checks of the response time fields to
# make sure they're using the right basis (between the
# wall-time and monotonic clocks).
start_time = time.time()
response = self.fetch("/hello")
response.rethrow()
self.assertGreaterEqual(response.request_time, 0)
self.assertLess(response.request_time, 1.0)
# A very crude check to make sure that start_time is based on
# wall time and not the monotonic clock.
self.assertLess(abs(response.start_time - start_time), 1.0)
for k, v in response.time_info.items():
self.assertTrue(0 <= v < 1.0, "time_info[%s] out of bounds: %s" % (k, v))
class RequestProxyTest(unittest.TestCase):
def test_request_set(self):
@ -567,22 +609,20 @@ class HTTPResponseTestCase(unittest.TestCase):
class SyncHTTPClientTest(unittest.TestCase):
def setUp(self):
if IOLoop.configured_class().__name__ in ('TwistedIOLoop',
'AsyncIOMainLoop'):
if IOLoop.configured_class().__name__ == 'TwistedIOLoop':
# TwistedIOLoop only supports the global reactor, so we can't have
# separate IOLoops for client and server threads.
# AsyncIOMainLoop doesn't work with the default policy
# (although it could with some tweaks to this test and a
# policy that created loops for non-main threads).
raise unittest.SkipTest(
'Sync HTTPClient not compatible with TwistedIOLoop or '
'AsyncIOMainLoop')
'Sync HTTPClient not compatible with TwistedIOLoop')
self.server_ioloop = IOLoop()
sock, self.port = bind_unused_port()
app = Application([('/', HelloWorldHandler)])
self.server = HTTPServer(app, io_loop=self.server_ioloop)
self.server.add_socket(sock)
@gen.coroutine
def init_server():
sock, self.port = bind_unused_port()
app = Application([('/', HelloWorldHandler)])
self.server = HTTPServer(app)
self.server.add_socket(sock)
self.server_ioloop.run_sync(init_server)
self.server_thread = threading.Thread(target=self.server_ioloop.start)
self.server_thread.start()
@ -592,12 +632,20 @@ class SyncHTTPClientTest(unittest.TestCase):
def tearDown(self):
def stop_server():
self.server.stop()
# Delay the shutdown of the IOLoop by one iteration because
# Delay the shutdown of the IOLoop by several iterations because
# the server may still have some cleanup work left when
# the client finishes with the response (this is noticable
# the client finishes with the response (this is noticeable
# with http/2, which leaves a Future with an unexamined
# StreamClosedError on the loop).
self.server_ioloop.add_callback(self.server_ioloop.stop)
@gen.coroutine
def slow_stop():
# The number of iterations is difficult to predict. Typically,
# one is sufficient, although sometimes it needs more.
for i in range(5):
yield
self.server_ioloop.stop()
self.server_ioloop.add_callback(slow_stop)
self.server_ioloop.add_callback(stop_server)
self.server_thread.join()
self.http_client.close()
@ -656,6 +704,15 @@ class HTTPErrorTestCase(unittest.TestCase):
self.assertIsNot(e, e2)
self.assertEqual(e.code, e2.code)
def test_str(self):
def test_plain_error(self):
e = HTTPError(403)
self.assertEqual(str(e), "HTTP 403: Forbidden")
self.assertEqual(repr(e), "HTTP 403: Forbidden")
def test_error_with_response(self):
resp = HTTPResponse(HTTPRequest('http://example.com/'), 403)
with self.assertRaises(HTTPError) as cm:
resp.rethrow()
e = cm.exception
self.assertEqual(str(e), "HTTP 403: Forbidden")
self.assertEqual(repr(e), "HTTP 403: Forbidden")

View file

@ -1,21 +1,21 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
from __future__ import absolute_import, division, print_function, with_statement
from tornado import netutil
from tornado import gen, netutil
from tornado.concurrent import Future
from tornado.escape import json_decode, json_encode, utf8, _unicode, recursive_unicode, native_str
from tornado import gen
from tornado.http1connection import HTTP1Connection
from tornado.httpclient import HTTPError
from tornado.httpserver import HTTPServer
from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine
from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine # noqa: E501
from tornado.iostream import IOStream
from tornado.locks import Event
from tornado.log import gen_log
from tornado.netutil import ssl_options_to_context
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test # noqa: E501
from tornado.test.util import unittest, skipOnTravis
from tornado.util import u
from tornado.web import Application, RequestHandler, asynchronous, stream_request_body
from tornado.web import Application, RequestHandler, stream_request_body
from contextlib import closing
import datetime
import gzip
@ -30,18 +30,20 @@ from io import BytesIO
def read_stream_body(stream, callback):
"""Reads an HTTP response from `stream` and runs callback with its
headers and body."""
start_line, headers and body."""
chunks = []
class Delegate(HTTPMessageDelegate):
def headers_received(self, start_line, headers):
self.headers = headers
self.start_line = start_line
def data_received(self, chunk):
chunks.append(chunk)
def finish(self):
callback((self.headers, b''.join(chunks)))
conn.detach()
callback((self.start_line, self.headers, b''.join(chunks)))
conn = HTTP1Connection(stream, True)
conn.read_response(Delegate())
@ -87,7 +89,7 @@ class BaseSSLTest(AsyncHTTPSTestCase):
class SSLTestMixin(object):
def get_ssl_options(self):
return dict(ssl_version=self.get_ssl_version(),
return dict(ssl_version=self.get_ssl_version(), # type: ignore
**AsyncHTTPSTestCase.get_ssl_options())
def get_ssl_version(self):
@ -109,22 +111,19 @@ class SSLTestMixin(object):
# misbehaving.
with ExpectLog(gen_log, '(SSL Error|uncaught exception)'):
with ExpectLog(gen_log, 'Uncaught exception', required=False):
self.http_client.fetch(
self.get_url("/").replace('https:', 'http:'),
self.stop,
request_timeout=3600,
connect_timeout=3600)
response = self.wait()
self.assertEqual(response.code, 599)
with self.assertRaises((IOError, HTTPError)):
self.fetch(
self.get_url("/").replace('https:', 'http:'),
request_timeout=3600,
connect_timeout=3600,
raise_error=True)
def test_error_logging(self):
# No stack traces are logged for SSL errors.
with ExpectLog(gen_log, 'SSL Error') as expect_log:
self.http_client.fetch(
self.get_url("/").replace("https:", "http:"),
self.stop)
response = self.wait()
self.assertEqual(response.code, 599)
with self.assertRaises((IOError, HTTPError)):
self.fetch(self.get_url("/").replace("https:", "http:"),
raise_error=True)
self.assertFalse(expect_log.logged_stack)
# Python's SSL implementation differs significantly between versions.
@ -150,7 +149,6 @@ class TLSv1Test(BaseSSLTest, SSLTestMixin):
return ssl.PROTOCOL_TLSv1
@unittest.skipIf(not hasattr(ssl, 'SSLContext'), 'ssl.SSLContext not present')
class SSLContextTest(BaseSSLTest, SSLTestMixin):
def get_ssl_options(self):
context = ssl_options_to_context(
@ -211,14 +209,13 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
def raw_fetch(self, headers, body, newline=b"\r\n"):
with closing(IOStream(socket.socket())) as stream:
stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
self.wait()
self.io_loop.run_sync(lambda: stream.connect(('127.0.0.1', self.get_http_port())))
stream.write(
newline.join(headers +
[utf8("Content-Length: %d" % len(body))]) +
newline + newline + body)
read_stream_body(stream, self.stop)
headers, body = self.wait()
start_line, headers, body = self.wait()
return body
def test_multipart_form(self):
@ -232,19 +229,19 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
b"\r\n".join([
b"Content-Disposition: form-data; name=argument",
b"",
u("\u00e1").encode("utf-8"),
u"\u00e1".encode("utf-8"),
b"--1234567890",
u('Content-Disposition: form-data; name="files"; filename="\u00f3"').encode("utf8"),
u'Content-Disposition: form-data; name="files"; filename="\u00f3"'.encode("utf8"),
b"",
u("\u00fa").encode("utf-8"),
u"\u00fa".encode("utf-8"),
b"--1234567890--",
b"",
]))
data = json_decode(response)
self.assertEqual(u("\u00e9"), data["header"])
self.assertEqual(u("\u00e1"), data["argument"])
self.assertEqual(u("\u00f3"), data["filename"])
self.assertEqual(u("\u00fa"), data["filebody"])
self.assertEqual(u"\u00e9", data["header"])
self.assertEqual(u"\u00e1", data["argument"])
self.assertEqual(u"\u00f3", data["filename"])
self.assertEqual(u"\u00fa", data["filebody"])
def test_newlines(self):
# We support both CRLF and bare LF as line separators.
@ -253,31 +250,27 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
newline=newline)
self.assertEqual(response, b'Hello world')
@gen_test
def test_100_continue(self):
# Run through a 100-continue interaction by hand:
# When given Expect: 100-continue, we get a 100 response after the
# headers, and then the real response after the body.
stream = IOStream(socket.socket(), io_loop=self.io_loop)
stream.connect(("127.0.0.1", self.get_http_port()), callback=self.stop)
self.wait()
stream.write(b"\r\n".join([b"POST /hello HTTP/1.1",
b"Content-Length: 1024",
b"Expect: 100-continue",
b"Connection: close",
b"\r\n"]), callback=self.stop)
self.wait()
stream.read_until(b"\r\n\r\n", self.stop)
data = self.wait()
stream = IOStream(socket.socket())
yield stream.connect(("127.0.0.1", self.get_http_port()))
yield stream.write(b"\r\n".join([
b"POST /hello HTTP/1.1",
b"Content-Length: 1024",
b"Expect: 100-continue",
b"Connection: close",
b"\r\n"]))
data = yield stream.read_until(b"\r\n\r\n")
self.assertTrue(data.startswith(b"HTTP/1.1 100 "), data)
stream.write(b"a" * 1024)
stream.read_until(b"\r\n", self.stop)
first_line = self.wait()
first_line = yield stream.read_until(b"\r\n")
self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
stream.read_until(b"\r\n\r\n", self.stop)
header_data = self.wait()
header_data = yield stream.read_until(b"\r\n\r\n")
headers = HTTPHeaders.parse(native_str(header_data.decode('latin1')))
stream.read_bytes(int(headers["Content-Length"]), self.stop)
body = self.wait()
body = yield stream.read_bytes(int(headers["Content-Length"]))
self.assertEqual(body, b"Got 1024 bytes in POST")
stream.close()
@ -340,17 +333,17 @@ class HTTPServerTest(AsyncHTTPTestCase):
def test_query_string_encoding(self):
response = self.fetch("/echo?foo=%C3%A9")
data = json_decode(response.body)
self.assertEqual(data, {u("foo"): [u("\u00e9")]})
self.assertEqual(data, {u"foo": [u"\u00e9"]})
def test_empty_query_string(self):
response = self.fetch("/echo?foo=&foo=")
data = json_decode(response.body)
self.assertEqual(data, {u("foo"): [u(""), u("")]})
self.assertEqual(data, {u"foo": [u"", u""]})
def test_empty_post_parameters(self):
response = self.fetch("/echo", method="POST", body="foo=&bar=")
data = json_decode(response.body)
self.assertEqual(data, {u("foo"): [u("")], u("bar"): [u("")]})
self.assertEqual(data, {u"foo": [u""], u"bar": [u""]})
def test_types(self):
headers = {"Cookie": "foo=bar"}
@ -395,8 +388,7 @@ class HTTPServerRawTest(AsyncHTTPTestCase):
def setUp(self):
super(HTTPServerRawTest, self).setUp()
self.stream = IOStream(socket.socket())
self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
self.wait()
self.io_loop.run_sync(lambda: self.stream.connect(('127.0.0.1', self.get_http_port())))
def tearDown(self):
self.stream.close()
@ -407,19 +399,28 @@ class HTTPServerRawTest(AsyncHTTPTestCase):
self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
self.wait()
def test_malformed_first_line(self):
def test_malformed_first_line_response(self):
with ExpectLog(gen_log, '.*Malformed HTTP request line'):
self.stream.write(b'asdf\r\n\r\n')
read_stream_body(self.stream, self.stop)
start_line, headers, response = self.wait()
self.assertEqual('HTTP/1.1', start_line.version)
self.assertEqual(400, start_line.code)
self.assertEqual('Bad Request', start_line.reason)
def test_malformed_first_line_log(self):
with ExpectLog(gen_log, '.*Malformed HTTP request line'):
self.stream.write(b'asdf\r\n\r\n')
# TODO: need an async version of ExpectLog so we don't need
# hard-coded timeouts here.
self.io_loop.add_timeout(datetime.timedelta(seconds=0.01),
self.io_loop.add_timeout(datetime.timedelta(seconds=0.05),
self.stop)
self.wait()
def test_malformed_headers(self):
with ExpectLog(gen_log, '.*Malformed HTTP headers'):
with ExpectLog(gen_log, '.*Malformed HTTP message.*no colon in header line'):
self.stream.write(b'GET / HTTP/1.0\r\nasdf\r\n\r\n')
self.io_loop.add_timeout(datetime.timedelta(seconds=0.01),
self.io_loop.add_timeout(datetime.timedelta(seconds=0.05),
self.stop)
self.wait()
@ -439,18 +440,50 @@ bar
""".replace(b"\n", b"\r\n"))
read_stream_body(self.stream, self.stop)
headers, response = self.wait()
self.assertEqual(json_decode(response), {u('foo'): [u('bar')]})
start_line, headers, response = self.wait()
self.assertEqual(json_decode(response), {u'foo': [u'bar']})
def test_chunked_request_uppercase(self):
# As per RFC 2616 section 3.6, "Transfer-Encoding" header's value is
# case-insensitive.
self.stream.write(b"""\
POST /echo HTTP/1.1
Transfer-Encoding: Chunked
Content-Type: application/x-www-form-urlencoded
4
foo=
3
bar
0
""".replace(b"\n", b"\r\n"))
read_stream_body(self.stream, self.stop)
start_line, headers, response = self.wait()
self.assertEqual(json_decode(response), {u'foo': [u'bar']})
@gen_test
def test_invalid_content_length(self):
with ExpectLog(gen_log, '.*Only integer Content-Length is allowed'):
self.stream.write(b"""\
POST /echo HTTP/1.1
Content-Length: foo
bar
""".replace(b"\n", b"\r\n"))
yield self.stream.read_until_close()
class XHeaderTest(HandlerBaseTestCase):
class Handler(RequestHandler):
def get(self):
self.set_header('request-version', self.request.version)
self.write(dict(remote_ip=self.request.remote_ip,
remote_protocol=self.request.protocol))
def get_httpserver_options(self):
return dict(xheaders=True)
return dict(xheaders=True, trusted_downstream=['5.5.5.5'])
def test_ip_headers(self):
self.assertEqual(self.fetch_json("/")["remote_ip"], "127.0.0.1")
@ -490,6 +523,16 @@ class XHeaderTest(HandlerBaseTestCase):
self.fetch_json("/", headers=invalid_host)["remote_ip"],
"127.0.0.1")
def test_trusted_downstream(self):
valid_ipv4_list = {"X-Forwarded-For": "127.0.0.1, 4.4.4.4, 5.5.5.5"}
resp = self.fetch("/", headers=valid_ipv4_list)
if resp.headers['request-version'].startswith('HTTP/2'):
# This is a hack - there's nothing that fundamentally requires http/1
# here but tornado_http2 doesn't support it yet.
self.skipTest('requires HTTP/1.x')
result = json_decode(resp.body)
self.assertEqual(result['remote_ip'], "4.4.4.4")
def test_scheme_headers(self):
self.assertEqual(self.fetch_json("/")["remote_protocol"], "http")
@ -503,6 +546,16 @@ class XHeaderTest(HandlerBaseTestCase):
self.fetch_json("/", headers=https_forwarded)["remote_protocol"],
"https")
https_multi_forwarded = {"X-Forwarded-Proto": "https , http"}
self.assertEqual(
self.fetch_json("/", headers=https_multi_forwarded)["remote_protocol"],
"http")
http_multi_forwarded = {"X-Forwarded-Proto": "http,https"}
self.assertEqual(
self.fetch_json("/", headers=http_multi_forwarded)["remote_protocol"],
"https")
bad_forwarded = {"X-Forwarded-Proto": "unknown"}
self.assertEqual(
self.fetch_json("/", headers=bad_forwarded)["remote_protocol"],
@ -560,37 +613,36 @@ class UnixSocketTest(AsyncTestCase):
self.sockfile = os.path.join(self.tmpdir, "test.sock")
sock = netutil.bind_unix_socket(self.sockfile)
app = Application([("/hello", HelloWorldRequestHandler)])
self.server = HTTPServer(app, io_loop=self.io_loop)
self.server = HTTPServer(app)
self.server.add_socket(sock)
self.stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop)
self.stream.connect(self.sockfile, self.stop)
self.wait()
self.stream = IOStream(socket.socket(socket.AF_UNIX))
self.io_loop.run_sync(lambda: self.stream.connect(self.sockfile))
def tearDown(self):
self.stream.close()
self.io_loop.run_sync(self.server.close_all_connections)
self.server.stop()
shutil.rmtree(self.tmpdir)
super(UnixSocketTest, self).tearDown()
@gen_test
def test_unix_socket(self):
self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
self.stream.read_until(b"\r\n", self.stop)
response = self.wait()
response = yield self.stream.read_until(b"\r\n")
self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
self.stream.read_until(b"\r\n\r\n", self.stop)
headers = HTTPHeaders.parse(self.wait().decode('latin1'))
self.stream.read_bytes(int(headers["Content-Length"]), self.stop)
body = self.wait()
header_data = yield self.stream.read_until(b"\r\n\r\n")
headers = HTTPHeaders.parse(header_data.decode('latin1'))
body = yield self.stream.read_bytes(int(headers["Content-Length"]))
self.assertEqual(body, b"Hello world")
@gen_test
def test_unix_socket_bad_request(self):
# Unix sockets don't have remote addresses so they just return an
# empty string.
with ExpectLog(gen_log, "Malformed HTTP message from"):
self.stream.write(b"garbage\r\n\r\n")
self.stream.read_until_close(self.stop)
response = self.wait()
self.assertEqual(response, b"")
response = yield self.stream.read_until_close()
self.assertEqual(response, b"HTTP/1.1 400 Bad Request\r\n\r\n")
class KeepAliveTest(AsyncHTTPTestCase):
@ -614,9 +666,11 @@ class KeepAliveTest(AsyncHTTPTestCase):
self.write(''.join(chr(i % 256) * 1024 for i in range(512)))
class FinishOnCloseHandler(RequestHandler):
@asynchronous
@gen.coroutine
def get(self):
self.flush()
never_finish = Event()
yield never_finish.wait()
def on_connection_close(self):
# This is not very realistic, but finishing the request
@ -643,119 +697,129 @@ class KeepAliveTest(AsyncHTTPTestCase):
super(KeepAliveTest, self).tearDown()
# The next few methods are a crude manual http client
@gen.coroutine
def connect(self):
self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
self.wait()
self.stream = IOStream(socket.socket())
yield self.stream.connect(('127.0.0.1', self.get_http_port()))
@gen.coroutine
def read_headers(self):
self.stream.read_until(b'\r\n', self.stop)
first_line = self.wait()
first_line = yield self.stream.read_until(b'\r\n')
self.assertTrue(first_line.startswith(b'HTTP/1.1 200'), first_line)
self.stream.read_until(b'\r\n\r\n', self.stop)
header_bytes = self.wait()
header_bytes = yield self.stream.read_until(b'\r\n\r\n')
headers = HTTPHeaders.parse(header_bytes.decode('latin1'))
return headers
raise gen.Return(headers)
@gen.coroutine
def read_response(self):
self.headers = self.read_headers()
self.stream.read_bytes(int(self.headers['Content-Length']), self.stop)
body = self.wait()
self.headers = yield self.read_headers()
body = yield self.stream.read_bytes(int(self.headers['Content-Length']))
self.assertEqual(b'Hello world', body)
def close(self):
self.stream.close()
del self.stream
@gen_test
def test_two_requests(self):
self.connect()
yield self.connect()
self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
self.read_response()
yield self.read_response()
self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
self.read_response()
yield self.read_response()
self.close()
@gen_test
def test_request_close(self):
self.connect()
yield self.connect()
self.stream.write(b'GET / HTTP/1.1\r\nConnection: close\r\n\r\n')
self.read_response()
self.stream.read_until_close(callback=self.stop)
data = self.wait()
yield self.read_response()
data = yield self.stream.read_until_close()
self.assertTrue(not data)
self.assertEqual(self.headers['Connection'], 'close')
self.close()
# keepalive is supported for http 1.0 too, but it's opt-in
@gen_test
def test_http10(self):
self.http_version = b'HTTP/1.0'
self.connect()
yield self.connect()
self.stream.write(b'GET / HTTP/1.0\r\n\r\n')
self.read_response()
self.stream.read_until_close(callback=self.stop)
data = self.wait()
yield self.read_response()
data = yield self.stream.read_until_close()
self.assertTrue(not data)
self.assertTrue('Connection' not in self.headers)
self.close()
@gen_test
def test_http10_keepalive(self):
self.http_version = b'HTTP/1.0'
self.connect()
yield self.connect()
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
self.read_response()
yield self.read_response()
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
self.read_response()
yield self.read_response()
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.close()
@gen_test
def test_http10_keepalive_extra_crlf(self):
self.http_version = b'HTTP/1.0'
self.connect()
yield self.connect()
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n')
self.read_response()
yield self.read_response()
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
self.read_response()
yield self.read_response()
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.close()
@gen_test
def test_pipelined_requests(self):
self.connect()
yield self.connect()
self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
self.read_response()
self.read_response()
yield self.read_response()
yield self.read_response()
self.close()
@gen_test
def test_pipelined_cancel(self):
self.connect()
yield self.connect()
self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
# only read once
self.read_response()
yield self.read_response()
self.close()
@gen_test
def test_cancel_during_download(self):
self.connect()
yield self.connect()
self.stream.write(b'GET /large HTTP/1.1\r\n\r\n')
self.read_headers()
self.stream.read_bytes(1024, self.stop)
self.wait()
yield self.read_headers()
yield self.stream.read_bytes(1024)
self.close()
@gen_test
def test_finish_while_closed(self):
self.connect()
yield self.connect()
self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n')
self.read_headers()
yield self.read_headers()
self.close()
@gen_test
def test_keepalive_chunked(self):
self.http_version = b'HTTP/1.0'
self.connect()
self.stream.write(b'POST / HTTP/1.0\r\nConnection: keep-alive\r\n'
yield self.connect()
self.stream.write(b'POST / HTTP/1.0\r\n'
b'Connection: keep-alive\r\n'
b'Transfer-Encoding: chunked\r\n'
b'\r\n0\r\n')
self.read_response()
b'\r\n'
b'0\r\n'
b'\r\n')
yield self.read_response()
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
self.read_response()
yield self.read_response()
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.close()
@ -775,7 +839,7 @@ class GzipBaseTest(object):
def test_uncompressed(self):
response = self.fetch('/', method='POST', body='foo=bar')
self.assertEquals(json_decode(response.body), {u('foo'): [u('bar')]})
self.assertEquals(json_decode(response.body), {u'foo': [u'bar']})
class GzipTest(GzipBaseTest, AsyncHTTPTestCase):
@ -784,7 +848,7 @@ class GzipTest(GzipBaseTest, AsyncHTTPTestCase):
def test_gzip(self):
response = self.post_gzip('foo=bar')
self.assertEquals(json_decode(response.body), {u('foo'): [u('bar')]})
self.assertEquals(json_decode(response.body), {u'foo': [u'bar']})
class GzipUnsupportedTest(GzipBaseTest, AsyncHTTPTestCase):
@ -805,7 +869,7 @@ class StreamingChunkSizeTest(AsyncHTTPTestCase):
def get_http_client(self):
# body_producer doesn't work on curl_httpclient, so override the
# configured AsyncHTTPClient implementation.
return SimpleAsyncHTTPClient(io_loop=self.io_loop)
return SimpleAsyncHTTPClient()
def get_httpserver_options(self):
return dict(chunk_size=self.CHUNK_SIZE, decompress_request=True)
@ -900,11 +964,14 @@ class MaxHeaderSizeTest(AsyncHTTPTestCase):
def test_large_headers(self):
with ExpectLog(gen_log, "Unsatisfiable read", required=False):
response = self.fetch("/", headers={'X-Filler': 'a' * 1000})
# 431 is "Request Header Fields Too Large", defined in RFC
# 6585. However, many implementations just close the
# connection in this case, resulting in a 599.
self.assertIn(response.code, (431, 599))
try:
self.fetch("/", headers={'X-Filler': 'a' * 1000}, raise_error=True)
self.fail("did not raise expected exception")
except HTTPError as e:
# 431 is "Request Header Fields Too Large", defined in RFC
# 6585. However, many implementations just close the
# connection in this case, resulting in a 599.
self.assertIn(e.response.code, (431, 599))
@skipOnTravis
@ -924,34 +991,35 @@ class IdleTimeoutTest(AsyncHTTPTestCase):
for stream in self.streams:
stream.close()
@gen.coroutine
def connect(self):
stream = IOStream(socket.socket())
stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
self.wait()
yield stream.connect(('127.0.0.1', self.get_http_port()))
self.streams.append(stream)
return stream
raise gen.Return(stream)
@gen_test
def test_unused_connection(self):
stream = self.connect()
stream.set_close_callback(self.stop)
self.wait()
stream = yield self.connect()
event = Event()
stream.set_close_callback(event.set)
yield event.wait()
@gen_test
def test_idle_after_use(self):
stream = self.connect()
stream.set_close_callback(lambda: self.stop("closed"))
stream = yield self.connect()
event = Event()
stream.set_close_callback(event.set)
# Use the connection twice to make sure keep-alives are working
for i in range(2):
stream.write(b"GET / HTTP/1.1\r\n\r\n")
stream.read_until(b"\r\n\r\n", self.stop)
self.wait()
stream.read_bytes(11, self.stop)
data = self.wait()
yield stream.read_until(b"\r\n\r\n")
data = yield stream.read_bytes(11)
self.assertEqual(data, b"Hello world")
# Now let the timeout trigger and close the connection.
data = self.wait()
self.assertEqual(data, "closed")
yield event.wait()
class BodyLimitsTest(AsyncHTTPTestCase):
@ -988,7 +1056,7 @@ class BodyLimitsTest(AsyncHTTPTestCase):
def get_http_client(self):
# body_producer doesn't work on curl_httpclient, so override the
# configured AsyncHTTPClient implementation.
return SimpleAsyncHTTPClient(io_loop=self.io_loop)
return SimpleAsyncHTTPClient()
def test_small_body(self):
response = self.fetch('/buffered', method='PUT', body=b'a' * 4096)
@ -999,24 +1067,27 @@ class BodyLimitsTest(AsyncHTTPTestCase):
def test_large_body_buffered(self):
with ExpectLog(gen_log, '.*Content-Length too long'):
response = self.fetch('/buffered', method='PUT', body=b'a' * 10240)
self.assertEqual(response.code, 599)
self.assertEqual(response.code, 400)
@unittest.skipIf(os.name == 'nt', 'flaky on windows')
def test_large_body_buffered_chunked(self):
# This test is flaky on windows for unknown reasons.
with ExpectLog(gen_log, '.*chunked body too large'):
response = self.fetch('/buffered', method='PUT',
body_producer=lambda write: write(b'a' * 10240))
self.assertEqual(response.code, 599)
self.assertEqual(response.code, 400)
def test_large_body_streaming(self):
with ExpectLog(gen_log, '.*Content-Length too long'):
response = self.fetch('/streaming', method='PUT', body=b'a' * 10240)
self.assertEqual(response.code, 599)
self.assertEqual(response.code, 400)
@unittest.skipIf(os.name == 'nt', 'flaky on windows')
def test_large_body_streaming_chunked(self):
with ExpectLog(gen_log, '.*chunked body too large'):
response = self.fetch('/streaming', method='PUT',
body_producer=lambda write: write(b'a' * 10240))
self.assertEqual(response.code, 599)
self.assertEqual(response.code, 400)
def test_large_body_streaming_override(self):
response = self.fetch('/streaming?expected_size=10240', method='PUT',
@ -1053,14 +1124,16 @@ class BodyLimitsTest(AsyncHTTPTestCase):
stream.write(b'PUT /streaming?expected_size=10240 HTTP/1.1\r\n'
b'Content-Length: 10240\r\n\r\n')
stream.write(b'a' * 10240)
headers, response = yield gen.Task(read_stream_body, stream)
fut = Future()
read_stream_body(stream, callback=fut.set_result)
start_line, headers, response = yield fut
self.assertEqual(response, b'10240')
# Without the ?expected_size parameter, we get the old default value
stream.write(b'PUT /streaming HTTP/1.1\r\n'
b'Content-Length: 10240\r\n\r\n')
with ExpectLog(gen_log, '.*Content-Length too long'):
data = yield stream.read_until_close()
self.assertEqual(data, b'')
self.assertEqual(data, b'HTTP/1.1 400 Bad Request\r\n\r\n')
finally:
stream.close()
@ -1081,10 +1154,10 @@ class LegacyInterfaceTest(AsyncHTTPTestCase):
request.connection.finish()
return
message = b"Hello world"
request.write(utf8("HTTP/1.1 200 OK\r\n"
"Content-Length: %d\r\n\r\n" % len(message)))
request.write(message)
request.finish()
request.connection.write(utf8("HTTP/1.1 200 OK\r\n"
"Content-Length: %d\r\n\r\n" % len(message)))
request.connection.write(message)
request.connection.finish()
return handle_request
def test_legacy_interface(self):

View file

@ -1,13 +1,16 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
from __future__ import absolute_import, division, print_function, with_statement
from tornado.httputil import url_concat, parse_multipart_form_data, HTTPHeaders, format_timestamp, HTTPServerRequest, parse_request_start_line
from tornado.httputil import (
url_concat, parse_multipart_form_data, HTTPHeaders, format_timestamp,
HTTPServerRequest, parse_request_start_line, parse_cookie, qs_to_qsl,
HTTPInputError,
)
from tornado.escape import utf8, native_str
from tornado.util import PY3
from tornado.log import gen_log
from tornado.testing import ExpectLog
from tornado.test.util import unittest
from tornado.util import u
import copy
import datetime
@ -15,6 +18,11 @@ import logging
import pickle
import time
if PY3:
import urllib.parse as urllib_parse
else:
import urlparse as urllib_parse
class TestUrlConcat(unittest.TestCase):
def test_url_concat_no_query_params(self):
@ -43,14 +51,14 @@ class TestUrlConcat(unittest.TestCase):
"https://localhost/path?x",
[('y', 'y'), ('z', 'z')],
)
self.assertEqual(url, "https://localhost/path?x&y=y&z=z")
self.assertEqual(url, "https://localhost/path?x=&y=y&z=z")
def test_url_concat_trailing_amp(self):
url = url_concat(
"https://localhost/path?x&",
[('y', 'y'), ('z', 'z')],
)
self.assertEqual(url, "https://localhost/path?x&y=y&z=z")
self.assertEqual(url, "https://localhost/path?x=&y=y&z=z")
def test_url_concat_mult_params(self):
url = url_concat(
@ -66,6 +74,52 @@ class TestUrlConcat(unittest.TestCase):
)
self.assertEqual(url, "https://localhost/path?r=1&t=2")
def test_url_concat_none_params(self):
url = url_concat(
"https://localhost/path?r=1&t=2",
None,
)
self.assertEqual(url, "https://localhost/path?r=1&t=2")
def test_url_concat_with_frag(self):
url = url_concat(
"https://localhost/path#tab",
[('y', 'y')],
)
self.assertEqual(url, "https://localhost/path?y=y#tab")
def test_url_concat_multi_same_params(self):
url = url_concat(
"https://localhost/path",
[('y', 'y1'), ('y', 'y2')],
)
self.assertEqual(url, "https://localhost/path?y=y1&y=y2")
def test_url_concat_multi_same_query_params(self):
url = url_concat(
"https://localhost/path?r=1&r=2",
[('y', 'y')],
)
self.assertEqual(url, "https://localhost/path?r=1&r=2&y=y")
def test_url_concat_dict_params(self):
url = url_concat(
"https://localhost/path",
dict(y='y'),
)
self.assertEqual(url, "https://localhost/path?y=y")
class QsParseTest(unittest.TestCase):
def test_parsing(self):
qsstring = "a=1&b=2&a=3"
qs = urllib_parse.parse_qs(qsstring)
qsl = list(qs_to_qsl(qs))
self.assertIn(('a', '1'), qsl)
self.assertIn(('a', '3'), qsl)
self.assertIn(('b', '2'), qsl)
class MultipartFormDataTest(unittest.TestCase):
def test_file_upload(self):
@ -122,6 +176,20 @@ Foo
self.assertEqual(file["filename"], filename)
self.assertEqual(file["body"], b"Foo")
def test_non_ascii_filename(self):
data = b"""\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"; filename*=UTF-8''%C3%A1b.txt
Foo
--1234--""".replace(b"\n", b"\r\n")
args = {}
files = {}
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], u"áb.txt")
self.assertEqual(file["body"], b"Foo")
def test_boundary_starts_and_ends_with_quotes(self):
data = b'''\
--1234
@ -230,6 +298,13 @@ Foo: even
("Foo", "bar baz"),
("Foo", "even more lines")])
def test_malformed_continuation(self):
# If the first line starts with whitespace, it's a
# continuation line with nothing to continue, so reject it
# (with a proper error).
data = " Foo: bar"
self.assertRaises(HTTPInputError, HTTPHeaders.parse, data)
def test_unicode_newlines(self):
# Ensure that only \r\n is recognized as a header separator, and not
# the other newline-like unicode characters.
@ -238,13 +313,13 @@ Foo: even
# and cpython's unicodeobject.c (which defines the implementation
# of unicode_type.splitlines(), and uses a different list than TR13).
newlines = [
u('\u001b'), # VERTICAL TAB
u('\u001c'), # FILE SEPARATOR
u('\u001d'), # GROUP SEPARATOR
u('\u001e'), # RECORD SEPARATOR
u('\u0085'), # NEXT LINE
u('\u2028'), # LINE SEPARATOR
u('\u2029'), # PARAGRAPH SEPARATOR
u'\u001b', # VERTICAL TAB
u'\u001c', # FILE SEPARATOR
u'\u001d', # GROUP SEPARATOR
u'\u001e', # RECORD SEPARATOR
u'\u0085', # NEXT LINE
u'\u2028', # LINE SEPARATOR
u'\u2029', # PARAGRAPH SEPARATOR
]
for newline in newlines:
# Try the utf8 and latin1 representations of each newline
@ -319,6 +394,14 @@ Foo: even
self.assertEqual(headers['quux'], 'xyzzy')
self.assertEqual(sorted(headers.get_all()), [('Foo', 'bar'), ('Quux', 'xyzzy')])
def test_string(self):
headers = HTTPHeaders()
headers.add("Foo", "1")
headers.add("Foo", "2")
headers.add("Foo", "3")
headers2 = HTTPHeaders.parse(str(headers))
self.assertEquals(headers, headers2)
class FormatTimestampTest(unittest.TestCase):
# Make sure that all the input types are supported.
@ -359,6 +442,10 @@ class HTTPServerRequestTest(unittest.TestCase):
requets = HTTPServerRequest(uri='/')
self.assertIsInstance(requets.body, bytes)
def test_repr_does_not_contain_headers(self):
request = HTTPServerRequest(uri='/', headers={'Canary': 'Coal Mine'})
self.assertTrue('Canary' not in repr(request))
class ParseRequestStartLineTest(unittest.TestCase):
METHOD = "GET"
@ -371,3 +458,59 @@ class ParseRequestStartLineTest(unittest.TestCase):
self.assertEqual(parsed_start_line.method, self.METHOD)
self.assertEqual(parsed_start_line.path, self.PATH)
self.assertEqual(parsed_start_line.version, self.VERSION)
class ParseCookieTest(unittest.TestCase):
# These tests copied from Django:
# https://github.com/django/django/pull/6277/commits/da810901ada1cae9fc1f018f879f11a7fb467b28
def test_python_cookies(self):
"""
Test cases copied from Python's Lib/test/test_http_cookies.py
"""
self.assertEqual(parse_cookie('chips=ahoy; vienna=finger'),
{'chips': 'ahoy', 'vienna': 'finger'})
# Here parse_cookie() differs from Python's cookie parsing in that it
# treats all semicolons as delimiters, even within quotes.
self.assertEqual(
parse_cookie('keebler="E=mc2; L=\\"Loves\\"; fudge=\\012;"'),
{'keebler': '"E=mc2', 'L': '\\"Loves\\"', 'fudge': '\\012', '': '"'}
)
# Illegal cookies that have an '=' char in an unquoted value.
self.assertEqual(parse_cookie('keebler=E=mc2'), {'keebler': 'E=mc2'})
# Cookies with ':' character in their name.
self.assertEqual(parse_cookie('key:term=value:term'), {'key:term': 'value:term'})
# Cookies with '[' and ']'.
self.assertEqual(parse_cookie('a=b; c=[; d=r; f=h'),
{'a': 'b', 'c': '[', 'd': 'r', 'f': 'h'})
def test_cookie_edgecases(self):
# Cookies that RFC6265 allows.
self.assertEqual(parse_cookie('a=b; Domain=example.com'),
{'a': 'b', 'Domain': 'example.com'})
# parse_cookie() has historically kept only the last cookie with the
# same name.
self.assertEqual(parse_cookie('a=b; h=i; a=c'), {'a': 'c', 'h': 'i'})
def test_invalid_cookies(self):
"""
Cookie strings that go against RFC6265 but browsers will send if set
via document.cookie.
"""
# Chunks without an equals sign appear as unnamed values per
# https://bugzilla.mozilla.org/show_bug.cgi?id=169091
self.assertIn('django_language',
parse_cookie('abc=def; unnamed; django_language=en').keys())
# Even a double quote may be an unamed value.
self.assertEqual(parse_cookie('a=b; "; c=d'), {'a': 'b', '': '"', 'c': 'd'})
# Spaces in names and values, and an equals sign in values.
self.assertEqual(parse_cookie('a b c=d e = f; gh=i'), {'a b c': 'd e = f', 'gh': 'i'})
# More characters the spec forbids.
self.assertEqual(parse_cookie('a b,c<>@:/[]?{}=d " =e,f g'),
{'a b,c<>@:/[]?{}': 'd " =e,f g'})
# Unicode characters. The spec only allows ASCII.
self.assertEqual(parse_cookie('saint=André Bessette'),
{'saint': native_str('André Bessette')})
# Browsers don't send extra whitespace or semicolons in Cookie headers,
# but parse_cookie() should parse whitespace the same way
# document.cookie parses whitespace.
self.assertEqual(parse_cookie(' = b ; ; = ; c = ; '), {'': 'b', 'c': ''})

View file

@ -1,47 +1,73 @@
# flake8: noqa
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import subprocess
import sys
from tornado.test.util import unittest
_import_everything = b"""
# The event loop is not fork-safe, and it's easy to initialize an asyncio.Future
# at startup, which in turn creates the default event loop and prevents forking.
# Explicitly disallow the default event loop so that an error will be raised
# if something tries to touch it.
try:
import asyncio
except ImportError:
pass
else:
asyncio.set_event_loop(None)
import tornado.auth
import tornado.autoreload
import tornado.concurrent
import tornado.escape
import tornado.gen
import tornado.http1connection
import tornado.httpclient
import tornado.httpserver
import tornado.httputil
import tornado.ioloop
import tornado.iostream
import tornado.locale
import tornado.log
import tornado.netutil
import tornado.options
import tornado.process
import tornado.simple_httpclient
import tornado.stack_context
import tornado.tcpserver
import tornado.tcpclient
import tornado.template
import tornado.testing
import tornado.util
import tornado.web
import tornado.websocket
import tornado.wsgi
try:
import pycurl
except ImportError:
pass
else:
import tornado.curl_httpclient
"""
class ImportTest(unittest.TestCase):
def test_import_everything(self):
# Some of our modules are not otherwise tested. Import them
# all (unless they have external dependencies) here to at
# least ensure that there are no syntax errors.
import tornado.auth
import tornado.autoreload
import tornado.concurrent
# import tornado.curl_httpclient # depends on pycurl
import tornado.escape
import tornado.gen
import tornado.http1connection
import tornado.httpclient
import tornado.httpserver
import tornado.httputil
# Test that all Tornado modules can be imported without side effects,
# specifically without initializing the default asyncio event loop.
# Since we can't tell which modules may have already beein imported
# in our process, do it in a subprocess for a clean slate.
proc = subprocess.Popen([sys.executable], stdin=subprocess.PIPE)
proc.communicate(_import_everything)
self.assertEqual(proc.returncode, 0)
def test_import_aliases(self):
# Ensure we don't delete formerly-documented aliases accidentally.
import tornado.ioloop
import tornado.iostream
import tornado.locale
import tornado.log
import tornado.netutil
import tornado.options
import tornado.process
import tornado.simple_httpclient
import tornado.stack_context
import tornado.tcpserver
import tornado.template
import tornado.testing
import tornado.gen
import tornado.util
import tornado.web
import tornado.websocket
import tornado.wsgi
# for modules with dependencies, if those dependencies can be loaded,
# load them too.
def test_import_pycurl(self):
try:
import pycurl
except ImportError:
pass
else:
import tornado.curl_httpclient
self.assertIs(tornado.ioloop.TimeoutError, tornado.util.TimeoutError)
self.assertIs(tornado.gen.TimeoutError, tornado.util.TimeoutError)

View file

@ -1,28 +1,48 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
from __future__ import absolute_import, division, print_function, with_statement
from concurrent.futures import ThreadPoolExecutor
import contextlib
import datetime
import functools
import socket
import subprocess
import sys
import threading
import time
import types
try:
from unittest import mock # type: ignore
except ImportError:
try:
import mock # type: ignore
except ImportError:
mock = None
from tornado.escape import native_str
from tornado import gen
from tornado.ioloop import IOLoop, TimeoutError, PollIOLoop, PeriodicCallback
from tornado.log import app_log
from tornado.platform.select import _Select
from tornado.stack_context import ExceptionStackContext, StackContext, wrap, NullContext
from tornado.testing import AsyncTestCase, bind_unused_port, ExpectLog
from tornado.test.util import unittest, skipIfNonUnix, skipOnTravis, skipBefore35, exec_test
from tornado.testing import AsyncTestCase, bind_unused_port, ExpectLog, gen_test
from tornado.test.util import (unittest, skipIfNonUnix, skipOnTravis,
skipBefore35, exec_test, ignore_deprecation)
try:
from concurrent import futures
except ImportError:
futures = None
try:
import asyncio
except ImportError:
asyncio = None
try:
import twisted
except ImportError:
twisted = None
class FakeTimeSelect(_Select):
def __init__(self):
@ -61,6 +81,25 @@ class FakeTimeIOLoop(PollIOLoop):
class TestIOLoop(AsyncTestCase):
def test_add_callback_return_sequence(self):
# A callback returning {} or [] shouldn't spin the CPU, see Issue #1803.
self.calls = 0
loop = self.io_loop
test = self
old_add_callback = loop.add_callback
def add_callback(self, callback, *args, **kwargs):
test.calls += 1
old_add_callback(callback, *args, **kwargs)
loop.add_callback = types.MethodType(add_callback, loop)
loop.add_callback(lambda: {})
loop.add_callback(lambda: [])
loop.add_timeout(datetime.timedelta(milliseconds=50), loop.stop)
loop.start()
self.assertLess(self.calls, 10)
@skipOnTravis
def test_add_callback_wakeup(self):
# Make sure that add_callback from inside a running IOLoop
@ -138,8 +177,9 @@ class TestIOLoop(AsyncTestCase):
other_ioloop.close()
def test_add_callback_while_closing(self):
# Issue #635: add_callback() should raise a clean exception
# if called while another thread is closing the IOLoop.
# add_callback should not fail if it races with another thread
# closing the IOLoop. The callbacks are dropped silently
# without executing.
closing = threading.Event()
def target():
@ -152,11 +192,7 @@ class TestIOLoop(AsyncTestCase):
thread.start()
closing.wait()
for i in range(1000):
try:
other_ioloop.add_callback(lambda: None)
except RuntimeError as e:
self.assertEqual("IOLoop is closing", str(e))
break
other_ioloop.add_callback(lambda: None)
def test_handle_callback_exception(self):
# IOLoop.handle_callback_exception can be overridden to catch
@ -241,7 +277,9 @@ class TestIOLoop(AsyncTestCase):
self.io_loop.call_later(0, results.append, 4)
self.io_loop.call_later(0, self.stop)
self.wait()
self.assertEqual(results, [1, 2, 3, 4])
# The asyncio event loop does not guarantee the order of these
# callbacks, but PollIOLoop does.
self.assertEqual(sorted(results), [1, 2, 3, 4])
def test_add_timeout_return(self):
# All the timeout methods return non-None handles that can be
@ -368,25 +406,29 @@ class TestIOLoop(AsyncTestCase):
"""The IOLoop examines exceptions from awaitables and logs them."""
namespace = exec_test(globals(), locals(), """
async def callback():
self.io_loop.add_callback(self.stop)
# Stop the IOLoop two iterations after raising an exception
# to give the exception time to be logged.
self.io_loop.add_callback(self.io_loop.add_callback, self.stop)
1 / 0
""")
with NullContext():
self.io_loop.add_callback(namespace["callback"])
with ExpectLog(app_log, "Exception in callback"):
self.wait()
def test_spawn_callback(self):
# An added callback runs in the test's stack_context, so will be
# re-arised in wait().
self.io_loop.add_callback(lambda: 1 / 0)
with self.assertRaises(ZeroDivisionError):
self.wait()
# A spawned callback is run directly on the IOLoop, so it will be
# logged without stopping the test.
self.io_loop.spawn_callback(lambda: 1 / 0)
self.io_loop.add_callback(self.stop)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
with ignore_deprecation():
# An added callback runs in the test's stack_context, so will be
# re-raised in wait().
self.io_loop.add_callback(lambda: 1 / 0)
with self.assertRaises(ZeroDivisionError):
self.wait()
# A spawned callback is run directly on the IOLoop, so it will be
# logged without stopping the test.
self.io_loop.spawn_callback(lambda: 1 / 0)
self.io_loop.add_callback(self.stop)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
@skipIfNonUnix
def test_remove_handler_from_handler(self):
@ -407,7 +449,7 @@ class TestIOLoop(AsyncTestCase):
self.io_loop.remove_handler(client)
self.io_loop.add_handler(client, handle_read, self.io_loop.READ)
self.io_loop.add_handler(server, handle_read, self.io_loop.READ)
self.io_loop.call_later(0.03, self.stop)
self.io_loop.call_later(0.1, self.stop)
self.wait()
# Only one fd was read; the other was cleanly removed.
@ -416,6 +458,16 @@ class TestIOLoop(AsyncTestCase):
client.close()
server.close()
@gen_test
def test_init_close_race(self):
# Regression test for #2367
def f():
for i in range(10):
loop = IOLoop()
loop.close()
yield gen.multi([self.io_loop.run_in_executor(None, f) for i in range(2)])
# Deliberately not a subclass of AsyncTestCase so the IOLoop isn't
# automatically set as current.
@ -463,6 +515,16 @@ class TestIOLoopCurrent(unittest.TestCase):
self.assertIs(self.io_loop, IOLoop.current())
class TestIOLoopCurrentAsync(AsyncTestCase):
@gen_test
def test_clear_without_current(self):
# If there is no current IOLoop, clear_current is a no-op (but
# should not fail). Use a thread so we see the threading.Local
# in a pristine state.
with ThreadPoolExecutor(1) as e:
yield e.submit(IOLoop.clear_current)
class TestIOLoopAddCallback(AsyncTestCase):
def setUp(self):
super(TestIOLoopAddCallback, self).setUp()
@ -485,11 +547,12 @@ class TestIOLoopAddCallback(AsyncTestCase):
self.assertNotIn('c2', self.active_contexts)
self.stop()
with StackContext(functools.partial(self.context, 'c1')):
wrapped = wrap(f1)
with ignore_deprecation():
with StackContext(functools.partial(self.context, 'c1')):
wrapped = wrap(f1)
with StackContext(functools.partial(self.context, 'c2')):
self.add_callback(wrapped)
with StackContext(functools.partial(self.context, 'c2')):
self.add_callback(wrapped)
self.wait()
@ -503,11 +566,12 @@ class TestIOLoopAddCallback(AsyncTestCase):
self.assertNotIn('c2', self.active_contexts)
self.stop((foo, bar))
with StackContext(functools.partial(self.context, 'c1')):
wrapped = wrap(f1)
with ignore_deprecation():
with StackContext(functools.partial(self.context, 'c1')):
wrapped = wrap(f1)
with StackContext(functools.partial(self.context, 'c2')):
self.add_callback(wrapped, 1, bar=2)
with StackContext(functools.partial(self.context, 'c2')):
self.add_callback(wrapped, 1, bar=2)
result = self.wait()
self.assertEqual(result, (1, 2))
@ -552,15 +616,86 @@ class TestIOLoopFutures(AsyncTestCase):
# stack_context propagates to the ioloop callback, but the worker
# task just has its exceptions caught and saved in the Future.
with futures.ThreadPoolExecutor(1) as pool:
with ExceptionStackContext(handle_exception):
self.io_loop.add_future(pool.submit(task), callback)
ready.set()
self.wait()
with ignore_deprecation():
with futures.ThreadPoolExecutor(1) as pool:
with ExceptionStackContext(handle_exception):
self.io_loop.add_future(pool.submit(task), callback)
ready.set()
self.wait()
self.assertEqual(self.exception.args[0], "callback")
self.assertEqual(self.future.exception().args[0], "worker")
@gen_test
def test_run_in_executor_gen(self):
event1 = threading.Event()
event2 = threading.Event()
def sync_func(self_event, other_event):
self_event.set()
other_event.wait()
# Note that return value doesn't actually do anything,
# it is just passed through to our final assertion to
# make sure it is passed through properly.
return self_event
# Run two synchronous functions, which would deadlock if not
# run in parallel.
res = yield [
IOLoop.current().run_in_executor(None, sync_func, event1, event2),
IOLoop.current().run_in_executor(None, sync_func, event2, event1)
]
self.assertEqual([event1, event2], res)
@skipBefore35
@gen_test
def test_run_in_executor_native(self):
event1 = threading.Event()
event2 = threading.Event()
def sync_func(self_event, other_event):
self_event.set()
other_event.wait()
return self_event
# Go through an async wrapper to ensure that the result of
# run_in_executor works with await and not just gen.coroutine
# (simply passing the underlying concurrrent future would do that).
namespace = exec_test(globals(), locals(), """
async def async_wrapper(self_event, other_event):
return await IOLoop.current().run_in_executor(
None, sync_func, self_event, other_event)
""")
res = yield [
namespace["async_wrapper"](event1, event2),
namespace["async_wrapper"](event2, event1)
]
self.assertEqual([event1, event2], res)
@gen_test
def test_set_default_executor(self):
count = [0]
class MyExecutor(futures.ThreadPoolExecutor):
def submit(self, func, *args):
count[0] += 1
return super(MyExecutor, self).submit(func, *args)
event = threading.Event()
def sync_func():
event.set()
executor = MyExecutor(1)
loop = IOLoop.current()
loop.set_default_executor(executor)
yield loop.run_in_executor(None, sync_func)
self.assertEqual(1, count[0])
self.assertTrue(event.is_set())
class TestIOLoopRunSync(unittest.TestCase):
def setUp(self):
@ -580,14 +715,14 @@ class TestIOLoopRunSync(unittest.TestCase):
def test_async_result(self):
@gen.coroutine
def f():
yield gen.Task(self.io_loop.add_callback)
yield gen.moment
raise gen.Return(42)
self.assertEqual(self.io_loop.run_sync(f), 42)
def test_async_exception(self):
@gen.coroutine
def f():
yield gen.Task(self.io_loop.add_callback)
yield gen.moment
1 / 0
with self.assertRaises(ZeroDivisionError):
self.io_loop.run_sync(f)
@ -600,18 +735,24 @@ class TestIOLoopRunSync(unittest.TestCase):
def test_timeout(self):
@gen.coroutine
def f():
yield gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1)
yield gen.sleep(1)
self.assertRaises(TimeoutError, self.io_loop.run_sync, f, timeout=0.01)
@skipBefore35
def test_native_coroutine(self):
@gen.coroutine
def f1():
yield gen.moment
namespace = exec_test(globals(), locals(), """
async def f():
await gen.Task(self.io_loop.add_callback)
async def f2():
await f1()
""")
self.io_loop.run_sync(namespace['f'])
self.io_loop.run_sync(namespace['f2'])
@unittest.skipIf(asyncio is not None,
'IOLoop configuration not available')
class TestPeriodicCallback(unittest.TestCase):
def setUp(self):
self.io_loop = FakeTimeIOLoop()
@ -653,6 +794,149 @@ class TestPeriodicCallback(unittest.TestCase):
self.io_loop.start()
self.assertEqual(calls, expected)
def test_io_loop_set_at_start(self):
# Check PeriodicCallback uses the current IOLoop at start() time,
# not at instantiation time.
calls = []
io_loop = FakeTimeIOLoop()
def cb():
calls.append(io_loop.time())
pc = PeriodicCallback(cb, 10000)
io_loop.make_current()
pc.start()
io_loop.call_later(50, io_loop.stop)
io_loop.start()
self.assertEqual(calls, [1010, 1020, 1030, 1040, 1050])
io_loop.close()
class TestPeriodicCallbackMath(unittest.TestCase):
def simulate_calls(self, pc, durations):
"""Simulate a series of calls to the PeriodicCallback.
Pass a list of call durations in seconds (negative values
work to simulate clock adjustments during the call, or more or
less equivalently, between calls). This method returns the
times at which each call would be made.
"""
calls = []
now = 1000
pc._next_timeout = now
for d in durations:
pc._update_next(now)
calls.append(pc._next_timeout)
now = pc._next_timeout + d
return calls
def test_basic(self):
pc = PeriodicCallback(None, 10000)
self.assertEqual(self.simulate_calls(pc, [0] * 5),
[1010, 1020, 1030, 1040, 1050])
def test_overrun(self):
# If a call runs for too long, we skip entire cycles to get
# back on schedule.
call_durations = [9, 9, 10, 11, 20, 20, 35, 35, 0, 0, 0]
expected = [
1010, 1020, 1030, # first 3 calls on schedule
1050, 1070, # next 2 delayed one cycle
1100, 1130, # next 2 delayed 2 cycles
1170, 1210, # next 2 delayed 3 cycles
1220, 1230, # then back on schedule.
]
pc = PeriodicCallback(None, 10000)
self.assertEqual(self.simulate_calls(pc, call_durations),
expected)
def test_clock_backwards(self):
pc = PeriodicCallback(None, 10000)
# Backwards jumps are ignored, potentially resulting in a
# slightly slow schedule (although we assume that when
# time.time() and time.monotonic() are different, time.time()
# is getting adjusted by NTP and is therefore more accurate)
self.assertEqual(self.simulate_calls(pc, [-2, -1, -3, -2, 0]),
[1010, 1020, 1030, 1040, 1050])
# For big jumps, we should perhaps alter the schedule, but we
# don't currently. This trace shows that we run callbacks
# every 10s of time.time(), but the first and second calls are
# 110s of real time apart because the backwards jump is
# ignored.
self.assertEqual(self.simulate_calls(pc, [-100, 0, 0]),
[1010, 1020, 1030])
@unittest.skipIf(mock is None, 'mock package not present')
def test_jitter(self):
random_times = [0.5, 1, 0, 0.75]
expected = [1010, 1022.5, 1030, 1041.25]
call_durations = [0] * len(random_times)
pc = PeriodicCallback(None, 10000, jitter=0.5)
def mock_random():
return random_times.pop(0)
with mock.patch('random.random', mock_random):
self.assertEqual(self.simulate_calls(pc, call_durations),
expected)
class TestIOLoopConfiguration(unittest.TestCase):
def run_python(self, *statements):
statements = [
'from tornado.ioloop import IOLoop, PollIOLoop',
'classname = lambda x: x.__class__.__name__',
] + list(statements)
args = [sys.executable, '-c', '; '.join(statements)]
return native_str(subprocess.check_output(args)).strip()
def test_default(self):
if asyncio is not None:
# When asyncio is available, it is used by default.
cls = self.run_python('print(classname(IOLoop.current()))')
self.assertEqual(cls, 'AsyncIOMainLoop')
cls = self.run_python('print(classname(IOLoop()))')
self.assertEqual(cls, 'AsyncIOLoop')
else:
# Otherwise, the default is a subclass of PollIOLoop
is_poll = self.run_python(
'print(isinstance(IOLoop.current(), PollIOLoop))')
self.assertEqual(is_poll, 'True')
@unittest.skipIf(asyncio is not None,
"IOLoop configuration not available")
def test_explicit_select(self):
# SelectIOLoop can always be configured explicitly.
default_class = self.run_python(
'IOLoop.configure("tornado.platform.select.SelectIOLoop")',
'print(classname(IOLoop.current()))')
self.assertEqual(default_class, 'SelectIOLoop')
@unittest.skipIf(asyncio is None, "asyncio module not present")
def test_asyncio(self):
cls = self.run_python(
'IOLoop.configure("tornado.platform.asyncio.AsyncIOLoop")',
'print(classname(IOLoop.current()))')
self.assertEqual(cls, 'AsyncIOMainLoop')
@unittest.skipIf(asyncio is None, "asyncio module not present")
def test_asyncio_main(self):
cls = self.run_python(
'from tornado.platform.asyncio import AsyncIOMainLoop',
'AsyncIOMainLoop().install()',
'print(classname(IOLoop.current()))')
self.assertEqual(cls, 'AsyncIOMainLoop')
@unittest.skipIf(twisted is None, "twisted module not present")
@unittest.skipIf(asyncio is not None,
"IOLoop configuration not available")
def test_twisted(self):
cls = self.run_python(
'from tornado.platform.twisted import TwistedIOLoop',
'TwistedIOLoop().install()',
'print(classname(IOLoop.current()))')
self.assertEqual(cls, 'TwistedIOLoop')
if __name__ == "__main__":
unittest.main()

View file

@ -1,4 +1,4 @@
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import datetime
import os
@ -8,7 +8,7 @@ import tempfile
import tornado.locale
from tornado.escape import utf8, to_unicode
from tornado.test.util import unittest, skipOnAppEngine
from tornado.util import u, unicode_type
from tornado.util import unicode_type
class TranslationLoaderTest(unittest.TestCase):
@ -35,7 +35,7 @@ class TranslationLoaderTest(unittest.TestCase):
os.path.join(os.path.dirname(__file__), 'csv_translations'))
locale = tornado.locale.get("fr_FR")
self.assertTrue(isinstance(locale, tornado.locale.CSVLocale))
self.assertEqual(locale.translate("school"), u("\u00e9cole"))
self.assertEqual(locale.translate("school"), u"\u00e9cole")
# tempfile.mkdtemp is not available on app engine.
@skipOnAppEngine
@ -55,7 +55,7 @@ class TranslationLoaderTest(unittest.TestCase):
tornado.locale.load_translations(tmpdir)
locale = tornado.locale.get('fr_FR')
self.assertIsInstance(locale, tornado.locale.CSVLocale)
self.assertEqual(locale.translate("school"), u("\u00e9cole"))
self.assertEqual(locale.translate("school"), u"\u00e9cole")
finally:
shutil.rmtree(tmpdir)
@ -65,20 +65,20 @@ class TranslationLoaderTest(unittest.TestCase):
"tornado_test")
locale = tornado.locale.get("fr_FR")
self.assertTrue(isinstance(locale, tornado.locale.GettextLocale))
self.assertEqual(locale.translate("school"), u("\u00e9cole"))
self.assertEqual(locale.pgettext("law", "right"), u("le droit"))
self.assertEqual(locale.pgettext("good", "right"), u("le bien"))
self.assertEqual(locale.pgettext("organization", "club", "clubs", 1), u("le club"))
self.assertEqual(locale.pgettext("organization", "club", "clubs", 2), u("les clubs"))
self.assertEqual(locale.pgettext("stick", "club", "clubs", 1), u("le b\xe2ton"))
self.assertEqual(locale.pgettext("stick", "club", "clubs", 2), u("les b\xe2tons"))
self.assertEqual(locale.translate("school"), u"\u00e9cole")
self.assertEqual(locale.pgettext("law", "right"), u"le droit")
self.assertEqual(locale.pgettext("good", "right"), u"le bien")
self.assertEqual(locale.pgettext("organization", "club", "clubs", 1), u"le club")
self.assertEqual(locale.pgettext("organization", "club", "clubs", 2), u"les clubs")
self.assertEqual(locale.pgettext("stick", "club", "clubs", 1), u"le b\xe2ton")
self.assertEqual(locale.pgettext("stick", "club", "clubs", 2), u"les b\xe2tons")
class LocaleDataTest(unittest.TestCase):
def test_non_ascii_name(self):
name = tornado.locale.LOCALE_NAMES['es_LA']['name']
self.assertTrue(isinstance(name, unicode_type))
self.assertEqual(name, u('Espa\u00f1ol'))
self.assertEqual(name, u'Espa\u00f1ol')
self.assertEqual(utf8(name), b'Espa\xc3\xb1ol')
@ -89,16 +89,17 @@ class EnglishTest(unittest.TestCase):
self.assertEqual(locale.format_date(date, full_format=True),
'April 28, 2013 at 6:35 pm')
self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(seconds=2), full_format=False),
now = datetime.datetime.utcnow()
self.assertEqual(locale.format_date(now - datetime.timedelta(seconds=2), full_format=False),
'2 seconds ago')
self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(minutes=2), full_format=False),
self.assertEqual(locale.format_date(now - datetime.timedelta(minutes=2), full_format=False),
'2 minutes ago')
self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(hours=2), full_format=False),
self.assertEqual(locale.format_date(now - datetime.timedelta(hours=2), full_format=False),
'2 hours ago')
now = datetime.datetime.utcnow()
self.assertEqual(locale.format_date(now - datetime.timedelta(days=1), full_format=False, shorter=True),
'yesterday')
self.assertEqual(locale.format_date(now - datetime.timedelta(days=1),
full_format=False, shorter=True), 'yesterday')
date = now - datetime.timedelta(days=2)
self.assertEqual(locale.format_date(date, full_format=False, shorter=True),

View file

@ -11,7 +11,7 @@
# under the License.
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
from datetime import timedelta
from tornado import gen, locks
@ -35,6 +35,16 @@ class ConditionTest(AsyncTestCase):
self.history.append(key)
future.add_done_callback(callback)
def loop_briefly(self):
"""Run all queued callbacks on the IOLoop.
In these tests, this method is used after calling notify() to
preserve the pre-5.0 behavior in which callbacks ran
synchronously.
"""
self.io_loop.add_callback(self.stop)
self.wait()
def test_repr(self):
c = locks.Condition()
self.assertIn('Condition', repr(c))
@ -53,8 +63,10 @@ class ConditionTest(AsyncTestCase):
self.record_done(c.wait(), 'wait1')
self.record_done(c.wait(), 'wait2')
c.notify(1)
self.loop_briefly()
self.history.append('notify1')
c.notify(1)
self.loop_briefly()
self.history.append('notify2')
self.assertEqual(['wait1', 'notify1', 'wait2', 'notify2'],
self.history)
@ -65,12 +77,15 @@ class ConditionTest(AsyncTestCase):
self.record_done(c.wait(), i)
c.notify(3)
self.loop_briefly()
# Callbacks execute in the order they were registered.
self.assertEqual(list(range(3)), self.history)
c.notify(1)
self.loop_briefly()
self.assertEqual(list(range(4)), self.history)
c.notify(2)
self.loop_briefly()
self.assertEqual(list(range(6)), self.history)
def test_notify_all(self):
@ -79,6 +94,7 @@ class ConditionTest(AsyncTestCase):
self.record_done(c.wait(), i)
c.notify_all()
self.loop_briefly()
self.history.append('notify_all')
# Callbacks execute in the order they were registered.
@ -125,6 +141,7 @@ class ConditionTest(AsyncTestCase):
self.assertEqual(['timeout', 0, 2], self.history)
self.assertEqual(['timeout', 0, 2], self.history)
c.notify()
yield
self.assertEqual(['timeout', 0, 2, 3], self.history)
@gen_test
@ -139,6 +156,7 @@ class ConditionTest(AsyncTestCase):
self.assertEqual(['timeout'], self.history)
c.notify_all()
yield
self.assertEqual(['timeout', 0, 2], self.history)
@gen_test
@ -154,6 +172,7 @@ class ConditionTest(AsyncTestCase):
# resolving third future.
futures[1].add_done_callback(lambda _: c.notify())
c.notify(2)
yield
self.assertTrue(all(f.done() for f in futures))
@gen_test

View file

@ -1,4 +1,3 @@
#!/usr/bin/env python
#
# Copyright 2012 Facebook
#
@ -13,7 +12,7 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import contextlib
import glob
@ -29,7 +28,7 @@ from tornado.escape import utf8
from tornado.log import LogFormatter, define_logging_options, enable_pretty_logging
from tornado.options import OptionParser
from tornado.test.util import unittest
from tornado.util import u, basestring_type
from tornado.util import basestring_type
@contextlib.contextmanager
@ -42,7 +41,8 @@ def ignore_bytes_warning():
class LogFormatterTest(unittest.TestCase):
# Matches the output of a single logging call (which may be multiple lines
# if a traceback was included, so we use the DOTALL option)
LINE_RE = re.compile(b"(?s)\x01\\[E [0-9]{6} [0-9]{2}:[0-9]{2}:[0-9]{2} log_test:[0-9]+\\]\x02 (.*)")
LINE_RE = re.compile(
b"(?s)\x01\\[E [0-9]{6} [0-9]{2}:[0-9]{2}:[0-9]{2} log_test:[0-9]+\\]\x02 (.*)")
def setUp(self):
self.formatter = LogFormatter(color=False)
@ -51,9 +51,9 @@ class LogFormatterTest(unittest.TestCase):
# for testing. (testing with color off fails to expose some potential
# encoding issues from the control characters)
self.formatter._colors = {
logging.ERROR: u("\u0001"),
logging.ERROR: u"\u0001",
}
self.formatter._normal = u("\u0002")
self.formatter._normal = u"\u0002"
# construct a Logger directly to bypass getLogger's caching
self.logger = logging.Logger('LogFormatterTest')
self.logger.propagate = False
@ -96,16 +96,16 @@ class LogFormatterTest(unittest.TestCase):
def test_utf8_logging(self):
with ignore_bytes_warning():
self.logger.error(u("\u00e9").encode("utf8"))
self.logger.error(u"\u00e9".encode("utf8"))
if issubclass(bytes, basestring_type):
# on python 2, utf8 byte strings (and by extension ascii byte
# strings) are passed through as-is.
self.assertEqual(self.get_output(), utf8(u("\u00e9")))
self.assertEqual(self.get_output(), utf8(u"\u00e9"))
else:
# on python 3, byte strings always get repr'd even if
# they're ascii-only, so this degenerates into another
# copy of test_bytes_logging.
self.assertEqual(self.get_output(), utf8(repr(utf8(u("\u00e9")))))
self.assertEqual(self.get_output(), utf8(repr(utf8(u"\u00e9"))))
def test_bytes_exception_logging(self):
try:
@ -128,8 +128,8 @@ class UnicodeLogFormatterTest(LogFormatterTest):
return logging.FileHandler(filename, encoding="utf8")
def test_unicode_logging(self):
self.logger.error(u("\u00e9"))
self.assertEqual(self.get_output(), utf8(u("\u00e9")))
self.logger.error(u"\u00e9")
self.assertEqual(self.get_output(), utf8(u"\u00e9"))
class EnablePrettyLoggingTest(unittest.TestCase):

View file

@ -1,5 +1,6 @@
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import errno
import os
import signal
import socket
@ -7,10 +8,12 @@ from subprocess import Popen
import sys
import time
from tornado.netutil import BlockingResolver, ThreadedResolver, is_valid_ip, bind_sockets
from tornado.netutil import (
BlockingResolver, OverrideResolver, ThreadedResolver, is_valid_ip, bind_sockets
)
from tornado.stack_context import ExceptionStackContext
from tornado.testing import AsyncTestCase, gen_test, bind_unused_port
from tornado.test.util import unittest, skipIfNoNetwork
from tornado.test.util import unittest, skipIfNoNetwork, ignore_deprecation
try:
from concurrent import futures
@ -18,15 +21,15 @@ except ImportError:
futures = None
try:
import pycares
import pycares # type: ignore
except ImportError:
pycares = None
else:
from tornado.platform.caresresolver import CaresResolver
try:
import twisted
import twisted.names
import twisted # type: ignore
import twisted.names # type: ignore
except ImportError:
twisted = None
else:
@ -35,7 +38,8 @@ else:
class _ResolverTestMixin(object):
def test_localhost(self):
self.resolver.resolve('localhost', 80, callback=self.stop)
with ignore_deprecation():
self.resolver.resolve('localhost', 80, callback=self.stop)
result = self.wait()
self.assertIn((socket.AF_INET, ('127.0.0.1', 80)), result)
@ -55,29 +59,30 @@ class _ResolverErrorTestMixin(object):
self.stop(exc_val)
return True # Halt propagation.
with ExceptionStackContext(handler):
self.resolver.resolve('an invalid domain', 80, callback=self.stop)
with ignore_deprecation():
with ExceptionStackContext(handler):
self.resolver.resolve('an invalid domain', 80, callback=self.stop)
result = self.wait()
self.assertIsInstance(result, Exception)
@gen_test
def test_future_interface_bad_host(self):
with self.assertRaises(Exception):
with self.assertRaises(IOError):
yield self.resolver.resolve('an invalid domain', 80,
socket.AF_UNSPEC)
def _failing_getaddrinfo(*args):
"""Dummy implementation of getaddrinfo for use in mocks"""
raise socket.gaierror("mock: lookup failed")
raise socket.gaierror(errno.EIO, "mock: lookup failed")
@skipIfNoNetwork
class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(BlockingResolverTest, self).setUp()
self.resolver = BlockingResolver(io_loop=self.io_loop)
self.resolver = BlockingResolver()
# getaddrinfo-based tests need mocking to reliably generate errors;
@ -86,7 +91,7 @@ class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
class BlockingResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
def setUp(self):
super(BlockingResolverErrorTest, self).setUp()
self.resolver = BlockingResolver(io_loop=self.io_loop)
self.resolver = BlockingResolver()
self.real_getaddrinfo = socket.getaddrinfo
socket.getaddrinfo = _failing_getaddrinfo
@ -95,12 +100,31 @@ class BlockingResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
super(BlockingResolverErrorTest, self).tearDown()
class OverrideResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(OverrideResolverTest, self).setUp()
mapping = {
('google.com', 80): ('1.2.3.4', 80),
('google.com', 80, socket.AF_INET): ('1.2.3.4', 80),
('google.com', 80, socket.AF_INET6): ('2a02:6b8:7c:40c:c51e:495f:e23a:3', 80)
}
self.resolver = OverrideResolver(BlockingResolver(), mapping)
@gen_test
def test_resolve_multiaddr(self):
result = yield self.resolver.resolve('google.com', 80, socket.AF_INET)
self.assertIn((socket.AF_INET, ('1.2.3.4', 80)), result)
result = yield self.resolver.resolve('google.com', 80, socket.AF_INET6)
self.assertIn((socket.AF_INET6, ('2a02:6b8:7c:40c:c51e:495f:e23a:3', 80, 0, 0)), result)
@skipIfNoNetwork
@unittest.skipIf(futures is None, "futures module not present")
class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(ThreadedResolverTest, self).setUp()
self.resolver = ThreadedResolver(io_loop=self.io_loop)
self.resolver = ThreadedResolver()
def tearDown(self):
self.resolver.close()
@ -110,7 +134,7 @@ class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
class ThreadedResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
def setUp(self):
super(ThreadedResolverErrorTest, self).setUp()
self.resolver = BlockingResolver(io_loop=self.io_loop)
self.resolver = BlockingResolver()
self.real_getaddrinfo = socket.getaddrinfo
socket.getaddrinfo = _failing_getaddrinfo
@ -157,19 +181,23 @@ class ThreadedResolverImportTest(unittest.TestCase):
class CaresResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(CaresResolverTest, self).setUp()
self.resolver = CaresResolver(io_loop=self.io_loop)
self.resolver = CaresResolver()
# TwistedResolver produces consistent errors in our test cases so we
# can test the regular and error cases in the same class.
# could test the regular and error cases in the same class. However,
# in the error cases it appears that cleanup of socket objects is
# handled asynchronously and occasionally results in "unclosed socket"
# warnings if not given time to shut down (and there is no way to
# explicitly shut it down). This makes the test flaky, so we do not
# test error cases here.
@skipIfNoNetwork
@unittest.skipIf(twisted is None, "twisted module not present")
@unittest.skipIf(getattr(twisted, '__version__', '0.0') < "12.1", "old version of twisted")
class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin,
_ResolverErrorTestMixin):
class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(TwistedResolverTest, self).setUp()
self.resolver = TwistedResolver(io_loop=self.io_loop)
self.resolver = TwistedResolver()
class IsValidIPTest(unittest.TestCase):
@ -203,9 +231,10 @@ class TestPortAllocation(unittest.TestCase):
@unittest.skipIf(not hasattr(socket, "SO_REUSEPORT"), "SO_REUSEPORT is not supported")
def test_reuse_port(self):
sockets = []
socket, port = bind_unused_port(reuse_port=True)
try:
sockets = bind_sockets(port, 'localhost', reuse_port=True)
sockets = bind_sockets(port, '127.0.0.1', reuse_port=True)
self.assertTrue(all(s.getsockname()[1] == port for s in sockets))
finally:
socket.close()

View file

@ -3,3 +3,5 @@ port=443
username='李康'
foo_bar='a'
my_path = __file__

View file

@ -1,28 +1,41 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import datetime
import os
import sys
from tornado.options import OptionParser, Error
from tornado.util import basestring_type
from tornado.test.util import unittest
from tornado.util import basestring_type, PY3
from tornado.test.util import unittest, subTest
if PY3:
from io import StringIO
else:
from cStringIO import StringIO
try:
from cStringIO import StringIO # python 2
except ImportError:
from io import StringIO # python 3
try:
from unittest import mock # python 3.3
# py33+
from unittest import mock # type: ignore
except ImportError:
try:
import mock # third-party mock package
import mock # type: ignore
except ImportError:
mock = None
class Email(object):
def __init__(self, value):
if isinstance(value, str) and '@' in value:
self._value = value
else:
raise ValueError()
@property
def value(self):
return self._value
class OptionsTest(unittest.TestCase):
def test_parse_command_line(self):
options = OptionParser()
@ -34,10 +47,13 @@ class OptionsTest(unittest.TestCase):
options = OptionParser()
options.define("port", default=80)
options.define("username", default='foo')
options.parse_config_file(os.path.join(os.path.dirname(__file__),
"options_test.cfg"))
self.assertEquals(options.port, 443)
options.define("my_path")
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"options_test.cfg")
options.parse_config_file(config_path)
self.assertEqual(options.port, 443)
self.assertEqual(options.username, "李康")
self.assertEqual(options.my_path, config_path)
def test_parse_callbacks(self):
options = OptionParser()
@ -131,6 +147,12 @@ class OptionsTest(unittest.TestCase):
options = self._sample_options()
self.assertEqual(1, options['a'])
def test_setitem(self):
options = OptionParser()
options.define('foo', default=1, type=int)
options['foo'] = 2
self.assertEqual(options['foo'], 2)
def test_items(self):
options = self._sample_options()
# OptionParsers always define 'help'.
@ -179,7 +201,7 @@ class OptionsTest(unittest.TestCase):
self.assertEqual(options.foo, 5)
self.assertEqual(options.foo, 2)
def test_types(self):
def _define_options(self):
options = OptionParser()
options.define('str', type=str)
options.define('basestring', type=basestring_type)
@ -187,13 +209,11 @@ class OptionsTest(unittest.TestCase):
options.define('float', type=float)
options.define('datetime', type=datetime.datetime)
options.define('timedelta', type=datetime.timedelta)
options.parse_command_line(['main.py',
'--str=asdf',
'--basestring=qwer',
'--int=42',
'--float=1.5',
'--datetime=2013-04-28 05:16',
'--timedelta=45s'])
options.define('email', type=Email)
options.define('list-of-int', type=int, multiple=True)
return options
def _check_options_values(self, options):
self.assertEqual(options.str, 'asdf')
self.assertEqual(options.basestring, 'qwer')
self.assertEqual(options.int, 42)
@ -201,6 +221,30 @@ class OptionsTest(unittest.TestCase):
self.assertEqual(options.datetime,
datetime.datetime(2013, 4, 28, 5, 16))
self.assertEqual(options.timedelta, datetime.timedelta(seconds=45))
self.assertEqual(options.email.value, 'tornado@web.com')
self.assertTrue(isinstance(options.email, Email))
self.assertEqual(options.list_of_int, [1, 2, 3])
def test_types(self):
options = self._define_options()
options.parse_command_line(['main.py',
'--str=asdf',
'--basestring=qwer',
'--int=42',
'--float=1.5',
'--datetime=2013-04-28 05:16',
'--timedelta=45s',
'--email=tornado@web.com',
'--list-of-int=1,2,3'])
self._check_options_values(options)
def test_types_with_conf_file(self):
for config_file_name in ("options_test_types.cfg",
"options_test_types_str.cfg"):
options = self._define_options()
options.parse_config_file(os.path.join(os.path.dirname(__file__),
config_file_name))
self._check_options_values(options)
def test_multiple_string(self):
options = OptionParser()
@ -222,6 +266,24 @@ class OptionsTest(unittest.TestCase):
self.assertRegexpMatches(str(cm.exception),
'Option.*foo.*already defined')
def test_error_redefine_underscore(self):
# Ensure that the dash/underscore normalization doesn't
# interfere with the redefinition error.
tests = [
('foo-bar', 'foo-bar'),
('foo_bar', 'foo_bar'),
('foo-bar', 'foo_bar'),
('foo_bar', 'foo-bar'),
]
for a, b in tests:
with subTest(self, a=a, b=b):
options = OptionParser()
options.define(a)
with self.assertRaises(Error) as cm:
options.define(b)
self.assertRegexpMatches(str(cm.exception),
'Option.*foo.bar.*already defined')
def test_dash_underscore_cli(self):
# Dashes and underscores should be interchangeable.
for defined_name in ['foo-bar', 'foo_bar']:

View file

@ -0,0 +1,11 @@
from datetime import datetime, timedelta
from tornado.test.options_test import Email
str = 'asdf'
basestring = 'qwer'
int = 42
float = 1.5
datetime = datetime(2013, 4, 28, 5, 16)
timedelta = timedelta(0, 45)
email = Email('tornado@web.com')
list_of_int = [1, 2, 3]

View file

@ -0,0 +1,8 @@
str = 'asdf'
basestring = 'qwer'
int = 42
float = 1.5
datetime = '2013-04-28 05:16'
timedelta = '45s'
email = 'tornado@web.com'
list_of_int = '1,2,3'

View file

@ -1,12 +1,11 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
from __future__ import absolute_import, division, print_function, with_statement
import logging
import os
import signal
import subprocess
import sys
from tornado.httpclient import HTTPClient, HTTPError
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
@ -17,12 +16,15 @@ from tornado.testing import bind_unused_port, ExpectLog, AsyncTestCase, gen_test
from tornado.test.util import unittest, skipIfNonUnix
from tornado.web import RequestHandler, Application
try:
import asyncio
except ImportError:
asyncio = None
def skip_if_twisted():
if IOLoop.configured_class().__name__.endswith(('TwistedIOLoop',
'AsyncIOMainLoop')):
raise unittest.SkipTest("Process tests not compatible with "
"TwistedIOLoop or AsyncIOMainLoop")
if IOLoop.configured_class().__name__.endswith('TwistedIOLoop'):
raise unittest.SkipTest("Process tests not compatible with TwistedIOLoop")
# Not using AsyncHTTPTestCase because we need control over the IOLoop.
@ -58,11 +60,12 @@ class ProcessTest(unittest.TestCase):
super(ProcessTest, self).tearDown()
def test_multi_process(self):
# This test can't work on twisted because we use the global reactor
# and have no way to get it back into a sane state after the fork.
# This test doesn't work on twisted because we use the global
# reactor and don't restore it to a sane state after the fork
# (asyncio has the same issue, but we have a special case in
# place for it).
skip_if_twisted()
with ExpectLog(gen_log, "(Starting .* processes|child .* exited|uncaught exception)"):
self.assertFalse(IOLoop.initialized())
sock, port = bind_unused_port()
def get_url(path):
@ -81,6 +84,10 @@ class ProcessTest(unittest.TestCase):
sock.close()
return
try:
if asyncio is not None:
# Reset the global asyncio event loop, which was put into
# a broken state by the fork.
asyncio.set_event_loop(asyncio.new_event_loop())
if id in (0, 1):
self.assertEqual(id, task_id())
server = HTTPServer(self.get_app())
@ -136,6 +143,7 @@ class ProcessTest(unittest.TestCase):
@skipIfNonUnix
class SubprocessTest(AsyncTestCase):
@gen_test
def test_subprocess(self):
if IOLoop.configured_class().__name__.endswith('LayeredTwistedIOLoop'):
# This test fails non-deterministically with LayeredTwistedIOLoop.
@ -147,54 +155,52 @@ class SubprocessTest(AsyncTestCase):
"LayeredTwistedIOLoop")
subproc = Subprocess([sys.executable, '-u', '-i'],
stdin=Subprocess.STREAM,
stdout=Subprocess.STREAM, stderr=subprocess.STDOUT,
io_loop=self.io_loop)
self.addCleanup(lambda: os.kill(subproc.pid, signal.SIGTERM))
subproc.stdout.read_until(b'>>> ', self.stop)
self.wait()
stdout=Subprocess.STREAM, stderr=subprocess.STDOUT)
self.addCleanup(lambda: (subproc.proc.terminate(), subproc.proc.wait()))
self.addCleanup(subproc.stdout.close)
self.addCleanup(subproc.stdin.close)
yield subproc.stdout.read_until(b'>>> ')
subproc.stdin.write(b"print('hello')\n")
subproc.stdout.read_until(b'\n', self.stop)
data = self.wait()
data = yield subproc.stdout.read_until(b'\n')
self.assertEqual(data, b"hello\n")
subproc.stdout.read_until(b">>> ", self.stop)
self.wait()
yield subproc.stdout.read_until(b">>> ")
subproc.stdin.write(b"raise SystemExit\n")
subproc.stdout.read_until_close(self.stop)
data = self.wait()
data = yield subproc.stdout.read_until_close()
self.assertEqual(data, b"")
@gen_test
def test_close_stdin(self):
# Close the parent's stdin handle and see that the child recognizes it.
subproc = Subprocess([sys.executable, '-u', '-i'],
stdin=Subprocess.STREAM,
stdout=Subprocess.STREAM, stderr=subprocess.STDOUT,
io_loop=self.io_loop)
self.addCleanup(lambda: os.kill(subproc.pid, signal.SIGTERM))
subproc.stdout.read_until(b'>>> ', self.stop)
self.wait()
stdout=Subprocess.STREAM, stderr=subprocess.STDOUT)
self.addCleanup(lambda: (subproc.proc.terminate(), subproc.proc.wait()))
yield subproc.stdout.read_until(b'>>> ')
subproc.stdin.close()
subproc.stdout.read_until_close(self.stop)
data = self.wait()
data = yield subproc.stdout.read_until_close()
self.assertEqual(data, b"\n")
@gen_test
def test_stderr(self):
# This test is mysteriously flaky on twisted: it succeeds, but logs
# an error of EBADF on closing a file descriptor.
skip_if_twisted()
subproc = Subprocess([sys.executable, '-u', '-c',
r"import sys; sys.stderr.write('hello\n')"],
stderr=Subprocess.STREAM,
io_loop=self.io_loop)
self.addCleanup(lambda: os.kill(subproc.pid, signal.SIGTERM))
subproc.stderr.read_until(b'\n', self.stop)
data = self.wait()
stderr=Subprocess.STREAM)
self.addCleanup(lambda: (subproc.proc.terminate(), subproc.proc.wait()))
data = yield subproc.stderr.read_until(b'\n')
self.assertEqual(data, b'hello\n')
# More mysterious EBADF: This fails if done with self.addCleanup instead of here.
subproc.stderr.close()
def test_sigchild(self):
# Twisted's SIGCHLD handler and Subprocess's conflict with each other.
skip_if_twisted()
Subprocess.initialize(io_loop=self.io_loop)
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, '-c', 'pass'],
io_loop=self.io_loop)
subproc = Subprocess([sys.executable, '-c', 'pass'])
subproc.set_exit_callback(self.stop)
ret = self.wait()
self.assertEqual(ret, 0)
@ -212,14 +218,31 @@ class SubprocessTest(AsyncTestCase):
def test_sigchild_signal(self):
skip_if_twisted()
Subprocess.initialize(io_loop=self.io_loop)
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, '-c',
'import time; time.sleep(30)'],
io_loop=self.io_loop)
stdout=Subprocess.STREAM)
self.addCleanup(subproc.stdout.close)
subproc.set_exit_callback(self.stop)
os.kill(subproc.pid, signal.SIGTERM)
ret = self.wait()
try:
ret = self.wait(timeout=1.0)
except AssertionError:
# We failed to get the termination signal. This test is
# occasionally flaky on pypy, so try to get a little more
# information: did the process close its stdout
# (indicating that the problem is in the parent process's
# signal handling) or did the child process somehow fail
# to terminate?
subproc.stdout.read_until_close(callback=self.stop)
try:
self.wait(timeout=1.0)
except AssertionError:
raise AssertionError("subprocess failed to terminate")
else:
raise AssertionError("subprocess closed stdout but failed to "
"get termination signal")
self.assertEqual(subproc.returncode, ret)
self.assertEqual(ret, -signal.SIGTERM)

View file

@ -11,7 +11,7 @@
# under the License.
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
from datetime import timedelta
from random import random

View file

@ -1,7 +1,6 @@
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
from tornado.ioloop import IOLoop
from tornado.netutil import ThreadedResolver
from tornado.util import u
# When this module is imported, it runs getaddrinfo on a thread. Since
# the hostname is unicode, getaddrinfo attempts to import encodings.idna
@ -9,4 +8,4 @@ from tornado.util import u
# this deadlock.
resolver = ThreadedResolver()
IOLoop.current().run_sync(lambda: resolver.resolve(u('localhost'), 80))
IOLoop.current().run_sync(lambda: resolver.resolve(u'localhost', 80))

View file

@ -0,0 +1,247 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function
from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine # noqa: E501
from tornado.routing import HostMatches, PathMatches, ReversibleRouter, Router, Rule, RuleRouter
from tornado.testing import AsyncHTTPTestCase
from tornado.web import Application, HTTPError, RequestHandler
from tornado.wsgi import WSGIContainer
class BasicRouter(Router):
def find_handler(self, request, **kwargs):
class MessageDelegate(HTTPMessageDelegate):
def __init__(self, connection):
self.connection = connection
def finish(self):
self.connection.write_headers(
ResponseStartLine("HTTP/1.1", 200, "OK"),
HTTPHeaders({"Content-Length": "2"}),
b"OK"
)
self.connection.finish()
return MessageDelegate(request.connection)
class BasicRouterTestCase(AsyncHTTPTestCase):
def get_app(self):
return BasicRouter()
def test_basic_router(self):
response = self.fetch("/any_request")
self.assertEqual(response.body, b"OK")
resources = {}
class GetResource(RequestHandler):
def get(self, path):
if path not in resources:
raise HTTPError(404)
self.finish(resources[path])
class PostResource(RequestHandler):
def post(self, path):
resources[path] = self.request.body
class HTTPMethodRouter(Router):
def __init__(self, app):
self.app = app
def find_handler(self, request, **kwargs):
handler = GetResource if request.method == "GET" else PostResource
return self.app.get_handler_delegate(request, handler, path_args=[request.path])
class HTTPMethodRouterTestCase(AsyncHTTPTestCase):
def get_app(self):
return HTTPMethodRouter(Application())
def test_http_method_router(self):
response = self.fetch("/post_resource", method="POST", body="data")
self.assertEqual(response.code, 200)
response = self.fetch("/get_resource")
self.assertEqual(response.code, 404)
response = self.fetch("/post_resource")
self.assertEqual(response.code, 200)
self.assertEqual(response.body, b"data")
def _get_named_handler(handler_name):
class Handler(RequestHandler):
def get(self, *args, **kwargs):
if self.application.settings.get("app_name") is not None:
self.write(self.application.settings["app_name"] + ": ")
self.finish(handler_name + ": " + self.reverse_url(handler_name))
return Handler
FirstHandler = _get_named_handler("first_handler")
SecondHandler = _get_named_handler("second_handler")
class CustomRouter(ReversibleRouter):
def __init__(self):
super(CustomRouter, self).__init__()
self.routes = {}
def add_routes(self, routes):
self.routes.update(routes)
def find_handler(self, request, **kwargs):
if request.path in self.routes:
app, handler = self.routes[request.path]
return app.get_handler_delegate(request, handler)
def reverse_url(self, name, *args):
handler_path = '/' + name
return handler_path if handler_path in self.routes else None
class CustomRouterTestCase(AsyncHTTPTestCase):
def get_app(self):
class CustomApplication(Application):
def reverse_url(self, name, *args):
return router.reverse_url(name, *args)
router = CustomRouter()
app1 = CustomApplication(app_name="app1")
app2 = CustomApplication(app_name="app2")
router.add_routes({
"/first_handler": (app1, FirstHandler),
"/second_handler": (app2, SecondHandler),
"/first_handler_second_app": (app2, FirstHandler),
})
return router
def test_custom_router(self):
response = self.fetch("/first_handler")
self.assertEqual(response.body, b"app1: first_handler: /first_handler")
response = self.fetch("/second_handler")
self.assertEqual(response.body, b"app2: second_handler: /second_handler")
response = self.fetch("/first_handler_second_app")
self.assertEqual(response.body, b"app2: first_handler: /first_handler")
class ConnectionDelegate(HTTPServerConnectionDelegate):
def start_request(self, server_conn, request_conn):
class MessageDelegate(HTTPMessageDelegate):
def __init__(self, connection):
self.connection = connection
def finish(self):
response_body = b"OK"
self.connection.write_headers(
ResponseStartLine("HTTP/1.1", 200, "OK"),
HTTPHeaders({"Content-Length": str(len(response_body))}))
self.connection.write(response_body)
self.connection.finish()
return MessageDelegate(request_conn)
class RuleRouterTest(AsyncHTTPTestCase):
def get_app(self):
app = Application()
def request_callable(request):
request.connection.write_headers(
ResponseStartLine("HTTP/1.1", 200, "OK"),
HTTPHeaders({"Content-Length": "2"}))
request.connection.write(b"OK")
request.connection.finish()
router = CustomRouter()
router.add_routes({
"/nested_handler": (app, _get_named_handler("nested_handler"))
})
app.add_handlers(".*", [
(HostMatches("www.example.com"), [
(PathMatches("/first_handler"),
"tornado.test.routing_test.SecondHandler", {}, "second_handler")
]),
Rule(PathMatches("/.*handler"), router),
Rule(PathMatches("/first_handler"), FirstHandler, name="first_handler"),
Rule(PathMatches("/request_callable"), request_callable),
("/connection_delegate", ConnectionDelegate())
])
return app
def test_rule_based_router(self):
response = self.fetch("/first_handler")
self.assertEqual(response.body, b"first_handler: /first_handler")
response = self.fetch("/first_handler", headers={'Host': 'www.example.com'})
self.assertEqual(response.body, b"second_handler: /first_handler")
response = self.fetch("/nested_handler")
self.assertEqual(response.body, b"nested_handler: /nested_handler")
response = self.fetch("/nested_not_found_handler")
self.assertEqual(response.code, 404)
response = self.fetch("/connection_delegate")
self.assertEqual(response.body, b"OK")
response = self.fetch("/request_callable")
self.assertEqual(response.body, b"OK")
response = self.fetch("/404")
self.assertEqual(response.code, 404)
class WSGIContainerTestCase(AsyncHTTPTestCase):
def get_app(self):
wsgi_app = WSGIContainer(self.wsgi_app)
class Handler(RequestHandler):
def get(self, *args, **kwargs):
self.finish(self.reverse_url("tornado"))
return RuleRouter([
(PathMatches("/tornado.*"), Application([(r"/tornado/test", Handler, {}, "tornado")])),
(PathMatches("/wsgi"), wsgi_app),
])
def wsgi_app(self, environ, start_response):
start_response("200 OK", [])
return [b"WSGI"]
def test_wsgi_container(self):
response = self.fetch("/tornado/test")
self.assertEqual(response.body, b"/tornado/test")
response = self.fetch("/wsgi")
self.assertEqual(response.body, b"WSGI")
def test_delegate_not_found(self):
response = self.fetch("/404")
self.assertEqual(response.code, 404)

View file

@ -1,12 +1,13 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
from __future__ import absolute_import, division, print_function, with_statement
import gc
import io
import locale # system locale module, not tornado.locale
import logging
import operator
import textwrap
import sys
from tornado.httpclient import AsyncHTTPClient
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
@ -25,10 +26,12 @@ TEST_MODULES = [
'tornado.util.doctests',
'tornado.test.asyncio_test',
'tornado.test.auth_test',
'tornado.test.autoreload_test',
'tornado.test.concurrent_test',
'tornado.test.curl_httpclient_test',
'tornado.test.escape_test',
'tornado.test.gen_test',
'tornado.test.http1connection_test',
'tornado.test.httpclient_test',
'tornado.test.httpserver_test',
'tornado.test.httputil_test',
@ -42,6 +45,7 @@ TEST_MODULES = [
'tornado.test.options_test',
'tornado.test.process_test',
'tornado.test.queues_test',
'tornado.test.routing_test',
'tornado.test.simple_httpclient_test',
'tornado.test.stack_context_test',
'tornado.test.tcpclient_test',
@ -52,6 +56,7 @@ TEST_MODULES = [
'tornado.test.util_test',
'tornado.test.web_test',
'tornado.test.websocket_test',
'tornado.test.windows_test',
'tornado.test.wsgi_test',
]
@ -60,16 +65,21 @@ def all():
return unittest.defaultTestLoader.loadTestsFromNames(TEST_MODULES)
class TornadoTextTestRunner(unittest.TextTestRunner):
def run(self, test):
result = super(TornadoTextTestRunner, self).run(test)
if result.skipped:
skip_reasons = set(reason for (test, reason) in result.skipped)
self.stream.write(textwrap.fill(
"Some tests were skipped because: %s" %
", ".join(sorted(skip_reasons))))
self.stream.write("\n")
return result
def test_runner_factory(stderr):
class TornadoTextTestRunner(unittest.TextTestRunner):
def __init__(self, *args, **kwargs):
super(TornadoTextTestRunner, self).__init__(*args, stream=stderr, **kwargs)
def run(self, test):
result = super(TornadoTextTestRunner, self).run(test)
if result.skipped:
skip_reasons = set(reason for (test, reason) in result.skipped)
self.stream.write(textwrap.fill(
"Some tests were skipped because: %s" %
", ".join(sorted(skip_reasons))))
self.stream.write("\n")
return result
return TornadoTextTestRunner
class LogCounter(logging.Filter):
@ -77,16 +87,31 @@ class LogCounter(logging.Filter):
def __init__(self, *args, **kwargs):
# Can't use super() because logging.Filter is an old-style class in py26
logging.Filter.__init__(self, *args, **kwargs)
self.warning_count = self.error_count = 0
self.info_count = self.warning_count = self.error_count = 0
def filter(self, record):
if record.levelno >= logging.ERROR:
self.error_count += 1
elif record.levelno >= logging.WARNING:
self.warning_count += 1
elif record.levelno >= logging.INFO:
self.info_count += 1
return True
class CountingStderr(io.IOBase):
def __init__(self, real):
self.real = real
self.byte_count = 0
def write(self, data):
self.byte_count += len(data)
return self.real.write(data)
def flush(self):
return self.real.flush()
def main():
# The -W command-line option does not work in a virtualenv with
# python 3 (as of virtualenv 1.7), so configure warnings
@ -112,13 +137,18 @@ def main():
# 2.7 and 3.2
warnings.filterwarnings("ignore", category=DeprecationWarning,
message="Please use assert.* instead")
# unittest2 0.6 on py26 reports these as PendingDeprecationWarnings
# instead of DeprecationWarnings.
warnings.filterwarnings("ignore", category=PendingDeprecationWarning,
message="Please use assert.* instead")
# Twisted 15.0.0 triggers some warnings on py3 with -bb.
warnings.filterwarnings("ignore", category=BytesWarning,
module=r"twisted\..*")
if (3,) < sys.version_info < (3, 6):
# Prior to 3.6, async ResourceWarnings were rather noisy
# and even
# `python3.4 -W error -c 'import asyncio; asyncio.get_event_loop()'`
# would generate a warning.
warnings.filterwarnings("ignore", category=ResourceWarning, # noqa: F821
module=r"asyncio\..*")
logging.getLogger("tornado.access").setLevel(logging.CRITICAL)
@ -154,6 +184,12 @@ def main():
add_parse_callback(
lambda: logging.getLogger().handlers[0].addFilter(log_counter))
# Certain errors (especially "unclosed resource" errors raised in
# destructors) go directly to stderr instead of logging. Count
# anything written by anything but the test runner as an error.
orig_stderr = sys.stderr
sys.stderr = CountingStderr(orig_stderr)
import tornado.testing
kwargs = {}
if sys.version_info >= (3, 2):
@ -163,17 +199,21 @@ def main():
# suppresses this behavior, although this looks like an implementation
# detail. http://bugs.python.org/issue15626
kwargs['warnings'] = False
kwargs['testRunner'] = TornadoTextTestRunner
kwargs['testRunner'] = test_runner_factory(orig_stderr)
try:
tornado.testing.main(**kwargs)
finally:
# The tests should run clean; consider it a failure if they logged
# any warnings or errors. We'd like to ban info logs too, but
# we can't count them cleanly due to interactions with LogTrapTestCase.
if log_counter.warning_count > 0 or log_counter.error_count > 0:
logging.error("logged %d warnings and %d errors",
log_counter.warning_count, log_counter.error_count)
# The tests should run clean; consider it a failure if they
# logged anything at info level or above.
if (log_counter.info_count > 0 or
log_counter.warning_count > 0 or
log_counter.error_count > 0 or
sys.stderr.byte_count > 0):
logging.error("logged %d infos, %d warnings, %d errors, and %d bytes to stderr",
log_counter.info_count, log_counter.warning_count,
log_counter.error_count, sys.stderr.byte_count)
sys.exit(1)
if __name__ == '__main__':
main()

View file

@ -1,4 +1,4 @@
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import collections
from contextlib import closing
@ -11,25 +11,28 @@ import socket
import ssl
import sys
from tornado.escape import to_unicode
from tornado.escape import to_unicode, utf8
from tornado import gen
from tornado.httpclient import AsyncHTTPClient
from tornado.httputil import HTTPHeaders, ResponseStartLine
from tornado.ioloop import IOLoop
from tornado.iostream import UnsatisfiableReadError
from tornado.locks import Event
from tornado.log import gen_log
from tornado.concurrent import Future
from tornado.netutil import Resolver, bind_sockets
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler, RedirectHandler
from tornado.simple_httpclient import SimpleAsyncHTTPClient, HTTPStreamClosedError, HTTPTimeoutError
from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler, RedirectHandler # noqa: E501
from tornado.test import httpclient_test
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog
from tornado.test.util import skipOnTravis, skipIfNoIPv6, refusing_port, unittest, skipBefore35, exec_test
from tornado.web import RequestHandler, Application, asynchronous, url, stream_request_body
from tornado.testing import (AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase,
ExpectLog, gen_test)
from tornado.test.util import skipOnTravis, skipIfNoIPv6, refusing_port, skipBefore35, exec_test
from tornado.web import RequestHandler, Application, url, stream_request_body
class SimpleHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
def get_http_client(self):
client = SimpleAsyncHTTPClient(io_loop=self.io_loop,
force_instance=True)
client = SimpleAsyncHTTPClient(force_instance=True)
self.assertTrue(isinstance(client, SimpleAsyncHTTPClient))
return client
@ -39,24 +42,33 @@ class TriggerHandler(RequestHandler):
self.queue = queue
self.wake_callback = wake_callback
@asynchronous
@gen.coroutine
def get(self):
logging.debug("queuing trigger")
self.queue.append(self.finish)
if self.get_argument("wake", "true") == "true":
self.wake_callback()
never_finish = Event()
yield never_finish.wait()
class HangHandler(RequestHandler):
@asynchronous
@gen.coroutine
def get(self):
pass
never_finish = Event()
yield never_finish.wait()
class ContentLengthHandler(RequestHandler):
def get(self):
self.set_header("Content-Length", self.get_argument("value"))
self.write("ok")
self.stream = self.detach()
IOLoop.current().spawn_callback(self.write_response)
@gen.coroutine
def write_response(self):
yield self.stream.write(utf8("HTTP/1.0 200 OK\r\nContent-Length: %s\r\n\r\nok" %
self.get_argument("value")))
self.stream.close()
class HeadHandler(RequestHandler):
@ -72,10 +84,8 @@ class OptionsHandler(RequestHandler):
class NoContentHandler(RequestHandler):
def get(self):
if self.get_argument("error", None):
self.set_header("Content-Length", "5")
self.write("hello")
self.set_status(204)
self.finish()
class SeeOtherPostHandler(RequestHandler):
@ -99,13 +109,12 @@ class HostEchoHandler(RequestHandler):
class NoContentLengthHandler(RequestHandler):
@asynchronous
def get(self):
if self.request.version.startswith('HTTP/1'):
# Emulate the old HTTP/1.0 behavior of returning a body with no
# content-length. Tornado handles content-length at the framework
# level so we have to go around it.
stream = self.request.connection.detach()
stream = self.detach()
stream.write(b"HTTP/1.0 200 OK\r\n\r\n"
b"hello")
stream.close()
@ -151,16 +160,16 @@ class SimpleHTTPClientTestMixin(object):
def test_singleton(self):
# Class "constructor" reuses objects on the same IOLoop
self.assertTrue(SimpleAsyncHTTPClient(self.io_loop) is
SimpleAsyncHTTPClient(self.io_loop))
self.assertTrue(SimpleAsyncHTTPClient() is
SimpleAsyncHTTPClient())
# unless force_instance is used
self.assertTrue(SimpleAsyncHTTPClient(self.io_loop) is not
SimpleAsyncHTTPClient(self.io_loop,
force_instance=True))
self.assertTrue(SimpleAsyncHTTPClient() is not
SimpleAsyncHTTPClient(force_instance=True))
# different IOLoops use different objects
with closing(IOLoop()) as io_loop2:
self.assertTrue(SimpleAsyncHTTPClient(self.io_loop) is not
SimpleAsyncHTTPClient(io_loop2))
client1 = self.io_loop.run_sync(gen.coroutine(SimpleAsyncHTTPClient))
client2 = io_loop2.run_sync(gen.coroutine(SimpleAsyncHTTPClient))
self.assertTrue(client1 is not client2)
def test_connection_limit(self):
with closing(self.create_client(max_clients=2)) as client:
@ -169,8 +178,8 @@ class SimpleHTTPClientTestMixin(object):
# Send 4 requests. Two can be sent immediately, while the others
# will be queued
for i in range(4):
client.fetch(self.get_url("/trigger"),
lambda response, i=i: (seen.append(i), self.stop()))
client.fetch(self.get_url("/trigger")).add_done_callback(
lambda fut, i=i: (seen.append(i), self.stop()))
self.wait(condition=lambda: len(self.triggers) == 2)
self.assertEqual(len(client.queue), 2)
@ -189,12 +198,12 @@ class SimpleHTTPClientTestMixin(object):
self.assertEqual(set(seen), set([0, 1, 2, 3]))
self.assertEqual(len(self.triggers), 0)
@gen_test
def test_redirect_connection_limit(self):
# following redirects should not consume additional connections
with closing(self.create_client(max_clients=1)) as client:
client.fetch(self.get_url('/countdown/3'), self.stop,
max_redirects=3)
response = self.wait()
response = yield client.fetch(self.get_url('/countdown/3'),
max_redirects=3)
response.rethrow()
def test_gzip(self):
@ -237,55 +246,58 @@ class SimpleHTTPClientTestMixin(object):
# request is the original request, is a POST still
self.assertEqual("POST", response.request.method)
@skipOnTravis
@gen_test
def test_connect_timeout(self):
timeout = 0.1
class TimeoutResolver(Resolver):
def resolve(self, *args, **kwargs):
return Future() # never completes
with closing(self.create_client(resolver=TimeoutResolver())) as client:
with self.assertRaises(HTTPTimeoutError):
yield client.fetch(self.get_url('/hello'),
connect_timeout=timeout,
request_timeout=3600,
raise_error=True)
@skipOnTravis
def test_request_timeout(self):
timeout = 0.1
timeout_min, timeout_max = 0.099, 0.15
if os.name == 'nt':
timeout = 0.5
timeout_min, timeout_max = 0.4, 0.6
response = self.fetch('/trigger?wake=false', request_timeout=timeout)
self.assertEqual(response.code, 599)
self.assertTrue(timeout_min < response.request_time < timeout_max,
response.request_time)
self.assertEqual(str(response.error), "HTTP 599: Timeout")
with self.assertRaises(HTTPTimeoutError):
self.fetch('/trigger?wake=false', request_timeout=timeout, raise_error=True)
# trigger the hanging request to let it clean up after itself
self.triggers.popleft()()
@skipIfNoIPv6
def test_ipv6(self):
try:
[sock] = bind_sockets(None, '::1', family=socket.AF_INET6)
port = sock.getsockname()[1]
self.http_server.add_socket(sock)
except socket.gaierror as e:
if e.args[0] == socket.EAI_ADDRFAMILY:
# python supports ipv6, but it's not configured on the network
# interface, so skip this test.
return
raise
[sock] = bind_sockets(None, '::1', family=socket.AF_INET6)
port = sock.getsockname()[1]
self.http_server.add_socket(sock)
url = '%s://[::1]:%d/hello' % (self.get_protocol(), port)
# ipv6 is currently enabled by default but can be disabled
self.http_client.fetch(url, self.stop, allow_ipv6=False)
response = self.wait()
self.assertEqual(response.code, 599)
with self.assertRaises(Exception):
self.fetch(url, allow_ipv6=False, raise_error=True)
self.http_client.fetch(url, self.stop)
response = self.wait()
response = self.fetch(url)
self.assertEqual(response.body, b"Hello world!")
def xtest_multiple_content_length_accepted(self):
def test_multiple_content_length_accepted(self):
response = self.fetch("/content_length?value=2,2")
self.assertEqual(response.body, b"ok")
response = self.fetch("/content_length?value=2,%202,2")
self.assertEqual(response.body, b"ok")
response = self.fetch("/content_length?value=2,4")
self.assertEqual(response.code, 599)
response = self.fetch("/content_length?value=2,%202,3")
self.assertEqual(response.code, 599)
with ExpectLog(gen_log, ".*Multiple unequal Content-Lengths"):
with self.assertRaises(HTTPStreamClosedError):
self.fetch("/content_length?value=2,4", raise_error=True)
with self.assertRaises(HTTPStreamClosedError):
self.fetch("/content_length?value=2,%202,3", raise_error=True)
def test_head_request(self):
response = self.fetch("/head", method="HEAD")
@ -303,63 +315,52 @@ class SimpleHTTPClientTestMixin(object):
def test_no_content(self):
response = self.fetch("/no_content")
self.assertEqual(response.code, 204)
# 204 status doesn't need a content-length, but tornado will
# add a zero content-length anyway.
# 204 status shouldn't have a content-length
#
# A test without a content-length header is included below
# Tests with a content-length header are included below
# in HTTP204NoContentTestCase.
self.assertEqual(response.headers["Content-length"], "0")
# 204 status with non-zero content length is malformed
with ExpectLog(gen_log, "Malformed HTTP message"):
response = self.fetch("/no_content?error=1")
self.assertEqual(response.code, 599)
self.assertNotIn("Content-Length", response.headers)
def test_host_header(self):
host_re = re.compile(b"^localhost:[0-9]+$")
host_re = re.compile(b"^127.0.0.1:[0-9]+$")
response = self.fetch("/host_echo")
self.assertTrue(host_re.match(response.body))
url = self.get_url("/host_echo").replace("http://", "http://me:secret@")
self.http_client.fetch(url, self.stop)
response = self.wait()
response = self.fetch(url)
self.assertTrue(host_re.match(response.body), response.body)
def test_connection_refused(self):
cleanup_func, port = refusing_port()
self.addCleanup(cleanup_func)
with ExpectLog(gen_log, ".*", required=False):
self.http_client.fetch("http://127.0.0.1:%d/" % port, self.stop)
response = self.wait()
self.assertEqual(599, response.code)
with self.assertRaises(socket.error) as cm:
self.fetch("http://127.0.0.1:%d/" % port, raise_error=True)
if sys.platform != 'cygwin':
# cygwin returns EPERM instead of ECONNREFUSED here
contains_errno = str(errno.ECONNREFUSED) in str(response.error)
contains_errno = str(errno.ECONNREFUSED) in str(cm.exception)
if not contains_errno and hasattr(errno, "WSAECONNREFUSED"):
contains_errno = str(errno.WSAECONNREFUSED) in str(response.error)
self.assertTrue(contains_errno, response.error)
contains_errno = str(errno.WSAECONNREFUSED) in str(cm.exception)
self.assertTrue(contains_errno, cm.exception)
# This is usually "Connection refused".
# On windows, strerror is broken and returns "Unknown error".
expected_message = os.strerror(errno.ECONNREFUSED)
self.assertTrue(expected_message in str(response.error),
response.error)
self.assertTrue(expected_message in str(cm.exception),
cm.exception)
def test_queue_timeout(self):
with closing(self.create_client(max_clients=1)) as client:
client.fetch(self.get_url('/trigger'), self.stop,
request_timeout=10)
# Wait for the trigger request to block, not complete.
fut1 = client.fetch(self.get_url('/trigger'), request_timeout=10)
self.wait()
client.fetch(self.get_url('/hello'), self.stop,
connect_timeout=0.1)
response = self.wait()
with self.assertRaises(HTTPTimeoutError) as cm:
self.io_loop.run_sync(lambda: client.fetch(
self.get_url('/hello'), connect_timeout=0.1, raise_error=True))
self.assertEqual(response.code, 599)
self.assertTrue(response.request_time < 1, response.request_time)
self.assertEqual(str(response.error), "HTTP 599: Timeout")
self.assertEqual(str(cm.exception), "Timeout in request queue")
self.triggers.popleft()()
self.wait()
self.io_loop.run_sync(lambda: fut1)
def test_no_content_length(self):
response = self.fetch("/no_content_length")
@ -375,7 +376,7 @@ class SimpleHTTPClientTestMixin(object):
@gen.coroutine
def async_body_producer(self, write):
yield write(b'1234')
yield gen.Task(IOLoop.current().add_callback)
yield gen.moment
yield write(b'5678')
def test_sync_body_producer_chunked(self):
@ -409,7 +410,8 @@ class SimpleHTTPClientTestMixin(object):
namespace = exec_test(globals(), locals(), """
async def body_producer(write):
await write(b'1234')
await gen.Task(IOLoop.current().add_callback)
import asyncio
await asyncio.sleep(0)
await write(b'5678')
""")
response = self.fetch("/echo_post", method="POST",
@ -422,7 +424,8 @@ class SimpleHTTPClientTestMixin(object):
namespace = exec_test(globals(), locals(), """
async def body_producer(write):
await write(b'1234')
await gen.Task(IOLoop.current().add_callback)
import asyncio
await asyncio.sleep(0)
await write(b'5678')
""")
response = self.fetch("/echo_post", method="POST",
@ -470,8 +473,7 @@ class SimpleHTTPClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPTestCase):
self.http_client = self.create_client()
def create_client(self, **kwargs):
return SimpleAsyncHTTPClient(self.io_loop, force_instance=True,
**kwargs)
return SimpleAsyncHTTPClient(force_instance=True, **kwargs)
class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase):
@ -480,7 +482,7 @@ class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase):
self.http_client = self.create_client()
def create_client(self, **kwargs):
return SimpleAsyncHTTPClient(self.io_loop, force_instance=True,
return SimpleAsyncHTTPClient(force_instance=True,
defaults=dict(validate_cert=False),
**kwargs)
@ -488,8 +490,6 @@ class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase):
resp = self.fetch("/hello", ssl_options={})
self.assertEqual(resp.body, b"Hello world!")
@unittest.skipIf(not hasattr(ssl, 'SSLContext'),
'ssl.SSLContext not present')
def test_ssl_context(self):
resp = self.fetch("/hello",
ssl_options=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
@ -498,27 +498,25 @@ class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase):
def test_ssl_options_handshake_fail(self):
with ExpectLog(gen_log, "SSL Error|Uncaught exception",
required=False):
resp = self.fetch(
"/hello", ssl_options=dict(cert_reqs=ssl.CERT_REQUIRED))
self.assertRaises(ssl.SSLError, resp.rethrow)
with self.assertRaises(ssl.SSLError):
self.fetch(
"/hello", ssl_options=dict(cert_reqs=ssl.CERT_REQUIRED),
raise_error=True)
@unittest.skipIf(not hasattr(ssl, 'SSLContext'),
'ssl.SSLContext not present')
def test_ssl_context_handshake_fail(self):
with ExpectLog(gen_log, "SSL Error|Uncaught exception"):
ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
ctx.verify_mode = ssl.CERT_REQUIRED
resp = self.fetch("/hello", ssl_options=ctx)
self.assertRaises(ssl.SSLError, resp.rethrow)
with self.assertRaises(ssl.SSLError):
self.fetch("/hello", ssl_options=ctx, raise_error=True)
def test_error_logging(self):
# No stack traces are logged for SSL errors (in this case,
# failure to validate the testing self-signed cert).
# The SSLError is exposed through ssl.SSLError.
with ExpectLog(gen_log, '.*') as expect_log:
response = self.fetch("/", validate_cert=True)
self.assertEqual(response.code, 599)
self.assertIsInstance(response.error, ssl.SSLError)
with self.assertRaises(ssl.SSLError):
self.fetch("/", validate_cert=True, raise_error=True)
self.assertFalse(expect_log.logged_stack)
@ -533,24 +531,22 @@ class CreateAsyncHTTPClientTestCase(AsyncTestCase):
def test_max_clients(self):
AsyncHTTPClient.configure(SimpleAsyncHTTPClient)
with closing(AsyncHTTPClient(
self.io_loop, force_instance=True)) as client:
with closing(AsyncHTTPClient(force_instance=True)) as client:
self.assertEqual(client.max_clients, 10)
with closing(AsyncHTTPClient(
self.io_loop, max_clients=11, force_instance=True)) as client:
max_clients=11, force_instance=True)) as client:
self.assertEqual(client.max_clients, 11)
# Now configure max_clients statically and try overriding it
# with each way max_clients can be passed
AsyncHTTPClient.configure(SimpleAsyncHTTPClient, max_clients=12)
with closing(AsyncHTTPClient(
self.io_loop, force_instance=True)) as client:
with closing(AsyncHTTPClient(force_instance=True)) as client:
self.assertEqual(client.max_clients, 12)
with closing(AsyncHTTPClient(
self.io_loop, max_clients=13, force_instance=True)) as client:
max_clients=13, force_instance=True)) as client:
self.assertEqual(client.max_clients, 13)
with closing(AsyncHTTPClient(
self.io_loop, max_clients=14, force_instance=True)) as client:
max_clients=14, force_instance=True)) as client:
self.assertEqual(client.max_clients, 14)
@ -563,14 +559,15 @@ class HTTP100ContinueTestCase(AsyncHTTPTestCase):
request.connection.finish()
return
self.request = request
self.request.connection.stream.write(
b"HTTP/1.1 100 CONTINUE\r\n\r\n",
self.respond_200)
fut = self.request.connection.stream.write(
b"HTTP/1.1 100 CONTINUE\r\n\r\n")
fut.add_done_callback(self.respond_200)
def respond_200(self):
self.request.connection.stream.write(
b"HTTP/1.1 200 OK\r\nContent-Length: 1\r\n\r\nA",
self.request.connection.stream.close)
def respond_200(self, fut):
fut.result()
fut = self.request.connection.stream.write(
b"HTTP/1.1 200 OK\r\nContent-Length: 1\r\n\r\nA")
fut.add_done_callback(lambda f: self.request.connection.stream.close())
def get_app(self):
# Not a full Application, but works as an HTTPServer callback
@ -592,16 +589,20 @@ class HTTP204NoContentTestCase(AsyncHTTPTestCase):
HTTPHeaders())
request.connection.finish()
return
# A 204 response never has a body, even if doesn't have a content-length
# (which would otherwise mean read-until-close). Tornado always
# sends a content-length, so we simulate here a server that sends
# no content length and does not close the connection.
# (which would otherwise mean read-until-close). We simulate here a
# server that sends no content length and does not close the connection.
#
# Tests of a 204 response with a Content-Length header are included
# Tests of a 204 response with no Content-Length header are included
# in SimpleHTTPClientTestMixin.
stream = request.connection.detach()
stream.write(
b"HTTP/1.1 204 No content\r\n\r\n")
stream.write(b"HTTP/1.1 204 No content\r\n")
if request.arguments.get("error", [False])[-1]:
stream.write(b"Content-Length: 5\r\n")
else:
stream.write(b"Content-Length: 0\r\n")
stream.write(b"\r\n")
stream.close()
def get_app(self):
@ -614,12 +615,21 @@ class HTTP204NoContentTestCase(AsyncHTTPTestCase):
self.assertEqual(resp.code, 204)
self.assertEqual(resp.body, b'')
def test_204_invalid_content_length(self):
# 204 status with non-zero content length is malformed
with ExpectLog(gen_log, ".*Response with code 204 should not have body"):
with self.assertRaises(HTTPStreamClosedError):
self.fetch("/?error=1", raise_error=True)
if not self.http1:
self.skipTest("requires HTTP/1.x")
if self.http_client.configured_class != SimpleAsyncHTTPClient:
self.skipTest("curl client accepts invalid headers")
class HostnameMappingTestCase(AsyncHTTPTestCase):
def setUp(self):
super(HostnameMappingTestCase, self).setUp()
self.http_client = SimpleAsyncHTTPClient(
self.io_loop,
hostname_mapping={
'www.example.com': '127.0.0.1',
('foo.example.com', 8000): ('127.0.0.1', self.get_http_port()),
@ -629,37 +639,35 @@ class HostnameMappingTestCase(AsyncHTTPTestCase):
return Application([url("/hello", HelloWorldHandler), ])
def test_hostname_mapping(self):
self.http_client.fetch(
'http://www.example.com:%d/hello' % self.get_http_port(), self.stop)
response = self.wait()
response = self.fetch(
'http://www.example.com:%d/hello' % self.get_http_port())
response.rethrow()
self.assertEqual(response.body, b'Hello world!')
def test_port_mapping(self):
self.http_client.fetch('http://foo.example.com:8000/hello', self.stop)
response = self.wait()
response = self.fetch('http://foo.example.com:8000/hello')
response.rethrow()
self.assertEqual(response.body, b'Hello world!')
class ResolveTimeoutTestCase(AsyncHTTPTestCase):
def setUp(self):
# Dummy Resolver subclass that never invokes its callback.
# Dummy Resolver subclass that never finishes.
class BadResolver(Resolver):
@gen.coroutine
def resolve(self, *args, **kwargs):
pass
yield Event().wait()
super(ResolveTimeoutTestCase, self).setUp()
self.http_client = SimpleAsyncHTTPClient(
self.io_loop,
resolver=BadResolver())
def get_app(self):
return Application([url("/hello", HelloWorldHandler), ])
def test_resolve_timeout(self):
response = self.fetch('/hello', connect_timeout=0.1)
self.assertEqual(response.code, 599)
with self.assertRaises(HTTPTimeoutError):
self.fetch('/hello', connect_timeout=0.1, raise_error=True)
class MaxHeaderSizeTest(AsyncHTTPTestCase):
@ -678,7 +686,7 @@ class MaxHeaderSizeTest(AsyncHTTPTestCase):
('/large', LargeHeaders)])
def get_http_client(self):
return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_header_size=1024)
return SimpleAsyncHTTPClient(max_header_size=1024)
def test_small_headers(self):
response = self.fetch('/small')
@ -687,8 +695,8 @@ class MaxHeaderSizeTest(AsyncHTTPTestCase):
def test_large_headers(self):
with ExpectLog(gen_log, "Unsatisfiable read"):
response = self.fetch('/large')
self.assertEqual(response.code, 599)
with self.assertRaises(UnsatisfiableReadError):
self.fetch('/large', raise_error=True)
class MaxBodySizeTest(AsyncHTTPTestCase):
@ -705,7 +713,7 @@ class MaxBodySizeTest(AsyncHTTPTestCase):
('/large', LargeBody)])
def get_http_client(self):
return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_body_size=1024 * 64)
return SimpleAsyncHTTPClient(max_body_size=1024 * 64)
def test_small_body(self):
response = self.fetch('/small')
@ -714,8 +722,8 @@ class MaxBodySizeTest(AsyncHTTPTestCase):
def test_large_body(self):
with ExpectLog(gen_log, "Malformed HTTP message from None: Content-Length too long"):
response = self.fetch('/large')
self.assertEqual(response.code, 599)
with self.assertRaises(HTTPStreamClosedError):
self.fetch('/large', raise_error=True)
class MaxBufferSizeTest(AsyncHTTPTestCase):
@ -729,7 +737,7 @@ class MaxBufferSizeTest(AsyncHTTPTestCase):
def get_http_client(self):
# 100KB body with 64KB buffer
return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_body_size=1024 * 100, max_buffer_size=1024 * 64)
return SimpleAsyncHTTPClient(max_body_size=1024 * 100, max_buffer_size=1024 * 64)
def test_large_body(self):
response = self.fetch('/large')
@ -754,6 +762,6 @@ class ChunkedWithContentLengthTest(AsyncHTTPTestCase):
def test_chunked_with_content_length(self):
# Make sure the invalid headers are detected
with ExpectLog(gen_log, ("Malformed HTTP message from None: Response "
"with both Transfer-Encoding and Content-Length")):
response = self.fetch('/chunkwithcl')
self.assertEqual(response.code, 599)
"with both Transfer-Encoding and Content-Length")):
with self.assertRaises(HTTPStreamClosedError):
self.fetch('/chunkwithcl', raise_error=True)

View file

@ -1,35 +1,36 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.log import app_log
from tornado.stack_context import (StackContext, wrap, NullContext, StackContextInconsistentError,
ExceptionStackContext, run_with_stack_context, _state)
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
from tornado.test.util import unittest
from tornado.test.util import unittest, ignore_deprecation
from tornado.web import asynchronous, Application, RequestHandler
import contextlib
import functools
import logging
import warnings
class TestRequestHandler(RequestHandler):
def __init__(self, app, request, io_loop):
def __init__(self, app, request):
super(TestRequestHandler, self).__init__(app, request)
self.io_loop = io_loop
@asynchronous
def get(self):
logging.debug('in get()')
# call self.part2 without a self.async_callback wrapper. Its
# exception should still get thrown
self.io_loop.add_callback(self.part2)
with ignore_deprecation():
@asynchronous
def get(self):
logging.debug('in get()')
# call self.part2 without a self.async_callback wrapper. Its
# exception should still get thrown
IOLoop.current().add_callback(self.part2)
def part2(self):
logging.debug('in part2()')
# Go through a third layer to make sure that contexts once restored
# are again passed on to future callbacks
self.io_loop.add_callback(self.part3)
IOLoop.current().add_callback(self.part3)
def part3(self):
logging.debug('in part3()')
@ -44,13 +45,13 @@ class TestRequestHandler(RequestHandler):
class HTTPStackContextTest(AsyncHTTPTestCase):
def get_app(self):
return Application([('/', TestRequestHandler,
dict(io_loop=self.io_loop))])
return Application([('/', TestRequestHandler)])
def test_stack_context(self):
with ExpectLog(app_log, "Uncaught exception GET /"):
self.http_client.fetch(self.get_url('/'), self.handle_response)
self.wait()
with ignore_deprecation():
self.http_client.fetch(self.get_url('/'), self.handle_response)
self.wait()
self.assertEqual(self.response.code, 500)
self.assertTrue(b'got expected exception' in self.response.body)
@ -63,6 +64,13 @@ class StackContextTest(AsyncTestCase):
def setUp(self):
super(StackContextTest, self).setUp()
self.active_contexts = []
self.warning_catcher = warnings.catch_warnings()
self.warning_catcher.__enter__()
warnings.simplefilter('ignore', DeprecationWarning)
def tearDown(self):
self.warning_catcher.__exit__(None, None, None)
super(StackContextTest, self).tearDown()
@contextlib.contextmanager
def context(self, name):
@ -284,5 +292,6 @@ class StackContextTest(AsyncTestCase):
f1)
self.assertEqual(self.active_contexts, [])
if __name__ == '__main__':
unittest.main()

View file

@ -1,4 +1,3 @@
#!/usr/bin/env python
#
# Copyright 2014 Facebook
#
@ -14,7 +13,7 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
from contextlib import closing
import os
@ -22,10 +21,12 @@ import socket
from tornado.concurrent import Future
from tornado.netutil import bind_sockets, Resolver
from tornado.queues import Queue
from tornado.tcpclient import TCPClient, _Connector
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, gen_test
from tornado.test.util import skipIfNoIPv6, unittest, refusing_port
from tornado.test.util import skipIfNoIPv6, unittest, refusing_port, skipIfNonUnix
from tornado.gen import TimeoutError
# Fake address families for testing. Used in place of AF_INET
# and AF_INET6 because some installations do not have AF_INET6.
@ -36,12 +37,14 @@ class TestTCPServer(TCPServer):
def __init__(self, family):
super(TestTCPServer, self).__init__()
self.streams = []
self.queue = Queue()
sockets = bind_sockets(None, 'localhost', family)
self.add_sockets(sockets)
self.port = sockets[0].getsockname()[1]
def handle_stream(self, stream, address):
self.streams.append(stream)
self.queue.put(stream)
def stop(self):
super(TestTCPServer, self).stop()
@ -74,19 +77,21 @@ class TCPClientTest(AsyncTestCase):
def skipIfLocalhostV4(self):
# The port used here doesn't matter, but some systems require it
# to be non-zero if we do not also pass AI_PASSIVE.
Resolver().resolve('localhost', 80, callback=self.stop)
addrinfo = self.wait()
addrinfo = self.io_loop.run_sync(lambda: Resolver().resolve('localhost', 80))
families = set(addr[0] for addr in addrinfo)
if socket.AF_INET6 not in families:
self.skipTest("localhost does not resolve to ipv6")
@gen_test
def do_test_connect(self, family, host):
def do_test_connect(self, family, host, source_ip=None, source_port=None):
port = self.start_server(family)
stream = yield self.client.connect(host, port)
stream = yield self.client.connect(host, port,
source_ip=source_ip,
source_port=source_port)
server_stream = yield self.server.queue.get()
with closing(stream):
stream.write(b"hello")
data = yield self.server.streams[0].read_bytes(5)
data = yield server_stream.read_bytes(5)
self.assertEqual(data, b"hello")
def test_connect_ipv4_ipv4(self):
@ -125,6 +130,44 @@ class TCPClientTest(AsyncTestCase):
with self.assertRaises(IOError):
yield self.client.connect('127.0.0.1', port)
def test_source_ip_fail(self):
'''
Fail when trying to use the source IP Address '8.8.8.8'.
'''
self.assertRaises(socket.error,
self.do_test_connect,
socket.AF_INET,
'127.0.0.1',
source_ip='8.8.8.8')
def test_source_ip_success(self):
'''
Success when trying to use the source IP Address '127.0.0.1'
'''
self.do_test_connect(socket.AF_INET, '127.0.0.1', source_ip='127.0.0.1')
@skipIfNonUnix
def test_source_port_fail(self):
'''
Fail when trying to use source port 1.
'''
self.assertRaises(socket.error,
self.do_test_connect,
socket.AF_INET,
'127.0.0.1',
source_port=1)
@gen_test
def test_connect_timeout(self):
timeout = 0.05
class TimeoutResolver(Resolver):
def resolve(self, *args, **kwargs):
return Future() # never completes
with self.assertRaises(TimeoutError):
yield TCPClient(resolver=TimeoutResolver()).connect(
'1.2.3.4', 12345, timeout=timeout)
class TestConnectorSplit(unittest.TestCase):
def test_one_family(self):
@ -169,9 +212,11 @@ class ConnectorTest(AsyncTestCase):
super(ConnectorTest, self).tearDown()
def create_stream(self, af, addr):
stream = ConnectorTest.FakeStream()
self.streams[addr] = stream
future = Future()
self.connect_futures[(af, addr)] = future
return future
return stream, future
def assert_pending(self, *keys):
self.assertEqual(sorted(self.connect_futures.keys()), sorted(keys))
@ -179,15 +224,22 @@ class ConnectorTest(AsyncTestCase):
def resolve_connect(self, af, addr, success):
future = self.connect_futures.pop((af, addr))
if success:
self.streams[addr] = ConnectorTest.FakeStream()
future.set_result(self.streams[addr])
else:
self.streams.pop(addr)
future.set_exception(IOError())
# Run the loop to allow callbacks to be run.
self.io_loop.add_callback(self.stop)
self.wait()
def assert_connector_streams_closed(self, conn):
for stream in conn.streams:
self.assertTrue(stream.closed)
def start_connect(self, addrinfo):
conn = _Connector(addrinfo, self.io_loop, self.create_stream)
conn = _Connector(addrinfo, self.create_stream)
# Give it a huge timeout; we'll trigger timeouts manually.
future = conn.start(3600)
future = conn.start(3600, connect_timeout=self.io_loop.time() + 3600)
return conn, future
def test_immediate_success(self):
@ -278,3 +330,101 @@ class ConnectorTest(AsyncTestCase):
self.assertFalse(future.done())
self.resolve_connect(AF1, 'b', False)
self.assertRaises(IOError, future.result)
def test_one_family_timeout_after_connect_timeout(self):
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
self.assert_pending((AF1, 'a'))
conn.on_connect_timeout()
# the connector will close all streams on connect timeout, we
# should explicitly pop the connect_future.
self.connect_futures.pop((AF1, 'a'))
self.assertTrue(self.streams.pop('a').closed)
conn.on_timeout()
# if the future is set with TimeoutError, we will not iterate next
# possible address.
self.assert_pending()
self.assertEqual(len(conn.streams), 1)
self.assert_connector_streams_closed(conn)
self.assertRaises(TimeoutError, future.result)
def test_one_family_success_before_connect_timeout(self):
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', True)
conn.on_connect_timeout()
self.assert_pending()
self.assertEqual(self.streams['a'].closed, False)
# success stream will be pop
self.assertEqual(len(conn.streams), 0)
# streams in connector should be closed after connect timeout
self.assert_connector_streams_closed(conn)
self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
def test_one_family_second_try_after_connect_timeout(self):
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'))
conn.on_connect_timeout()
self.connect_futures.pop((AF1, 'b'))
self.assertTrue(self.streams.pop('b').closed)
self.assert_pending()
self.assertEqual(len(conn.streams), 2)
self.assert_connector_streams_closed(conn)
self.assertRaises(TimeoutError, future.result)
def test_one_family_second_try_failure_before_connect_timeout(self):
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'))
self.resolve_connect(AF1, 'b', False)
conn.on_connect_timeout()
self.assert_pending()
self.assertEqual(len(conn.streams), 2)
self.assert_connector_streams_closed(conn)
self.assertRaises(IOError, future.result)
def test_two_family_timeout_before_connect_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, 'a'))
conn.on_timeout()
self.assert_pending((AF1, 'a'), (AF2, 'c'))
conn.on_connect_timeout()
self.connect_futures.pop((AF1, 'a'))
self.assertTrue(self.streams.pop('a').closed)
self.connect_futures.pop((AF2, 'c'))
self.assertTrue(self.streams.pop('c').closed)
self.assert_pending()
self.assertEqual(len(conn.streams), 2)
self.assert_connector_streams_closed(conn)
self.assertRaises(TimeoutError, future.result)
def test_two_family_success_after_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, 'a'))
conn.on_timeout()
self.assert_pending((AF1, 'a'), (AF2, 'c'))
self.resolve_connect(AF1, 'a', True)
# if one of streams succeed, connector will close all other streams
self.connect_futures.pop((AF2, 'c'))
self.assertTrue(self.streams.pop('c').closed)
self.assert_pending()
self.assertEqual(len(conn.streams), 1)
self.assert_connector_streams_closed(conn)
self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
def test_two_family_timeout_after_connect_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, 'a'))
conn.on_connect_timeout()
self.connect_futures.pop((AF1, 'a'))
self.assertTrue(self.streams.pop('a').closed)
self.assert_pending()
conn.on_timeout()
# if the future is set with TimeoutError, connector will not
# trigger secondary address.
self.assert_pending()
self.assertEqual(len(conn.streams), 1)
self.assert_connector_streams_closed(conn)
self.assertRaises(TimeoutError, future.result)

View file

@ -1,11 +1,17 @@
from __future__ import absolute_import, division, print_function, with_statement
import socket
from __future__ import absolute_import, division, print_function
import socket
import subprocess
import sys
import textwrap
from tornado.escape import utf8, to_unicode
from tornado import gen
from tornado.iostream import IOStream
from tornado.log import app_log
from tornado.stack_context import NullContext
from tornado.tcpserver import TCPServer
from tornado.test.util import skipBefore35, skipIfNonUnix, exec_test, unittest
from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test
@ -17,7 +23,7 @@ class TCPServerTest(AsyncTestCase):
class TestServer(TCPServer):
@gen.coroutine
def handle_stream(self, stream, address):
yield gen.moment
yield stream.read_bytes(len(b'hello'))
stream.close()
1 / 0
@ -30,6 +36,7 @@ class TCPServerTest(AsyncTestCase):
client = IOStream(socket.socket())
with ExpectLog(app_log, "Exception in callback"):
yield client.connect(('localhost', port))
yield client.write(b'hello')
yield client.read_until_close()
yield gen.moment
finally:
@ -37,3 +44,150 @@ class TCPServerTest(AsyncTestCase):
server.stop()
if client is not None:
client.close()
@skipBefore35
@gen_test
def test_handle_stream_native_coroutine(self):
# handle_stream may be a native coroutine.
namespace = exec_test(globals(), locals(), """
class TestServer(TCPServer):
async def handle_stream(self, stream, address):
stream.write(b'data')
stream.close()
""")
sock, port = bind_unused_port()
server = namespace['TestServer']()
server.add_socket(sock)
client = IOStream(socket.socket())
yield client.connect(('localhost', port))
result = yield client.read_until_close()
self.assertEqual(result, b'data')
server.stop()
client.close()
def test_stop_twice(self):
sock, port = bind_unused_port()
server = TCPServer()
server.add_socket(sock)
server.stop()
server.stop()
@gen_test
def test_stop_in_callback(self):
# Issue #2069: calling server.stop() in a loop callback should not
# raise EBADF when the loop handles other server connection
# requests in the same loop iteration
class TestServer(TCPServer):
@gen.coroutine
def handle_stream(self, stream, address):
server.stop()
yield stream.read_until_close()
sock, port = bind_unused_port()
server = TestServer()
server.add_socket(sock)
server_addr = ('localhost', port)
N = 40
clients = [IOStream(socket.socket()) for i in range(N)]
connected_clients = []
@gen.coroutine
def connect(c):
try:
yield c.connect(server_addr)
except EnvironmentError:
pass
else:
connected_clients.append(c)
yield [connect(c) for c in clients]
self.assertGreater(len(connected_clients), 0,
"all clients failed connecting")
try:
if len(connected_clients) == N:
# Ideally we'd make the test deterministic, but we're testing
# for a race condition in combination with the system's TCP stack...
self.skipTest("at least one client should fail connecting "
"for the test to be meaningful")
finally:
for c in connected_clients:
c.close()
# Here tearDown() would re-raise the EBADF encountered in the IO loop
@skipIfNonUnix
class TestMultiprocess(unittest.TestCase):
# These tests verify that the two multiprocess examples from the
# TCPServer docs work. Both tests start a server with three worker
# processes, each of which prints its task id to stdout (a single
# byte, so we don't have to worry about atomicity of the shared
# stdout stream) and then exits.
def run_subproc(self, code):
proc = subprocess.Popen(sys.executable,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE)
proc.stdin.write(utf8(code))
proc.stdin.close()
proc.wait()
stdout = proc.stdout.read()
proc.stdout.close()
if proc.returncode != 0:
raise RuntimeError("Process returned %d. stdout=%r" % (
proc.returncode, stdout))
return to_unicode(stdout)
def test_single(self):
# As a sanity check, run the single-process version through this test
# harness too.
code = textwrap.dedent("""
from __future__ import print_function
from tornado.ioloop import IOLoop
from tornado.tcpserver import TCPServer
server = TCPServer()
server.listen(0, address='127.0.0.1')
IOLoop.current().run_sync(lambda: None)
print('012', end='')
""")
out = self.run_subproc(code)
self.assertEqual(''.join(sorted(out)), "012")
def test_simple(self):
code = textwrap.dedent("""
from __future__ import print_function
from tornado.ioloop import IOLoop
from tornado.process import task_id
from tornado.tcpserver import TCPServer
server = TCPServer()
server.bind(0, address='127.0.0.1')
server.start(3)
IOLoop.current().run_sync(lambda: None)
print(task_id(), end='')
""")
out = self.run_subproc(code)
self.assertEqual(''.join(sorted(out)), "012")
def test_advanced(self):
code = textwrap.dedent("""
from __future__ import print_function
from tornado.ioloop import IOLoop
from tornado.netutil import bind_sockets
from tornado.process import fork_processes, task_id
from tornado.ioloop import IOLoop
from tornado.tcpserver import TCPServer
sockets = bind_sockets(0, address='127.0.0.1')
fork_processes(3)
server = TCPServer()
server.add_sockets(sockets)
IOLoop.current().run_sync(lambda: None)
print(task_id(), end='')
""")
out = self.run_subproc(code)
self.assertEqual(''.join(sorted(out)), "012")

View file

@ -1,4 +1,4 @@
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import os
import sys
@ -7,7 +7,7 @@ import traceback
from tornado.escape import utf8, native_str, to_unicode
from tornado.template import Template, DictLoader, ParseError, Loader
from tornado.test.util import unittest
from tornado.util import u, ObjectDict, unicode_type
from tornado.util import ObjectDict, unicode_type
class TemplateTest(unittest.TestCase):
@ -67,12 +67,13 @@ class TemplateTest(unittest.TestCase):
self.assertRaises(ParseError, lambda: Template("{%"))
self.assertEqual(Template("{{!").generate(), b"{{")
self.assertEqual(Template("{%!").generate(), b"{%")
self.assertEqual(Template("{#!").generate(), b"{#")
self.assertEqual(Template("{{ 'expr' }} {{!jquery expr}}").generate(),
b"expr {{jquery expr}}")
def test_unicode_template(self):
template = Template(utf8(u("\u00e9")))
self.assertEqual(template.generate(), utf8(u("\u00e9")))
template = Template(utf8(u"\u00e9"))
self.assertEqual(template.generate(), utf8(u"\u00e9"))
def test_unicode_literal_expression(self):
# Unicode literals should be usable in templates. Note that this
@ -82,10 +83,10 @@ class TemplateTest(unittest.TestCase):
if str is unicode_type:
# python 3 needs a different version of this test since
# 2to3 doesn't run on template internals
template = Template(utf8(u('{{ "\u00e9" }}')))
template = Template(utf8(u'{{ "\u00e9" }}'))
else:
template = Template(utf8(u('{{ u"\u00e9" }}')))
self.assertEqual(template.generate(), utf8(u("\u00e9")))
template = Template(utf8(u'{{ u"\u00e9" }}'))
self.assertEqual(template.generate(), utf8(u"\u00e9"))
def test_custom_namespace(self):
loader = DictLoader({"test.html": "{{ inc(5) }}"}, namespace={"inc": lambda x: x + 1})
@ -100,14 +101,14 @@ class TemplateTest(unittest.TestCase):
def test_unicode_apply(self):
def upper(s):
return to_unicode(s).upper()
template = Template(utf8(u("{% apply upper %}foo \u00e9{% end %}")))
self.assertEqual(template.generate(upper=upper), utf8(u("FOO \u00c9")))
template = Template(utf8(u"{% apply upper %}foo \u00e9{% end %}"))
self.assertEqual(template.generate(upper=upper), utf8(u"FOO \u00c9"))
def test_bytes_apply(self):
def upper(s):
return utf8(to_unicode(s).upper())
template = Template(utf8(u("{% apply upper %}foo \u00e9{% end %}")))
self.assertEqual(template.generate(upper=upper), utf8(u("FOO \u00c9")))
template = Template(utf8(u"{% apply upper %}foo \u00e9{% end %}"))
self.assertEqual(template.generate(upper=upper), utf8(u"FOO \u00c9"))
def test_if(self):
template = Template(utf8("{% if x > 4 %}yes{% else %}no{% end %}"))
@ -174,8 +175,8 @@ try{% set y = 1/x %}
self.assertEqual(template.generate(), '0')
def test_non_ascii_name(self):
loader = DictLoader({u("t\u00e9st.html"): "hello"})
self.assertEqual(loader.load(u("t\u00e9st.html")).generate(), b"hello")
loader = DictLoader({u"t\u00e9st.html": "hello"})
self.assertEqual(loader.load(u"t\u00e9st.html").generate(), b"hello")
class StackTraceTest(unittest.TestCase):
@ -202,10 +203,15 @@ three{%end%}
self.assertTrue("# test.html:2" in traceback.format_exc())
def test_error_line_number_module(self):
loader = None
def load_generate(path, **kwargs):
return loader.load(path).generate(**kwargs)
loader = DictLoader({
"base.html": "{% module Template('sub.html') %}",
"sub.html": "{{1/0}}",
}, namespace={"_tt_modules": ObjectDict({"Template": lambda path, **kwargs: loader.load(path).generate(**kwargs)})})
}, namespace={"_tt_modules": ObjectDict(Template=load_generate)})
try:
loader.load("base.html").generate()
self.fail("did not get expected exception")
@ -280,6 +286,11 @@ class ParseErrorDetailTest(unittest.TestCase):
self.assertEqual("foo.html", cm.exception.filename)
self.assertEqual(3, cm.exception.lineno)
def test_custom_parse_error(self):
# Make sure that ParseErrors remain compatible with their
# pre-4.3 signature.
self.assertEqual("asdf at None:0", str(ParseError("asdf")))
class AutoEscapeTest(unittest.TestCase):
def setUp(self):
@ -482,4 +493,4 @@ class TemplateLoaderTest(unittest.TestCase):
def test_utf8_in_file(self):
tmpl = self.loader.load("utf8.html")
result = tmpl.generate()
self.assertEqual(to_unicode(result).strip(), u("H\u00e9llo"))
self.assertEqual(to_unicode(result).strip(), u"H\u00e9llo")

View file

@ -1,16 +1,22 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
from tornado import gen, ioloop
from tornado.log import app_log
from tornado.testing import AsyncTestCase, gen_test, ExpectLog
from tornado.test.util import unittest, skipBefore35, exec_test
from tornado.simple_httpclient import SimpleAsyncHTTPClient, HTTPTimeoutError
from tornado.test.util import unittest, skipBefore35, exec_test, ignore_deprecation
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, bind_unused_port, gen_test, ExpectLog
from tornado.web import Application
import contextlib
import os
import platform
import traceback
import warnings
try:
import asyncio
except ImportError:
asyncio = None
@contextlib.contextmanager
def set_environ(name, value):
@ -28,12 +34,13 @@ def set_environ(name, value):
class AsyncTestCaseTest(AsyncTestCase):
def test_exception_in_callback(self):
self.io_loop.add_callback(lambda: 1 / 0)
try:
self.wait()
self.fail("did not get expected exception")
except ZeroDivisionError:
pass
with ignore_deprecation():
self.io_loop.add_callback(lambda: 1 / 0)
try:
self.wait()
self.fail("did not get expected exception")
except ZeroDivisionError:
pass
def test_wait_timeout(self):
time = self.io_loop.time
@ -64,15 +71,56 @@ class AsyncTestCaseTest(AsyncTestCase):
self.wait(timeout=0.15)
def test_multiple_errors(self):
def fail(message):
raise Exception(message)
self.io_loop.add_callback(lambda: fail("error one"))
self.io_loop.add_callback(lambda: fail("error two"))
# The first error gets raised; the second gets logged.
with ExpectLog(app_log, "multiple unhandled exceptions"):
with self.assertRaises(Exception) as cm:
self.wait()
self.assertEqual(str(cm.exception), "error one")
with ignore_deprecation():
def fail(message):
raise Exception(message)
self.io_loop.add_callback(lambda: fail("error one"))
self.io_loop.add_callback(lambda: fail("error two"))
# The first error gets raised; the second gets logged.
with ExpectLog(app_log, "multiple unhandled exceptions"):
with self.assertRaises(Exception) as cm:
self.wait()
self.assertEqual(str(cm.exception), "error one")
class AsyncHTTPTestCaseTest(AsyncHTTPTestCase):
@classmethod
def setUpClass(cls):
super(AsyncHTTPTestCaseTest, cls).setUpClass()
# An unused port is bound so we can make requests upon it without
# impacting a real local web server.
cls.external_sock, cls.external_port = bind_unused_port()
def get_app(self):
return Application()
def test_fetch_segment(self):
path = '/path'
response = self.fetch(path)
self.assertEqual(response.request.url, self.get_url(path))
@gen_test
def test_fetch_full_http_url(self):
path = 'http://localhost:%d/path' % self.external_port
with contextlib.closing(SimpleAsyncHTTPClient(force_instance=True)) as client:
with self.assertRaises(HTTPTimeoutError) as cm:
yield client.fetch(path, request_timeout=0.1, raise_error=True)
self.assertEqual(cm.exception.response.request.url, path)
@gen_test
def test_fetch_full_https_url(self):
path = 'https://localhost:%d/path' % self.external_port
with contextlib.closing(SimpleAsyncHTTPClient(force_instance=True)) as client:
with self.assertRaises(HTTPTimeoutError) as cm:
yield client.fetch(path, request_timeout=0.1, raise_error=True)
self.assertEqual(cm.exception.response.request.url, path)
@classmethod
def tearDownClass(cls):
cls.external_sock.close()
super(AsyncHTTPTestCaseTest, cls).tearDownClass()
class AsyncTestCaseWrapperTest(unittest.TestCase):
@ -87,6 +135,8 @@ class AsyncTestCaseWrapperTest(unittest.TestCase):
self.assertIn("should be decorated", result.errors[0][1])
@skipBefore35
@unittest.skipIf(platform.python_implementation() == 'PyPy',
'pypy destructor warnings cannot be silenced')
def test_undecorated_coroutine(self):
namespace = exec_test(globals(), locals(), """
class Test(AsyncTestCase):
@ -172,14 +222,14 @@ class GenTest(AsyncTestCase):
@gen_test
def test_async(self):
yield gen.Task(self.io_loop.add_callback)
yield gen.moment
self.finished = True
def test_timeout(self):
# Set a short timeout and exceed it.
@gen_test(timeout=0.1)
def test(self):
yield gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1)
yield gen.sleep(1)
# This can't use assertRaises because we need to inspect the
# exc_info triple (and not just the exception object)
@ -190,7 +240,7 @@ class GenTest(AsyncTestCase):
# The stack trace should blame the add_timeout line, not just
# unrelated IOLoop/testing internals.
self.assertIn(
"gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1)",
"gen.sleep(1)",
traceback.format_exc())
self.finished = True
@ -199,8 +249,7 @@ class GenTest(AsyncTestCase):
# A test that does not exceed its timeout should succeed.
@gen_test(timeout=1)
def test(self):
time = self.io_loop.time
yield gen.Task(self.io_loop.add_timeout, time() + 0.1)
yield gen.sleep(0.1)
test(self)
self.finished = True
@ -208,8 +257,7 @@ class GenTest(AsyncTestCase):
def test_timeout_environment_variable(self):
@gen_test(timeout=0.5)
def test_long_timeout(self):
time = self.io_loop.time
yield gen.Task(self.io_loop.add_timeout, time() + 0.25)
yield gen.sleep(0.25)
# Uses provided timeout of 0.5 seconds, doesn't time out.
with set_environ('ASYNC_TEST_TIMEOUT', '0.1'):
@ -220,8 +268,7 @@ class GenTest(AsyncTestCase):
def test_no_timeout_environment_variable(self):
@gen_test(timeout=0.01)
def test_short_timeout(self):
time = self.io_loop.time
yield gen.Task(self.io_loop.add_timeout, time() + 1)
yield gen.sleep(1)
# Uses environment-variable timeout of 0.1, times out.
with set_environ('ASYNC_TEST_TIMEOUT', '0.1'):
@ -234,7 +281,7 @@ class GenTest(AsyncTestCase):
@gen_test
def test_with_args(self, *args):
self.assertEqual(args, ('test',))
yield gen.Task(self.io_loop.add_callback)
yield gen.moment
test_with_args(self, 'test')
self.finished = True
@ -243,7 +290,7 @@ class GenTest(AsyncTestCase):
@gen_test
def test_with_kwargs(self, **kwargs):
self.assertDictEqual(kwargs, {'test': 'test'})
yield gen.Task(self.io_loop.add_callback)
yield gen.moment
test_with_kwargs(self, test='test')
self.finished = True
@ -274,5 +321,30 @@ class GenTest(AsyncTestCase):
self.finished = True
@unittest.skipIf(asyncio is None, "asyncio module not present")
class GetNewIOLoopTest(AsyncTestCase):
def get_new_ioloop(self):
# Use the current loop instead of creating a new one here.
return ioloop.IOLoop.current()
def setUp(self):
# This simulates the effect of an asyncio test harness like
# pytest-asyncio.
self.orig_loop = asyncio.get_event_loop()
self.new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.new_loop)
super(GetNewIOLoopTest, self).setUp()
def tearDown(self):
super(GetNewIOLoopTest, self).tearDown()
# AsyncTestCase must not affect the existing asyncio loop.
self.assertFalse(asyncio.get_event_loop().is_closed())
asyncio.set_event_loop(self.orig_loop)
self.new_loop.close()
def test_loop(self):
self.assertIs(self.io_loop.asyncio_loop, self.new_loop)
if __name__ == '__main__':
unittest.main()

View file

@ -17,7 +17,7 @@
Unittest for the twisted-style reactor.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import logging
import os
@ -28,14 +28,25 @@ import tempfile
import threading
import warnings
from tornado.escape import utf8
from tornado import gen
from tornado.httpclient import AsyncHTTPClient
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop, PollIOLoop
from tornado.platform.auto import set_close_exec
from tornado.testing import bind_unused_port
from tornado.test.util import unittest
from tornado.util import import_object, PY3
from tornado.web import RequestHandler, Application
try:
import fcntl
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue
from twisted.internet.interfaces import IReadDescriptor, IWriteDescriptor
from twisted.internet.protocol import Protocol
from twisted.python import log
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue # type: ignore
from twisted.internet.interfaces import IReadDescriptor, IWriteDescriptor # type: ignore
from twisted.internet.protocol import Protocol # type: ignore
from twisted.python import log # type: ignore
from tornado.platform.twisted import TornadoReactor, TwistedIOLoop
from zope.interface import implementer
from zope.interface import implementer # type: ignore
have_twisted = True
except ImportError:
have_twisted = False
@ -43,38 +54,29 @@ except ImportError:
# The core of Twisted 12.3.0 is available on python 3, but twisted.web is not
# so test for it separately.
try:
from twisted.web.client import Agent, readBody
from twisted.web.resource import Resource
from twisted.web.server import Site
from twisted.web.client import Agent, readBody # type: ignore
from twisted.web.resource import Resource # type: ignore
from twisted.web.server import Site # type: ignore
# As of Twisted 15.0.0, twisted.web is present but fails our
# tests due to internal str/bytes errors.
have_twisted_web = sys.version_info < (3,)
except ImportError:
have_twisted_web = False
try:
import thread # py2
except ImportError:
import _thread as thread # py3
if PY3:
import _thread as thread
else:
import thread
ResourceWarning = None
from tornado.escape import utf8
from tornado import gen
from tornado.httpclient import AsyncHTTPClient
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from tornado.platform.auto import set_close_exec
from tornado.platform.select import SelectIOLoop
from tornado.testing import bind_unused_port
from tornado.test.util import unittest
from tornado.util import import_object
from tornado.web import RequestHandler, Application
try:
import asyncio
except ImportError:
asyncio = None
skipIfNoTwisted = unittest.skipUnless(have_twisted,
"twisted module not present")
skipIfPy26 = unittest.skipIf(sys.version_info < (2, 7),
"twisted incompatible with singledispatch in py26")
def save_signal_handlers():
saved = {}
@ -97,8 +99,10 @@ def restore_signal_handlers(saved):
class ReactorTestCase(unittest.TestCase):
def setUp(self):
self._saved_signals = save_signal_handlers()
self._io_loop = IOLoop()
self._reactor = TornadoReactor(self._io_loop)
IOLoop.clear_current()
self._io_loop = IOLoop(make_current=True)
self._reactor = TornadoReactor()
IOLoop.clear_current()
def tearDown(self):
self._io_loop.close(all_fds=True)
@ -219,53 +223,51 @@ class ReactorCallInThread(ReactorTestCase):
self._reactor.run()
class Reader(object):
def __init__(self, fd, callback):
self._fd = fd
self._callback = callback
def logPrefix(self):
return "Reader"
def close(self):
self._fd.close()
def fileno(self):
return self._fd.fileno()
def readConnectionLost(self, reason):
self.close()
def connectionLost(self, reason):
self.close()
def doRead(self):
self._callback(self._fd)
if have_twisted:
Reader = implementer(IReadDescriptor)(Reader)
@implementer(IReadDescriptor)
class Reader(object):
def __init__(self, fd, callback):
self._fd = fd
self._callback = callback
def logPrefix(self):
return "Reader"
class Writer(object):
def __init__(self, fd, callback):
self._fd = fd
self._callback = callback
def close(self):
self._fd.close()
def logPrefix(self):
return "Writer"
def fileno(self):
return self._fd.fileno()
def close(self):
self._fd.close()
def readConnectionLost(self, reason):
self.close()
def fileno(self):
return self._fd.fileno()
def connectionLost(self, reason):
self.close()
def connectionLost(self, reason):
self.close()
def doRead(self):
self._callback(self._fd)
def doWrite(self):
self._callback(self._fd)
if have_twisted:
Writer = implementer(IWriteDescriptor)(Writer)
@implementer(IWriteDescriptor)
class Writer(object):
def __init__(self, fd, callback):
self._fd = fd
self._callback = callback
def logPrefix(self):
return "Writer"
def close(self):
self._fd.close()
def fileno(self):
return self._fd.fileno()
def connectionLost(self, reason):
self.close()
def doWrite(self):
self._callback(self._fd)
@skipIfNoTwisted
@ -362,7 +364,7 @@ class CompatibilityTests(unittest.TestCase):
self.saved_signals = save_signal_handlers()
self.io_loop = IOLoop()
self.io_loop.make_current()
self.reactor = TornadoReactor(self.io_loop)
self.reactor = TornadoReactor()
def tearDown(self):
self.reactor.disconnectAll()
@ -386,7 +388,7 @@ class CompatibilityTests(unittest.TestCase):
self.write("Hello from tornado!")
app = Application([('/', HelloHandler)],
log_function=lambda x: None)
server = HTTPServer(app, io_loop=self.io_loop)
server = HTTPServer(app)
sock, self.tornado_port = bind_unused_port()
server.add_sockets([sock])
@ -402,7 +404,7 @@ class CompatibilityTests(unittest.TestCase):
def tornado_fetch(self, url, runner):
responses = []
client = AsyncHTTPClient(self.io_loop)
client = AsyncHTTPClient()
def callback(response):
responses.append(response)
@ -495,7 +497,6 @@ class CompatibilityTests(unittest.TestCase):
'http://127.0.0.1:%d' % self.tornado_port, self.run_reactor)
self.assertEqual(response, 'Hello from tornado!')
@skipIfPy26
def testTornadoServerTwistedCoroutineClientIOLoop(self):
self.start_tornado_server()
response = self.twisted_coroutine_fetch(
@ -504,7 +505,6 @@ class CompatibilityTests(unittest.TestCase):
@skipIfNoTwisted
@skipIfPy26
class ConvertDeferredTest(unittest.TestCase):
def test_success(self):
@inlineCallbacks
@ -610,14 +610,14 @@ if have_twisted:
test_class = import_object(test_name)
except (ImportError, AttributeError):
continue
for test_func in blacklist:
for test_func in blacklist: # type: ignore
if hasattr(test_class, test_func):
# The test_func may be defined in a mixin, so clobber
# it instead of delattr()
setattr(test_class, test_func, lambda self: None)
def make_test_subclass(test_class):
class TornadoTest(test_class):
class TornadoTest(test_class): # type: ignore
_reactors = ["tornado.platform.twisted._TestReactor"]
def setUp(self):
@ -627,10 +627,10 @@ if have_twisted:
self.__curdir = os.getcwd()
self.__tempdir = tempfile.mkdtemp()
os.chdir(self.__tempdir)
super(TornadoTest, self).setUp()
super(TornadoTest, self).setUp() # type: ignore
def tearDown(self):
super(TornadoTest, self).tearDown()
super(TornadoTest, self).tearDown() # type: ignore
os.chdir(self.__curdir)
shutil.rmtree(self.__tempdir)
@ -645,7 +645,7 @@ if have_twisted:
# enabled) but without our filter rules to ignore those
# warnings from Twisted code.
filtered = []
for w in super(TornadoTest, self).flushWarnings(
for w in super(TornadoTest, self).flushWarnings( # type: ignore
*args, **kwargs):
if w['category'] in (BytesWarning, ResourceWarning):
continue
@ -682,15 +682,15 @@ if have_twisted:
# Twisted recently introduced a new logger; disable that one too.
try:
from twisted.logger import globalLogBeginner
from twisted.logger import globalLogBeginner # type: ignore
except ImportError:
pass
else:
globalLogBeginner.beginLoggingTo([])
globalLogBeginner.beginLoggingTo([], redirectStandardIO=False)
if have_twisted:
class LayeredTwistedIOLoop(TwistedIOLoop):
"""Layers a TwistedIOLoop on top of a TornadoReactor on a SelectIOLoop.
"""Layers a TwistedIOLoop on top of a TornadoReactor on a PollIOLoop.
This is of course silly, but is useful for testing purposes to make
sure we're implementing both sides of the various interfaces
@ -698,11 +698,8 @@ if have_twisted:
of the whole stack.
"""
def initialize(self, **kwargs):
# When configured to use LayeredTwistedIOLoop we can't easily
# get the next-best IOLoop implementation, so use the lowest common
# denominator.
self.real_io_loop = SelectIOLoop(make_current=False)
reactor = TornadoReactor(io_loop=self.real_io_loop)
self.real_io_loop = PollIOLoop(make_current=False) # type: ignore
reactor = self.real_io_loop.run_sync(gen.coroutine(TornadoReactor))
super(LayeredTwistedIOLoop, self).initialize(reactor=reactor, **kwargs)
self.add_callback(self.make_current)

View file

@ -1,22 +1,17 @@
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import contextlib
import os
import platform
import socket
import sys
import textwrap
import warnings
from tornado.testing import bind_unused_port
# Encapsulate the choice of unittest or unittest2 here.
# To be used as 'from tornado.test.util import unittest'.
if sys.version_info < (2, 7):
# In py26, we must always use unittest2.
import unittest2 as unittest
else:
# Otherwise, use whichever version of unittest was imported in
# tornado.testing.
from tornado.testing import unittest
# Delegate the choice of unittest or unittest2 to tornado.testing.
from tornado.testing import unittest
skipIfNonUnix = unittest.skipIf(os.name != 'posix' or sys.platform == 'cygwin',
"non-unix platform")
@ -34,14 +29,39 @@ skipOnAppEngine = unittest.skipIf('APPENGINE_RUNTIME' in os.environ,
skipIfNoNetwork = unittest.skipIf('NO_NETWORK' in os.environ,
'network access disabled')
skipIfNoIPv6 = unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present')
skipBefore33 = unittest.skipIf(sys.version_info < (3, 3), 'PEP 380 (yield from) not available')
skipBefore35 = unittest.skipIf(sys.version_info < (3, 5), 'PEP 492 (async/await) not available')
skipNotCPython = unittest.skipIf(platform.python_implementation() != 'CPython',
'Not CPython implementation')
# Used for tests affected by
# https://bitbucket.org/pypy/pypy/issues/2616/incomplete-error-handling-in
# TODO: remove this after pypy3 5.8 is obsolete.
skipPypy3V58 = unittest.skipIf(platform.python_implementation() == 'PyPy' and
sys.version_info > (3,) and
sys.pypy_version_info < (5, 9),
'pypy3 5.8 has buggy ssl module')
def _detect_ipv6():
if not socket.has_ipv6:
# socket.has_ipv6 check reports whether ipv6 was present at compile
# time. It's usually true even when ipv6 doesn't work for other reasons.
return False
sock = None
try:
sock = socket.socket(socket.AF_INET6)
sock.bind(('::1', 0))
except socket.error:
return False
finally:
if sock is not None:
sock.close()
return True
skipIfNoIPv6 = unittest.skipIf(not _detect_ipv6(), 'ipv6 support not present')
def refusing_port():
"""Returns a local port number that will refuse all connections.
@ -72,7 +92,27 @@ def exec_test(caller_globals, caller_locals, s):
# Flatten the real global and local namespace into our fake
# globals: it's all global from the perspective of code defined
# in s.
global_namespace = dict(caller_globals, **caller_locals)
global_namespace = dict(caller_globals, **caller_locals) # type: ignore
local_namespace = {}
exec(textwrap.dedent(s), global_namespace, local_namespace)
return local_namespace
def subTest(test, *args, **kwargs):
"""Compatibility shim for unittest.TestCase.subTest.
Usage: ``with tornado.test.util.subTest(self, x=x):``
"""
try:
subTest = test.subTest # py34+
except AttributeError:
subTest = contextlib.contextmanager(lambda *a, **kw: (yield))
return subTest(*args, **kwargs)
@contextlib.contextmanager
def ignore_deprecation():
"""Context manager to ignore deprecation warnings."""
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
yield

View file

@ -1,17 +1,21 @@
# coding: utf-8
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import re
import sys
import datetime
import tornado.escape
from tornado.escape import utf8
from tornado.util import raise_exc_info, Configurable, u, exec_in, ArgReplacer, timedelta_to_seconds, import_object
from tornado.test.util import unittest
from tornado.util import (
raise_exc_info, Configurable, exec_in, ArgReplacer,
timedelta_to_seconds, import_object, re_unescape, is_finalizing, PY3,
)
try:
from cStringIO import StringIO # py2
except ImportError:
from io import StringIO # py3
if PY3:
from io import StringIO
else:
from cStringIO import StringIO
class RaiseExcInfoTest(unittest.TestCase):
@ -57,12 +61,35 @@ class TestConfig2(TestConfigurable):
self.pos_arg = pos_arg
class TestConfig3(TestConfigurable):
# TestConfig3 is a configuration option that is itself configurable.
@classmethod
def configurable_base(cls):
return TestConfig3
@classmethod
def configurable_default(cls):
return TestConfig3A
class TestConfig3A(TestConfig3):
def initialize(self, a=None):
self.a = a
class TestConfig3B(TestConfig3):
def initialize(self, b=None):
self.b = b
class ConfigurableTest(unittest.TestCase):
def setUp(self):
self.saved = TestConfigurable._save_configuration()
self.saved3 = TestConfig3._save_configuration()
def tearDown(self):
TestConfigurable._restore_configuration(self.saved)
TestConfig3._restore_configuration(self.saved3)
def checkSubclasses(self):
# no matter how the class is configured, it should always be
@ -130,10 +157,43 @@ class ConfigurableTest(unittest.TestCase):
obj = TestConfig2()
self.assertIs(obj.b, None)
def test_config_multi_level(self):
TestConfigurable.configure(TestConfig3, a=1)
obj = TestConfigurable()
self.assertIsInstance(obj, TestConfig3A)
self.assertEqual(obj.a, 1)
TestConfigurable.configure(TestConfig3)
TestConfig3.configure(TestConfig3B, b=2)
obj = TestConfigurable()
self.assertIsInstance(obj, TestConfig3B)
self.assertEqual(obj.b, 2)
def test_config_inner_level(self):
# The inner level can be used even when the outer level
# doesn't point to it.
obj = TestConfig3()
self.assertIsInstance(obj, TestConfig3A)
TestConfig3.configure(TestConfig3B)
obj = TestConfig3()
self.assertIsInstance(obj, TestConfig3B)
# Configuring the base doesn't configure the inner.
obj = TestConfigurable()
self.assertIsInstance(obj, TestConfig1)
TestConfigurable.configure(TestConfig2)
obj = TestConfigurable()
self.assertIsInstance(obj, TestConfig2)
obj = TestConfig3()
self.assertIsInstance(obj, TestConfig3B)
class UnicodeLiteralTest(unittest.TestCase):
def test_unicode_escapes(self):
self.assertEqual(utf8(u('\u00e9')), b'\xc3\xa9')
self.assertEqual(utf8(u'\u00e9'), b'\xc3\xa9')
class ExecInTest(unittest.TestCase):
@ -189,7 +249,7 @@ class ImportObjectTest(unittest.TestCase):
self.assertIs(import_object('tornado.escape.utf8'), utf8)
def test_import_member_unicode(self):
self.assertIs(import_object(u('tornado.escape.utf8')), utf8)
self.assertIs(import_object(u'tornado.escape.utf8'), utf8)
def test_import_module(self):
self.assertIs(import_object('tornado.escape'), tornado.escape)
@ -198,4 +258,29 @@ class ImportObjectTest(unittest.TestCase):
# The internal implementation of __import__ differs depending on
# whether the thing being imported is a module or not.
# This variant requires a byte string in python 2.
self.assertIs(import_object(u('tornado.escape')), tornado.escape)
self.assertIs(import_object(u'tornado.escape'), tornado.escape)
class ReUnescapeTest(unittest.TestCase):
def test_re_unescape(self):
test_strings = (
'/favicon.ico',
'index.html',
'Hello, World!',
'!$@#%;',
)
for string in test_strings:
self.assertEqual(string, re_unescape(re.escape(string)))
def test_re_unescape_raises_error_on_invalid_input(self):
with self.assertRaises(ValueError):
re_unescape('\\d')
with self.assertRaises(ValueError):
re_unescape('\\b')
with self.assertRaises(ValueError):
re_unescape('\\Z')
class IsFinalizingTest(unittest.TestCase):
def test_basic(self):
self.assertFalse(is_finalizing())

View file

@ -1,18 +1,26 @@
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
from tornado.concurrent import Future
from tornado import gen
from tornado.escape import json_decode, utf8, to_unicode, recursive_unicode, native_str, to_basestring
from tornado.escape import json_decode, utf8, to_unicode, recursive_unicode, native_str, to_basestring # noqa: E501
from tornado.httpclient import HTTPClientError
from tornado.httputil import format_timestamp
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream
from tornado import locale
from tornado.locks import Event
from tornado.log import app_log, gen_log
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.template import DictLoader
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
from tornado.test.util import unittest, skipBefore35, exec_test
from tornado.util import u, ObjectDict, unicode_type, timedelta_to_seconds
from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish, removeslash, addslash, RedirectHandler as WebRedirectHandler, get_signature_key_version, GZipContentEncoding
from tornado.test.util import unittest, skipBefore35, exec_test, ignore_deprecation
from tornado.util import ObjectDict, unicode_type, timedelta_to_seconds, PY3
from tornado.web import (
Application, RequestHandler, StaticFileHandler, RedirectHandler as WebRedirectHandler,
HTTPError, MissingArgumentError, ErrorHandler, authenticated, asynchronous, url,
_create_signature_v1, create_signed_value, decode_signed_value, get_signature_key_version,
UIModule, Finish, stream_request_body, removeslash, addslash, GZipContentEncoding,
)
import binascii
import contextlib
@ -27,14 +35,16 @@ import os
import re
import socket
try:
if PY3:
import urllib.parse as urllib_parse # py3
except ImportError:
else:
import urllib as urllib_parse # py2
wsgi_safe_tests = []
relpath = lambda *a: os.path.join(os.path.dirname(__file__), *a)
def relpath(*a):
return os.path.join(os.path.dirname(__file__), *a)
def wsgi_safe(cls):
@ -181,6 +191,42 @@ class SecureCookieV2Test(unittest.TestCase):
self.assertEqual(new_handler.get_secure_cookie('foo'), None)
class FinalReturnTest(WebTestCase):
def get_handlers(self):
test = self
class FinishHandler(RequestHandler):
@gen.coroutine
def get(self):
test.final_return = self.finish()
yield test.final_return
class RenderHandler(RequestHandler):
def create_template_loader(self, path):
return DictLoader({'foo.html': 'hi'})
@gen.coroutine
def get(self):
test.final_return = self.render('foo.html')
return [("/finish", FinishHandler),
("/render", RenderHandler)]
def get_app_kwargs(self):
return dict(template_path='FinalReturnTest')
def test_finish_method_return_future(self):
response = self.fetch(self.get_url('/finish'))
self.assertEqual(response.code, 200)
self.assertIsInstance(self.final_return, Future)
self.assertTrue(self.final_return.done())
def test_render_method_return_future(self):
response = self.fetch(self.get_url('/render'))
self.assertEqual(response.code, 200)
self.assertIsInstance(self.final_return, Future)
class CookieTest(WebTestCase):
def get_handlers(self):
class SetCookieHandler(RequestHandler):
@ -188,7 +234,7 @@ class CookieTest(WebTestCase):
# Try setting cookies with different argument types
# to ensure that everything gets encoded correctly
self.set_cookie("str", "asdf")
self.set_cookie("unicode", u("qwer"))
self.set_cookie("unicode", u"qwer")
self.set_cookie("bytes", b"zxcv")
class GetCookieHandler(RequestHandler):
@ -199,8 +245,8 @@ class CookieTest(WebTestCase):
def get(self):
# unicode domain and path arguments shouldn't break things
# either (see bug #285)
self.set_cookie("unicode_args", "blah", domain=u("foo.com"),
path=u("/foo"))
self.set_cookie("unicode_args", "blah", domain=u"foo.com",
path=u"/foo")
class SetCookieSpecialCharHandler(RequestHandler):
def get(self):
@ -277,8 +323,8 @@ class CookieTest(WebTestCase):
data = [('foo=a=b', 'a=b'),
('foo="a=b"', 'a=b'),
('foo="a;b"', 'a;b'),
# ('foo=a\\073b', 'a;b'), # even encoded, ";" is a delimiter
('foo="a;b"', '"a'), # even quoted, ";" is a delimiter
('foo=a\\073b', 'a\\073b'), # escapes only decoded in quotes
('foo="a\\073b"', 'a;b'),
('foo="a\\"b"', 'a"b'),
]
@ -342,19 +388,17 @@ class AuthRedirectTest(WebTestCase):
dict(login_url='http://example.com/login'))]
def test_relative_auth_redirect(self):
self.http_client.fetch(self.get_url('/relative'), self.stop,
follow_redirects=False)
response = self.wait()
response = self.fetch(self.get_url('/relative'),
follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertEqual(response.headers['Location'], '/login?next=%2Frelative')
def test_absolute_auth_redirect(self):
self.http_client.fetch(self.get_url('/absolute'), self.stop,
follow_redirects=False)
response = self.wait()
response = self.fetch(self.get_url('/absolute'),
follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(re.match(
'http://example.com/login\?next=http%3A%2F%2Flocalhost%3A[0-9]+%2Fabsolute',
'http://example.com/login\?next=http%3A%2F%2F127.0.0.1%3A[0-9]+%2Fabsolute',
response.headers['Location']), response.headers['Location'])
@ -362,9 +406,11 @@ class ConnectionCloseHandler(RequestHandler):
def initialize(self, test):
self.test = test
@asynchronous
@gen.coroutine
def get(self):
self.test.on_handler_waiting()
never_finish = Event()
yield never_finish.wait()
def on_connection_close(self):
self.test.on_connection_close()
@ -377,7 +423,7 @@ class ConnectionCloseTest(WebTestCase):
def test_connection_close(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
s.connect(("127.0.0.1", self.get_http_port()))
self.stream = IOStream(s, io_loop=self.io_loop)
self.stream = IOStream(s)
self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
self.wait()
@ -434,9 +480,9 @@ class RequestEncodingTest(WebTestCase):
def test_group_encoding(self):
# Path components and query arguments should be decoded the same way
self.assertEqual(self.fetch_json('/group/%C3%A9?arg=%C3%A9'),
{u("path"): u("/group/%C3%A9"),
u("path_args"): [u("\u00e9")],
u("args"): {u("arg"): [u("\u00e9")]}})
{u"path": u"/group/%C3%A9",
u"path_args": [u"\u00e9"],
u"args": {u"arg": [u"\u00e9"]}})
def test_slashes(self):
# Slashes may be escaped to appear as a single "directory" in the path,
@ -541,14 +587,17 @@ class OptionalPathHandler(RequestHandler):
class FlowControlHandler(RequestHandler):
# These writes are too small to demonstrate real flow control,
# but at least it shows that the callbacks get run.
@asynchronous
def get(self):
self.write("1")
self.flush(callback=self.step2)
with ignore_deprecation():
@asynchronous
def get(self):
self.write("1")
with ignore_deprecation():
self.flush(callback=self.step2)
def step2(self):
self.write("2")
self.flush(callback=self.step3)
with ignore_deprecation():
self.flush(callback=self.step3)
def step3(self):
self.write("3")
@ -574,14 +623,13 @@ class RedirectHandler(RequestHandler):
class EmptyFlushCallbackHandler(RequestHandler):
@asynchronous
@gen.engine
@gen.coroutine
def get(self):
# Ensure that the flush callback is run whether or not there
# was any output. The gen.Task and direct yield forms are
# equivalent.
yield gen.Task(self.flush) # "empty" flush, but writes headers
yield gen.Task(self.flush) # empty flush
yield self.flush() # "empty" flush, but writes headers
yield self.flush() # empty flush
self.write("o")
yield self.flush() # flushes the "o"
yield self.flush() # empty flush
@ -633,7 +681,12 @@ class WSGISafeWebTest(WebTestCase):
{% end %}
</body></html>""",
"entry.html": """\
{{ set_resources(embedded_css=".entry { margin-bottom: 1em; }", embedded_javascript="js_embed()", css_files=["/base.css", "/foo.css"], javascript_files="/common.js", html_head="<meta>", html_body='<script src="/analytics.js"/>') }}
{{ set_resources(embedded_css=".entry { margin-bottom: 1em; }",
embedded_javascript="js_embed()",
css_files=["/base.css", "/foo.css"],
javascript_files="/common.js",
html_head="<meta>",
html_body='<script src="/analytics.js"/>') }}
<div class="entry">...</div>""",
})
return dict(template_loader=loader,
@ -655,8 +708,10 @@ class WSGISafeWebTest(WebTestCase):
url("/multi_header", MultiHeaderHandler),
url("/redirect", RedirectHandler),
url("/web_redirect_permanent", WebRedirectHandler, {"url": "/web_redirect_newpath"}),
url("/web_redirect", WebRedirectHandler, {"url": "/web_redirect_newpath", "permanent": False}),
url("//web_redirect_double_slash", WebRedirectHandler, {"url": '/web_redirect_newpath'}),
url("/web_redirect", WebRedirectHandler,
{"url": "/web_redirect_newpath", "permanent": False}),
url("//web_redirect_double_slash", WebRedirectHandler,
{"url": '/web_redirect_newpath'}),
url("/header_injection", HeaderInjectionHandler),
url("/get_argument", GetArgumentHandler),
url("/get_arguments", GetArgumentsHandler),
@ -690,15 +745,15 @@ class WSGISafeWebTest(WebTestCase):
response = self.fetch(req_url)
response.rethrow()
data = json_decode(response.body)
self.assertEqual(data, {u('path'): [u('unicode'), u('\u00e9')],
u('query'): [u('unicode'), u('\u00e9')],
self.assertEqual(data, {u'path': [u'unicode', u'\u00e9'],
u'query': [u'unicode', u'\u00e9'],
})
response = self.fetch("/decode_arg/%C3%A9?foo=%C3%A9")
response.rethrow()
data = json_decode(response.body)
self.assertEqual(data, {u('path'): [u('bytes'), u('c3a9')],
u('query'): [u('bytes'), u('c3a9')],
self.assertEqual(data, {u'path': [u'bytes', u'c3a9'],
u'query': [u'bytes', u'c3a9'],
})
def test_decode_argument_invalid_unicode(self):
@ -717,8 +772,8 @@ class WSGISafeWebTest(WebTestCase):
response = self.fetch(req_url)
response.rethrow()
data = json_decode(response.body)
self.assertEqual(data, {u('path'): [u('unicode'), u('1 + 1')],
u('query'): [u('unicode'), u('1 + 1')],
self.assertEqual(data, {u'path': [u'unicode', u'1 + 1'],
u'query': [u'unicode', u'1 + 1'],
})
def test_reverse_url(self):
@ -728,7 +783,7 @@ class WSGISafeWebTest(WebTestCase):
'/decode_arg/42')
self.assertEqual(self.app.reverse_url('decode_arg', b'\xe9'),
'/decode_arg/%E9')
self.assertEqual(self.app.reverse_url('decode_arg', u('\u00e9')),
self.assertEqual(self.app.reverse_url('decode_arg', u'\u00e9'),
'/decode_arg/%C3%A9')
self.assertEqual(self.app.reverse_url('decode_arg', '1 + 1'),
'/decode_arg/1%20%2B%201')
@ -761,13 +816,13 @@ js_embed()
//]]>
</script>
<script src="/analytics.js"/>
</body></html>""")
</body></html>""") # noqa: E501
def test_optional_path(self):
self.assertEqual(self.fetch_json("/optional_path/foo"),
{u("path"): u("foo")})
{u"path": u"foo"})
self.assertEqual(self.fetch_json("/optional_path/"),
{u("path"): None})
{u"path": None})
def test_multi_header(self):
response = self.fetch("/multi_header")
@ -915,6 +970,10 @@ class ErrorResponseTest(WebTestCase):
self.assertEqual(response.code, 503)
self.assertTrue(b"503: Service Unavailable" in response.body)
response = self.fetch("/default?status=435")
self.assertEqual(response.code, 435)
self.assertTrue(b"435: Unknown" in response.body)
def test_write_error(self):
with ExpectLog(app_log, "Uncaught exception"):
response = self.fetch("/write_error")
@ -1062,6 +1121,13 @@ class StaticFileTest(WebTestCase):
'If-None-Match': response1.headers['Etag']})
self.assertEqual(response2.code, 304)
def test_static_304_etag_modified_bug(self):
response1 = self.get_and_head("/static/robots.txt")
response2 = self.get_and_head("/static/robots.txt", headers={
'If-None-Match': '"MISMATCH"',
'If-Modified-Since': response1.headers['Last-Modified']})
self.assertEqual(response2.code, 200)
def test_static_if_modified_since_pre_epoch(self):
# On windows, the functions that work with time_t do not accept
# negative values, and at least one client (processing.js) seems
@ -1346,6 +1412,8 @@ class HostMatchingTest(WebTestCase):
[("/bar", HostMatchingTest.Handler, {"reply": "[1]"})])
self.app.add_handlers("www.example.com",
[("/baz", HostMatchingTest.Handler, {"reply": "[2]"})])
self.app.add_handlers("www.e.*e.com",
[("/baz", HostMatchingTest.Handler, {"reply": "[3]"})])
response = self.fetch("/foo")
self.assertEqual(response.body, b"wildcard")
@ -1360,6 +1428,40 @@ class HostMatchingTest(WebTestCase):
self.assertEqual(response.body, b"[1]")
response = self.fetch("/baz", headers={'Host': 'www.example.com'})
self.assertEqual(response.body, b"[2]")
response = self.fetch("/baz", headers={'Host': 'www.exe.com'})
self.assertEqual(response.body, b"[3]")
@wsgi_safe
class DefaultHostMatchingTest(WebTestCase):
def get_handlers(self):
return []
def get_app_kwargs(self):
return {'default_host': "www.example.com"}
def test_default_host_matching(self):
self.app.add_handlers("www.example.com",
[("/foo", HostMatchingTest.Handler, {"reply": "[0]"})])
self.app.add_handlers(r"www\.example\.com",
[("/bar", HostMatchingTest.Handler, {"reply": "[1]"})])
self.app.add_handlers("www.test.com",
[("/baz", HostMatchingTest.Handler, {"reply": "[2]"})])
response = self.fetch("/foo")
self.assertEqual(response.body, b"[0]")
response = self.fetch("/bar")
self.assertEqual(response.body, b"[1]")
response = self.fetch("/baz")
self.assertEqual(response.code, 404)
response = self.fetch("/foo", headers={"X-Real-Ip": "127.0.0.1"})
self.assertEqual(response.code, 404)
self.app.default_host = "www.test.com"
response = self.fetch("/baz")
self.assertEqual(response.body, b"[2]")
@wsgi_safe
@ -1370,7 +1472,7 @@ class NamedURLSpecGroupsTest(WebTestCase):
self.write(path)
return [("/str/(?P<path>.*)", EchoHandler),
(u("/unicode/(?P<path>.*)"), EchoHandler)]
(u"/unicode/(?P<path>.*)", EchoHandler)]
def test_named_urlspec_groups(self):
response = self.fetch("/str/foo")
@ -1395,6 +1497,19 @@ class ClearHeaderTest(SimpleHandlerTestCase):
self.assertEqual(response.headers["h2"], "bar")
class Header204Test(SimpleHandlerTestCase):
class Handler(RequestHandler):
def get(self):
self.set_status(204)
self.finish()
def test_204_headers(self):
response = self.fetch('/')
self.assertEqual(response.code, 204)
self.assertNotIn("Content-Length", response.headers)
self.assertNotIn("Transfer-Encoding", response.headers)
@wsgi_safe
class Header304Test(SimpleHandlerTestCase):
class Handler(RequestHandler):
@ -1426,7 +1541,7 @@ class StatusReasonTest(SimpleHandlerTestCase):
def get_http_client(self):
# simple_httpclient only: curl doesn't expose the reason string
return SimpleAsyncHTTPClient(io_loop=self.io_loop)
return SimpleAsyncHTTPClient()
def test_status(self):
response = self.fetch("/?code=304")
@ -1438,9 +1553,9 @@ class StatusReasonTest(SimpleHandlerTestCase):
response = self.fetch("/?code=682&reason=Bar")
self.assertEqual(response.code, 682)
self.assertEqual(response.reason, "Bar")
with ExpectLog(app_log, 'Uncaught exception'):
response = self.fetch("/?code=682")
self.assertEqual(response.code, 500)
response = self.fetch("/?code=682")
self.assertEqual(response.code, 682)
self.assertEqual(response.reason, "Unknown")
@wsgi_safe
@ -1465,7 +1580,7 @@ class RaiseWithReasonTest(SimpleHandlerTestCase):
def get_http_client(self):
# simple_httpclient only: curl doesn't expose the reason string
return SimpleAsyncHTTPClient(io_loop=self.io_loop)
return SimpleAsyncHTTPClient()
def test_raise_with_reason(self):
response = self.fetch("/")
@ -1476,6 +1591,9 @@ class RaiseWithReasonTest(SimpleHandlerTestCase):
def test_httperror_str(self):
self.assertEqual(str(HTTPError(682, reason="Foo")), "HTTP 682: Foo")
def test_httperror_str_from_httputil(self):
self.assertEqual(str(HTTPError(682)), "HTTP 682: Unknown")
@wsgi_safe
class ErrorHandlerXSRFTest(WebTestCase):
@ -1501,8 +1619,8 @@ class ErrorHandlerXSRFTest(WebTestCase):
class GzipTestCase(SimpleHandlerTestCase):
class Handler(RequestHandler):
def get(self):
if self.get_argument('vary', None):
self.set_header('Vary', self.get_argument('vary'))
for v in self.get_arguments('vary'):
self.add_header('Vary', v)
# Must write at least MIN_LENGTH bytes to activate compression.
self.write('hello world' + ('!' * GZipContentEncoding.MIN_LENGTH))
@ -1511,8 +1629,7 @@ class GzipTestCase(SimpleHandlerTestCase):
gzip=True,
static_path=os.path.join(os.path.dirname(__file__), 'static'))
def test_gzip(self):
response = self.fetch('/')
def assert_compressed(self, response):
# simple_httpclient renames the content-encoding header;
# curl_httpclient doesn't.
self.assertEqual(
@ -1520,17 +1637,17 @@ class GzipTestCase(SimpleHandlerTestCase):
'Content-Encoding',
response.headers.get('X-Consumed-Content-Encoding')),
'gzip')
def test_gzip(self):
response = self.fetch('/')
self.assert_compressed(response)
self.assertEqual(response.headers['Vary'], 'Accept-Encoding')
def test_gzip_static(self):
# The streaming responses in StaticFileHandler have subtle
# interactions with the gzip output so test this case separately.
response = self.fetch('/robots.txt')
self.assertEqual(
response.headers.get(
'Content-Encoding',
response.headers.get('X-Consumed-Content-Encoding')),
'gzip')
self.assert_compressed(response)
self.assertEqual(response.headers['Vary'], 'Accept-Encoding')
def test_gzip_not_requested(self):
@ -1540,8 +1657,16 @@ class GzipTestCase(SimpleHandlerTestCase):
def test_vary_already_present(self):
response = self.fetch('/?vary=Accept-Language')
self.assertEqual(response.headers['Vary'],
'Accept-Language, Accept-Encoding')
self.assert_compressed(response)
self.assertEqual([s.strip() for s in response.headers['Vary'].split(',')],
['Accept-Language', 'Accept-Encoding'])
def test_vary_already_present_multiple(self):
# Regression test for https://github.com/tornadoweb/tornado/issues/1670
response = self.fetch('/?vary=Accept-Language&vary=Cookie')
self.assert_compressed(response)
self.assertEqual([s.strip() for s in response.headers['Vary'].split(',')],
['Accept-Language', 'Cookie', 'Accept-Encoding'])
@wsgi_safe
@ -1721,28 +1846,29 @@ class MultipleExceptionTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
exc_count = 0
@asynchronous
def get(self):
from tornado.ioloop import IOLoop
IOLoop.current().add_callback(lambda: 1 / 0)
IOLoop.current().add_callback(lambda: 1 / 0)
with ignore_deprecation():
@asynchronous
def get(self):
IOLoop.current().add_callback(lambda: 1 / 0)
IOLoop.current().add_callback(lambda: 1 / 0)
def log_exception(self, typ, value, tb):
MultipleExceptionTest.Handler.exc_count += 1
def test_multi_exception(self):
# This test verifies that multiple exceptions raised into the same
# ExceptionStackContext do not generate extraneous log entries
# due to "Cannot send error response after headers written".
# log_exception is called, but it does not proceed to send_error.
response = self.fetch('/')
self.assertEqual(response.code, 500)
response = self.fetch('/')
self.assertEqual(response.code, 500)
# Each of our two requests generated two exceptions, we should have
# seen at least three of them by now (the fourth may still be
# in the queue).
self.assertGreater(MultipleExceptionTest.Handler.exc_count, 2)
with ignore_deprecation():
# This test verifies that multiple exceptions raised into the same
# ExceptionStackContext do not generate extraneous log entries
# due to "Cannot send error response after headers written".
# log_exception is called, but it does not proceed to send_error.
response = self.fetch('/')
self.assertEqual(response.code, 500)
response = self.fetch('/')
self.assertEqual(response.code, 500)
# Each of our two requests generated two exceptions, we should have
# seen at least three of them by now (the fourth may still be
# in the queue).
self.assertGreater(MultipleExceptionTest.Handler.exc_count, 2)
@wsgi_safe
@ -2050,7 +2176,7 @@ class StreamingRequestBodyTest(WebTestCase):
# Use a raw connection so we can control the sending of data.
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
s.connect(("127.0.0.1", self.get_http_port()))
stream = IOStream(s, io_loop=self.io_loop)
stream = IOStream(s)
stream.write(b"GET " + url + b" HTTP/1.1\r\n")
if connection_close:
stream.write(b"Connection: close\r\n")
@ -2073,9 +2199,9 @@ class StreamingRequestBodyTest(WebTestCase):
stream.write(b"4\r\nqwer\r\n")
data = yield self.data
self.assertEquals(data, b"qwer")
stream.write(b"0\r\n")
stream.write(b"0\r\n\r\n")
yield self.finished
data = yield gen.Task(stream.read_until_close)
data = yield stream.read_until_close()
# This would ideally use an HTTP1Connection to read the response.
self.assertTrue(data.endswith(b"{}"))
stream.close()
@ -2083,14 +2209,14 @@ class StreamingRequestBodyTest(WebTestCase):
@gen_test
def test_early_return(self):
stream = self.connect(b"/early_return", connection_close=False)
data = yield gen.Task(stream.read_until_close)
data = yield stream.read_until_close()
self.assertTrue(data.startswith(b"HTTP/1.1 401"))
@gen_test
def test_early_return_with_data(self):
stream = self.connect(b"/early_return", connection_close=False)
stream.write(b"4\r\nasdf\r\n")
data = yield gen.Task(stream.read_until_close)
data = yield stream.read_until_close()
self.assertTrue(data.startswith(b"HTTP/1.1 401"))
@gen_test
@ -2129,12 +2255,12 @@ class BaseFlowControlHandler(RequestHandler):
# Note that asynchronous prepare() does not block data_received,
# so we don't use in_method here.
self.methods.append('prepare')
yield gen.Task(IOLoop.current().add_callback)
yield gen.moment
@gen.coroutine
def post(self):
with self.in_method('post'):
yield gen.Task(IOLoop.current().add_callback)
yield gen.moment
self.write(dict(methods=self.methods))
@ -2146,7 +2272,7 @@ class BaseStreamingRequestFlowControlTest(object):
def get_http_client(self):
# simple_httpclient only: curl doesn't support body_producer.
return SimpleAsyncHTTPClient(io_loop=self.io_loop)
return SimpleAsyncHTTPClient()
# Test all the slightly different code paths for fixed, chunked, etc bodies.
def test_flow_control_fixed_body(self):
@ -2160,6 +2286,7 @@ class BaseStreamingRequestFlowControlTest(object):
def test_flow_control_chunked_body(self):
chunks = [b'abcd', b'efgh', b'ijkl']
@gen.coroutine
def body_producer(write):
for i in chunks:
@ -2185,6 +2312,7 @@ class BaseStreamingRequestFlowControlTest(object):
'data_received', 'data_received',
'post']))
class DecoratedStreamingRequestFlowControlTest(
BaseStreamingRequestFlowControlTest,
WebTestCase):
@ -2193,7 +2321,7 @@ class DecoratedStreamingRequestFlowControlTest(
@gen.coroutine
def data_received(self, data):
with self.in_method('data_received'):
yield gen.Task(IOLoop.current().add_callback)
yield gen.moment
return [('/', DecoratedFlowControlHandler, dict(test=self))]
@ -2206,7 +2334,8 @@ class NativeStreamingRequestFlowControlTest(
data_received = exec_test(globals(), locals(), """
async def data_received(self, data):
with self.in_method('data_received'):
await gen.Task(IOLoop.current().add_callback)
import asyncio
await asyncio.sleep(0)
""")["data_received"]
return [('/', NativeFlowControlHandler, dict(test=self))]
@ -2247,8 +2376,8 @@ class IncorrectContentLengthTest(SimpleHandlerTestCase):
with ExpectLog(gen_log,
"(Cannot send error response after headers written"
"|Failed to flush partial response)"):
response = self.fetch("/high")
self.assertEqual(response.code, 599)
with self.assertRaises(HTTPClientError):
self.fetch("/high", raise_error=True)
self.assertEqual(str(self.server_error),
"Tried to write 40 bytes less than Content-Length")
@ -2260,8 +2389,8 @@ class IncorrectContentLengthTest(SimpleHandlerTestCase):
with ExpectLog(gen_log,
"(Cannot send error response after headers written"
"|Failed to flush partial response)"):
response = self.fetch("/low")
self.assertEqual(response.code, 599)
with self.assertRaises(HTTPClientError):
self.fetch("/low", raise_error=True)
self.assertEqual(str(self.server_error),
"Tried to write more data than Content-Length")
@ -2282,10 +2411,11 @@ class ClientCloseTest(SimpleHandlerTestCase):
self.write('requires HTTP/1.x')
def test_client_close(self):
response = self.fetch('/')
if response.body == b'requires HTTP/1.x':
self.skipTest('requires HTTP/1.x')
self.assertEqual(response.code, 599)
with self.assertRaises((HTTPClientError, unittest.SkipTest)):
response = self.fetch('/', raise_error=True)
if response.body == b'requires HTTP/1.x':
self.skipTest('requires HTTP/1.x')
self.assertEqual(response.code, 599)
class SignedValueTest(unittest.TestCase):
@ -2483,6 +2613,22 @@ class XSRFTest(SimpleHandlerTestCase):
body=urllib_parse.urlencode(dict(_xsrf=self.xsrf_token)))
self.assertEqual(response.code, 403)
def test_xsrf_fail_argument_invalid_format(self):
with ExpectLog(gen_log, ".*'_xsrf' argument has invalid format"):
response = self.fetch(
"/", method="POST",
headers=self.cookie_headers(),
body=urllib_parse.urlencode(dict(_xsrf='3|')))
self.assertEqual(response.code, 403)
def test_xsrf_fail_cookie_invalid_format(self):
with ExpectLog(gen_log, ".*XSRF cookie does not match POST"):
response = self.fetch(
"/", method="POST",
headers=self.cookie_headers(token='3|'),
body=urllib_parse.urlencode(dict(_xsrf=self.xsrf_token)))
self.assertEqual(response.code, 403)
def test_xsrf_fail_cookie_no_body(self):
with ExpectLog(gen_log, ".*'_xsrf' argument missing"):
response = self.fetch(
@ -2520,7 +2666,7 @@ class XSRFTest(SimpleHandlerTestCase):
def test_xsrf_success_header(self):
response = self.fetch("/", method="POST", body=b"",
headers=dict({"X-Xsrftoken": self.xsrf_token},
headers=dict({"X-Xsrftoken": self.xsrf_token}, # type: ignore
**self.cookie_headers()))
self.assertEqual(response.code, 200)
@ -2623,8 +2769,8 @@ class FinishExceptionTest(SimpleHandlerTestCase):
raise Finish()
def test_finish_exception(self):
for url in ['/', '/?finish_value=1']:
response = self.fetch(url)
for u in ['/', '/?finish_value=1']:
response = self.fetch(u)
self.assertEqual(response.code, 401)
self.assertEqual('Basic realm="something"',
response.headers.get('WWW-Authenticate'))
@ -2763,3 +2909,59 @@ class ApplicationTest(AsyncTestCase):
app = Application([])
server = app.listen(0, address='127.0.0.1')
server.stop()
class URLSpecReverseTest(unittest.TestCase):
def test_reverse(self):
self.assertEqual('/favicon.ico', url(r'/favicon\.ico', None).reverse())
self.assertEqual('/favicon.ico', url(r'^/favicon\.ico$', None).reverse())
def test_non_reversible(self):
# URLSpecs are non-reversible if they include non-constant
# regex features outside capturing groups. Currently, this is
# only strictly enforced for backslash-escaped character
# classes.
paths = [
r'^/api/v\d+/foo/(\w+)$',
]
for path in paths:
# A URLSpec can still be created even if it cannot be reversed.
url_spec = url(path, None)
try:
result = url_spec.reverse()
self.fail("did not get expected exception when reversing %s. "
"result: %s" % (path, result))
except ValueError:
pass
def test_reverse_arguments(self):
self.assertEqual('/api/v1/foo/bar',
url(r'^/api/v1/foo/(\w+)$', None).reverse('bar'))
class RedirectHandlerTest(WebTestCase):
def get_handlers(self):
return [
('/src', WebRedirectHandler, {'url': '/dst'}),
('/src2', WebRedirectHandler, {'url': '/dst2?foo=bar'}),
(r'/(.*?)/(.*?)/(.*)', WebRedirectHandler, {'url': '/{1}/{0}/{2}'})]
def test_basic_redirect(self):
response = self.fetch('/src', follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], '/dst')
def test_redirect_with_argument(self):
response = self.fetch('/src?foo=bar', follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], '/dst?foo=bar')
def test_redirect_with_appending_argument(self):
response = self.fetch('/src2?foo2=bar2', follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], '/dst2?foo=bar&foo2=bar2')
def test_redirect_pattern(self):
response = self.fetch('/a/b/c', follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], '/b/a/c')

View file

@ -1,15 +1,19 @@
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import functools
import sys
import traceback
from tornado.concurrent import Future
from tornado import gen
from tornado.httpclient import HTTPError, HTTPRequest
from tornado.locks import Event
from tornado.log import gen_log, app_log
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.template import DictLoader
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
from tornado.test.util import unittest
from tornado.test.util import unittest, skipBefore35, exec_test
from tornado.web import Application, RequestHandler
from tornado.util import u
try:
import tornado.websocket # noqa
@ -22,7 +26,9 @@ except ImportError:
traceback.print_exc()
raise
from tornado.websocket import WebSocketHandler, websocket_connect, WebSocketError
from tornado.websocket import (
WebSocketHandler, websocket_connect, WebSocketError, WebSocketClosedError,
)
try:
from tornado import speedups
@ -47,8 +53,12 @@ class TestWebSocketHandler(WebSocketHandler):
class EchoHandler(TestWebSocketHandler):
@gen.coroutine
def on_message(self, message):
self.write_message(message, isinstance(message, bytes))
try:
yield self.write_message(message, isinstance(message, bytes))
except WebSocketClosedError:
pass
class ErrorInOnMessageHandler(TestWebSocketHandler):
@ -58,16 +68,36 @@ class ErrorInOnMessageHandler(TestWebSocketHandler):
class HeaderHandler(TestWebSocketHandler):
def open(self):
try:
# In a websocket context, many RequestHandler methods
# raise RuntimeErrors.
self.set_status(503)
raise Exception("did not get expected exception")
except RuntimeError:
pass
methods_to_test = [
functools.partial(self.write, 'This should not work'),
functools.partial(self.redirect, 'http://localhost/elsewhere'),
functools.partial(self.set_header, 'X-Test', ''),
functools.partial(self.set_cookie, 'Chocolate', 'Chip'),
functools.partial(self.set_status, 503),
self.flush,
self.finish,
]
for method in methods_to_test:
try:
# In a websocket context, many RequestHandler methods
# raise RuntimeErrors.
method()
raise Exception("did not get expected exception")
except RuntimeError:
pass
self.write_message(self.request.headers.get('X-Test', ''))
class HeaderEchoHandler(TestWebSocketHandler):
def set_default_headers(self):
self.set_header("X-Extra-Response-Header", "Extra-Response-Value")
def prepare(self):
for k, v in self.request.headers.get_all():
if k.lower().startswith('x-test'):
self.set_header(k, v)
class NonWebSocketHandler(RequestHandler):
def get(self):
self.write('ok')
@ -88,12 +118,75 @@ class AsyncPrepareHandler(TestWebSocketHandler):
self.write_message(message)
class PathArgsHandler(TestWebSocketHandler):
def open(self, arg):
self.write_message(arg)
class CoroutineOnMessageHandler(TestWebSocketHandler):
def initialize(self, close_future, compression_options=None):
super(CoroutineOnMessageHandler, self).initialize(close_future,
compression_options)
self.sleeping = 0
@gen.coroutine
def on_message(self, message):
if self.sleeping > 0:
self.write_message('another coroutine is already sleeping')
self.sleeping += 1
yield gen.sleep(0.01)
self.sleeping -= 1
self.write_message(message)
class RenderMessageHandler(TestWebSocketHandler):
def on_message(self, message):
self.write_message(self.render_string('message.html', message=message))
class SubprotocolHandler(TestWebSocketHandler):
def initialize(self, **kwargs):
super(SubprotocolHandler, self).initialize(**kwargs)
self.select_subprotocol_called = False
def select_subprotocol(self, subprotocols):
if self.select_subprotocol_called:
raise Exception("select_subprotocol called twice")
self.select_subprotocol_called = True
if 'goodproto' in subprotocols:
return 'goodproto'
return None
def open(self):
if not self.select_subprotocol_called:
raise Exception("select_subprotocol not called")
self.write_message("subprotocol=%s" % self.selected_subprotocol)
class OpenCoroutineHandler(TestWebSocketHandler):
def initialize(self, test, **kwargs):
super(OpenCoroutineHandler, self).initialize(**kwargs)
self.test = test
self.open_finished = False
@gen.coroutine
def open(self):
yield self.test.message_sent.wait()
yield gen.sleep(0.010)
self.open_finished = True
def on_message(self, message):
if not self.open_finished:
raise Exception('on_message called before open finished')
self.write_message('ok')
class WebSocketBaseTestCase(AsyncHTTPTestCase):
@gen.coroutine
def ws_connect(self, path, compression_options=None):
def ws_connect(self, path, **kwargs):
ws = yield websocket_connect(
'ws://127.0.0.1:%d%s' % (self.get_http_port(), path),
compression_options=compression_options)
**kwargs)
raise gen.Return(ws)
@gen.coroutine
@ -114,19 +207,55 @@ class WebSocketTest(WebSocketBaseTestCase):
('/echo', EchoHandler, dict(close_future=self.close_future)),
('/non_ws', NonWebSocketHandler),
('/header', HeaderHandler, dict(close_future=self.close_future)),
('/header_echo', HeaderEchoHandler,
dict(close_future=self.close_future)),
('/close_reason', CloseReasonHandler,
dict(close_future=self.close_future)),
('/error_in_on_message', ErrorInOnMessageHandler,
dict(close_future=self.close_future)),
('/async_prepare', AsyncPrepareHandler,
dict(close_future=self.close_future)),
])
('/path_args/(.*)', PathArgsHandler,
dict(close_future=self.close_future)),
('/coroutine', CoroutineOnMessageHandler,
dict(close_future=self.close_future)),
('/render', RenderMessageHandler,
dict(close_future=self.close_future)),
('/subprotocol', SubprotocolHandler,
dict(close_future=self.close_future)),
('/open_coroutine', OpenCoroutineHandler,
dict(close_future=self.close_future, test=self)),
], template_loader=DictLoader({
'message.html': '<b>{{ message }}</b>',
}))
def get_http_client(self):
# These tests require HTTP/1; force the use of SimpleAsyncHTTPClient.
return SimpleAsyncHTTPClient()
def tearDown(self):
super(WebSocketTest, self).tearDown()
RequestHandler._template_loaders.clear()
def test_http_request(self):
# WS server, HTTP client.
response = self.fetch('/echo')
self.assertEqual(response.code, 400)
def test_missing_websocket_key(self):
response = self.fetch('/echo',
headers={'Connection': 'Upgrade',
'Upgrade': 'WebSocket',
'Sec-WebSocket-Version': '13'})
self.assertEqual(response.code, 400)
def test_bad_websocket_version(self):
response = self.fetch('/echo',
headers={'Connection': 'Upgrade',
'Upgrade': 'WebSocket',
'Sec-WebSocket-Version': '12'})
self.assertEqual(response.code, 426)
@gen_test
def test_websocket_gen(self):
ws = yield self.ws_connect('/echo')
@ -138,7 +267,7 @@ class WebSocketTest(WebSocketBaseTestCase):
def test_websocket_callbacks(self):
websocket_connect(
'ws://127.0.0.1:%d/echo' % self.get_http_port(),
io_loop=self.io_loop, callback=self.stop)
callback=self.stop)
ws = self.wait().result()
ws.write_message('hello')
ws.read_message(self.stop)
@ -159,9 +288,17 @@ class WebSocketTest(WebSocketBaseTestCase):
@gen_test
def test_unicode_message(self):
ws = yield self.ws_connect('/echo')
ws.write_message(u('hello \u00e9'))
ws.write_message(u'hello \u00e9')
response = yield ws.read_message()
self.assertEqual(response, u('hello \u00e9'))
self.assertEqual(response, u'hello \u00e9')
yield self.close(ws)
@gen_test
def test_render_message(self):
ws = yield self.ws_connect('/render')
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, '<b>hello</b>')
yield self.close(ws)
@gen_test
@ -192,7 +329,6 @@ class WebSocketTest(WebSocketBaseTestCase):
with ExpectLog(gen_log, ".*"):
yield websocket_connect(
'ws://127.0.0.1:%d/' % port,
io_loop=self.io_loop,
connect_timeout=3600)
@gen_test
@ -215,6 +351,18 @@ class WebSocketTest(WebSocketBaseTestCase):
self.assertEqual(response, 'hello')
yield self.close(ws)
@gen_test
def test_websocket_header_echo(self):
# Ensure that headers can be returned in the response.
# Specifically, that arbitrary headers passed through websocket_connect
# can be returned.
ws = yield websocket_connect(
HTTPRequest('ws://127.0.0.1:%d/header_echo' % self.get_http_port(),
headers={'X-Test-Hello': 'hello'}))
self.assertEqual(ws.headers.get('X-Test-Hello'), 'hello')
self.assertEqual(ws.headers.get('X-Extra-Response-Header'), 'Extra-Response-Value')
yield self.close(ws)
@gen_test
def test_server_close_reason(self):
ws = yield self.ws_connect('/close_reason')
@ -238,6 +386,14 @@ class WebSocketTest(WebSocketBaseTestCase):
self.assertEqual(code, 1001)
self.assertEqual(reason, 'goodbye')
@gen_test
def test_write_after_close(self):
ws = yield self.ws_connect('/close_reason')
msg = yield ws.read_message()
self.assertIs(msg, None)
with self.assertRaises(WebSocketClosedError):
ws.write_message('hello')
@gen_test
def test_async_prepare(self):
# Previously, an async prepare method triggered a bug that would
@ -247,6 +403,23 @@ class WebSocketTest(WebSocketBaseTestCase):
res = yield ws.read_message()
self.assertEqual(res, 'hello')
@gen_test
def test_path_args(self):
ws = yield self.ws_connect('/path_args/hello')
res = yield ws.read_message()
self.assertEqual(res, 'hello')
@gen_test
def test_coroutine(self):
ws = yield self.ws_connect('/coroutine')
# Send both messages immediately, coroutine must process one at a time.
yield ws.write_message('hello1')
yield ws.write_message('hello2')
res = yield ws.read_message()
self.assertEqual(res, 'hello1')
res = yield ws.read_message()
self.assertEqual(res, 'hello2')
@gen_test
def test_check_origin_valid_no_path(self):
port = self.get_http_port()
@ -254,8 +427,7 @@ class WebSocketTest(WebSocketBaseTestCase):
url = 'ws://127.0.0.1:%d/echo' % port
headers = {'Origin': 'http://127.0.0.1:%d' % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
ws = yield websocket_connect(HTTPRequest(url, headers=headers))
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
@ -268,8 +440,7 @@ class WebSocketTest(WebSocketBaseTestCase):
url = 'ws://127.0.0.1:%d/echo' % port
headers = {'Origin': 'http://127.0.0.1:%d/something' % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
ws = yield websocket_connect(HTTPRequest(url, headers=headers))
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
@ -283,8 +454,7 @@ class WebSocketTest(WebSocketBaseTestCase):
headers = {'Origin': '127.0.0.1:%d' % port}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
yield websocket_connect(HTTPRequest(url, headers=headers))
self.assertEqual(cm.exception.code, 403)
@gen_test
@ -297,8 +467,7 @@ class WebSocketTest(WebSocketBaseTestCase):
headers = {'Origin': 'http://somewhereelse.com'}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
yield websocket_connect(HTTPRequest(url, headers=headers))
self.assertEqual(cm.exception.code, 403)
@ -312,21 +481,94 @@ class WebSocketTest(WebSocketBaseTestCase):
headers = {'Origin': 'http://subtenant.localhost'}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
yield websocket_connect(HTTPRequest(url, headers=headers))
self.assertEqual(cm.exception.code, 403)
@gen_test
def test_subprotocols(self):
ws = yield self.ws_connect('/subprotocol', subprotocols=['badproto', 'goodproto'])
self.assertEqual(ws.selected_subprotocol, 'goodproto')
res = yield ws.read_message()
self.assertEqual(res, 'subprotocol=goodproto')
yield self.close(ws)
@gen_test
def test_subprotocols_not_offered(self):
ws = yield self.ws_connect('/subprotocol')
self.assertIs(ws.selected_subprotocol, None)
res = yield ws.read_message()
self.assertEqual(res, 'subprotocol=None')
yield self.close(ws)
@gen_test
def test_open_coroutine(self):
self.message_sent = Event()
ws = yield self.ws_connect('/open_coroutine')
yield ws.write_message('hello')
self.message_sent.set()
res = yield ws.read_message()
self.assertEqual(res, 'ok')
yield self.close(ws)
if sys.version_info >= (3, 5):
NativeCoroutineOnMessageHandler = exec_test(globals(), locals(), """
class NativeCoroutineOnMessageHandler(TestWebSocketHandler):
def initialize(self, close_future, compression_options=None):
super().initialize(close_future, compression_options)
self.sleeping = 0
async def on_message(self, message):
if self.sleeping > 0:
self.write_message('another coroutine is already sleeping')
self.sleeping += 1
await gen.sleep(0.01)
self.sleeping -= 1
self.write_message(message)""")['NativeCoroutineOnMessageHandler']
class WebSocketNativeCoroutineTest(WebSocketBaseTestCase):
def get_app(self):
self.close_future = Future()
return Application([
('/native', NativeCoroutineOnMessageHandler,
dict(close_future=self.close_future))])
@skipBefore35
@gen_test
def test_native_coroutine(self):
ws = yield self.ws_connect('/native')
# Send both messages immediately, coroutine must process one at a time.
yield ws.write_message('hello1')
yield ws.write_message('hello2')
res = yield ws.read_message()
self.assertEqual(res, 'hello1')
res = yield ws.read_message()
self.assertEqual(res, 'hello2')
class CompressionTestMixin(object):
MESSAGE = 'Hello world. Testing 123 123'
def get_app(self):
self.close_future = Future()
class LimitedHandler(TestWebSocketHandler):
@property
def max_message_size(self):
return 1024
def on_message(self, message):
self.write_message(str(len(message)))
return Application([
('/echo', EchoHandler, dict(
close_future=self.close_future,
compression_options=self.get_server_compression_options())),
('/limited', LimitedHandler, dict(
close_future=self.close_future,
compression_options=self.get_server_compression_options())),
])
def get_server_compression_options(self):
@ -352,6 +594,22 @@ class CompressionTestMixin(object):
ws.protocol._wire_bytes_out)
yield self.close(ws)
@gen_test
def test_size_limit(self):
ws = yield self.ws_connect(
'/limited',
compression_options=self.get_client_compression_options())
# Small messages pass through.
ws.write_message('a' * 128)
response = yield ws.read_message()
self.assertEqual(response, '128')
# This message is too big after decompression, but it compresses
# down to a size that will pass the initial checks.
ws.write_message('a' * 2048)
response = yield ws.read_message()
self.assertIsNone(response)
yield self.close(ws)
class UncompressedTestMixin(CompressionTestMixin):
"""Specialization of CompressionTestMixin when we expect no compression."""
@ -417,3 +675,101 @@ class PythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
class CythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
def mask(self, mask, data):
return speedups.websocket_mask(mask, data)
class ServerPeriodicPingTest(WebSocketBaseTestCase):
def get_app(self):
class PingHandler(TestWebSocketHandler):
def on_pong(self, data):
self.write_message("got pong")
self.close_future = Future()
return Application([
('/', PingHandler, dict(close_future=self.close_future)),
], websocket_ping_interval=0.01)
@gen_test
def test_server_ping(self):
ws = yield self.ws_connect('/')
for i in range(3):
response = yield ws.read_message()
self.assertEqual(response, "got pong")
yield self.close(ws)
# TODO: test that the connection gets closed if ping responses stop.
class ClientPeriodicPingTest(WebSocketBaseTestCase):
def get_app(self):
class PingHandler(TestWebSocketHandler):
def on_ping(self, data):
self.write_message("got ping")
self.close_future = Future()
return Application([
('/', PingHandler, dict(close_future=self.close_future)),
])
@gen_test
def test_client_ping(self):
ws = yield self.ws_connect('/', ping_interval=0.01)
for i in range(3):
response = yield ws.read_message()
self.assertEqual(response, "got ping")
yield self.close(ws)
# TODO: test that the connection gets closed if ping responses stop.
class ManualPingTest(WebSocketBaseTestCase):
def get_app(self):
class PingHandler(TestWebSocketHandler):
def on_ping(self, data):
self.write_message(data, binary=isinstance(data, bytes))
self.close_future = Future()
return Application([
('/', PingHandler, dict(close_future=self.close_future)),
])
@gen_test
def test_manual_ping(self):
ws = yield self.ws_connect('/')
self.assertRaises(ValueError, ws.ping, 'a' * 126)
ws.ping('hello')
resp = yield ws.read_message()
# on_ping always sees bytes.
self.assertEqual(resp, b'hello')
ws.ping(b'binary hello')
resp = yield ws.read_message()
self.assertEqual(resp, b'binary hello')
yield self.close(ws)
class MaxMessageSizeTest(WebSocketBaseTestCase):
def get_app(self):
self.close_future = Future()
return Application([
('/', EchoHandler, dict(close_future=self.close_future)),
], websocket_max_message_size=1024)
@gen_test
def test_large_message(self):
ws = yield self.ws_connect('/')
# Write a message that is allowed.
msg = 'a' * 1024
ws.write_message(msg)
resp = yield ws.read_message()
self.assertEqual(resp, msg)
# Write a message that is too large.
ws.write_message(msg + 'b')
resp = yield ws.read_message()
# A message of None means the other side closed the connection.
self.assertIs(resp, None)
self.assertEqual(ws.close_code, 1009)
self.assertEqual(ws.close_reason, "message too big")
# TODO: Needs tests of messages split over multiple
# continuation frames.

View file

@ -0,0 +1,25 @@
from __future__ import absolute_import, division, print_function
import functools
import os
import socket
import unittest
from tornado.platform.auto import set_close_exec
skipIfNonWindows = unittest.skipIf(os.name != 'nt', 'non-windows platform')
@skipIfNonWindows
class WindowsTest(unittest.TestCase):
def test_set_close_exec(self):
# set_close_exec works with sockets.
s = socket.socket()
self.addCleanup(s.close)
set_close_exec(s.fileno())
# But it doesn't work with pipes.
r, w = os.pipe()
self.addCleanup(functools.partial(os.close, r))
self.addCleanup(functools.partial(os.close, w))
with self.assertRaises(WindowsError):
set_close_exec(r)

View file

@ -1,13 +1,16 @@
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
from wsgiref.validate import validator
from tornado.escape import json_decode
from tornado.test.httpserver_test import TypeCheckHandler
from tornado.test.util import ignore_deprecation
from tornado.testing import AsyncHTTPTestCase
from tornado.util import u
from tornado.web import RequestHandler, Application
from tornado.wsgi import WSGIApplication, WSGIContainer, WSGIAdapter
from tornado.test import httpserver_test
from tornado.test import web_test
class WSGIContainerTest(AsyncHTTPTestCase):
def wsgi_app(self, environ, start_response):
@ -24,7 +27,7 @@ class WSGIContainerTest(AsyncHTTPTestCase):
self.assertEqual(response.body, b"Hello world!")
class WSGIApplicationTest(AsyncHTTPTestCase):
class WSGIAdapterTest(AsyncHTTPTestCase):
def get_app(self):
class HelloHandler(RequestHandler):
def get(self):
@ -38,11 +41,13 @@ class WSGIApplicationTest(AsyncHTTPTestCase):
# another thread instead of using our own WSGIContainer, but this
# fits better in our async testing framework and the wsgiref
# validator should keep us honest
return WSGIContainer(validator(WSGIApplication([
("/", HelloHandler),
("/path/(.*)", PathQuotingHandler),
("/typecheck", TypeCheckHandler),
])))
with ignore_deprecation():
return WSGIContainer(validator(WSGIAdapter(
Application([
("/", HelloHandler),
("/path/(.*)", PathQuotingHandler),
("/typecheck", TypeCheckHandler),
]))))
def test_simple(self):
response = self.fetch("/")
@ -50,7 +55,7 @@ class WSGIApplicationTest(AsyncHTTPTestCase):
def test_path_quoting(self):
response = self.fetch("/path/foo%20bar%C3%A9")
self.assertEqual(response.body, u("foo bar\u00e9").encode("utf-8"))
self.assertEqual(response.body, u"foo bar\u00e9".encode("utf-8"))
def test_types(self):
headers = {"Cookie": "foo=bar"}
@ -62,39 +67,52 @@ class WSGIApplicationTest(AsyncHTTPTestCase):
data = json_decode(response.body)
self.assertEqual(data, {})
# This is kind of hacky, but run some of the HTTPServer tests through
# WSGIContainer and WSGIApplication to make sure everything survives
# repeated disassembly and reassembly.
from tornado.test import httpserver_test
from tornado.test import web_test
# This is kind of hacky, but run some of the HTTPServer and web tests
# through WSGIContainer and WSGIApplication to make sure everything
# survives repeated disassembly and reassembly.
class WSGIConnectionTest(httpserver_test.HTTPConnectionTest):
def get_app(self):
return WSGIContainer(validator(WSGIApplication(self.get_handlers())))
with ignore_deprecation():
return WSGIContainer(validator(WSGIAdapter(Application(self.get_handlers()))))
def wrap_web_tests_application():
result = {}
for cls in web_test.wsgi_safe_tests:
class WSGIApplicationWrappedTest(cls):
def get_app(self):
self.app = WSGIApplication(self.get_handlers(),
**self.get_app_kwargs())
return WSGIContainer(validator(self.app))
result["WSGIApplication_" + cls.__name__] = WSGIApplicationWrappedTest
def class_factory():
class WSGIApplicationWrappedTest(cls): # type: ignore
def setUp(self):
self.warning_catcher = ignore_deprecation()
self.warning_catcher.__enter__()
super(WSGIApplicationWrappedTest, self).setUp()
def tearDown(self):
super(WSGIApplicationWrappedTest, self).tearDown()
self.warning_catcher.__exit__(None, None, None)
def get_app(self):
self.app = WSGIApplication(self.get_handlers(),
**self.get_app_kwargs())
return WSGIContainer(validator(self.app))
result["WSGIApplication_" + cls.__name__] = class_factory()
return result
globals().update(wrap_web_tests_application())
def wrap_web_tests_adapter():
result = {}
for cls in web_test.wsgi_safe_tests:
class WSGIAdapterWrappedTest(cls):
class WSGIAdapterWrappedTest(cls): # type: ignore
def get_app(self):
self.app = Application(self.get_handlers(),
**self.get_app_kwargs())
return WSGIContainer(validator(WSGIAdapter(self.app)))
with ignore_deprecation():
return WSGIContainer(validator(WSGIAdapter(self.app)))
result["WSGIAdapter_" + cls.__name__] = WSGIAdapterWrappedTest
return result
globals().update(wrap_web_tests_adapter())