update shared deps
This commit is contained in:
parent
6806bebb7c
commit
642ba49f68
275 changed files with 31987 additions and 19235 deletions
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
This only works in python 2.7+.
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from tornado.test.runtests import all, main
|
||||
|
||||
|
|
|
|||
|
|
@ -10,9 +10,11 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tornado import gen
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.testing import AsyncTestCase, gen_test
|
||||
from tornado.test.util import unittest, skipBefore33, skipBefore35, exec_test
|
||||
|
||||
|
|
@ -21,7 +23,7 @@ try:
|
|||
except ImportError:
|
||||
asyncio = None
|
||||
else:
|
||||
from tornado.platform.asyncio import AsyncIOLoop, to_asyncio_future
|
||||
from tornado.platform.asyncio import AsyncIOLoop, to_asyncio_future, AnyThreadEventLoopPolicy
|
||||
# This is used in dynamically-evaluated code, so silence pyflakes.
|
||||
to_asyncio_future
|
||||
|
||||
|
|
@ -30,7 +32,6 @@ else:
|
|||
class AsyncIOLoopTest(AsyncTestCase):
|
||||
def get_new_ioloop(self):
|
||||
io_loop = AsyncIOLoop()
|
||||
asyncio.set_event_loop(io_loop.asyncio_loop)
|
||||
return io_loop
|
||||
|
||||
def test_asyncio_callback(self):
|
||||
|
|
@ -41,8 +42,15 @@ class AsyncIOLoopTest(AsyncTestCase):
|
|||
@gen_test
|
||||
def test_asyncio_future(self):
|
||||
# Test that we can yield an asyncio future from a tornado coroutine.
|
||||
# Without 'yield from', we must wrap coroutines in asyncio.async.
|
||||
x = yield asyncio.async(
|
||||
# Without 'yield from', we must wrap coroutines in ensure_future,
|
||||
# which was introduced during Python 3.4, deprecating the prior "async".
|
||||
if hasattr(asyncio, 'ensure_future'):
|
||||
ensure_future = asyncio.ensure_future
|
||||
else:
|
||||
# async is a reserved word in Python 3.7
|
||||
ensure_future = getattr(asyncio, 'async')
|
||||
|
||||
x = yield ensure_future(
|
||||
asyncio.get_event_loop().run_in_executor(None, lambda: 42))
|
||||
self.assertEqual(x, 42)
|
||||
|
||||
|
|
@ -69,7 +77,7 @@ class AsyncIOLoopTest(AsyncTestCase):
|
|||
# as demonstrated by other tests in the package.
|
||||
@gen.coroutine
|
||||
def tornado_coroutine():
|
||||
yield gen.Task(self.io_loop.add_callback)
|
||||
yield gen.moment
|
||||
raise gen.Return(42)
|
||||
native_coroutine_without_adapter = exec_test(globals(), locals(), """
|
||||
async def native_coroutine_without_adapter():
|
||||
|
|
@ -99,10 +107,11 @@ class AsyncIOLoopTest(AsyncTestCase):
|
|||
42)
|
||||
|
||||
# Asyncio only supports coroutines that yield asyncio-compatible
|
||||
# Futures.
|
||||
with self.assertRaises(RuntimeError):
|
||||
# Futures (which our Future is since 5.0).
|
||||
self.assertEqual(
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
native_coroutine_without_adapter())
|
||||
native_coroutine_without_adapter()),
|
||||
42)
|
||||
self.assertEqual(
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
native_coroutine_with_adapter()),
|
||||
|
|
@ -111,3 +120,87 @@ class AsyncIOLoopTest(AsyncTestCase):
|
|||
asyncio.get_event_loop().run_until_complete(
|
||||
native_coroutine_with_adapter2()),
|
||||
42)
|
||||
|
||||
|
||||
@unittest.skipIf(asyncio is None, "asyncio module not present")
|
||||
class LeakTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Trigger a cleanup of the mapping so we start with a clean slate.
|
||||
AsyncIOLoop().close()
|
||||
# If we don't clean up after ourselves other tests may fail on
|
||||
# py34.
|
||||
self.orig_policy = asyncio.get_event_loop_policy()
|
||||
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
|
||||
|
||||
def tearDown(self):
|
||||
asyncio.get_event_loop().close()
|
||||
asyncio.set_event_loop_policy(self.orig_policy)
|
||||
|
||||
def test_ioloop_close_leak(self):
|
||||
orig_count = len(IOLoop._ioloop_for_asyncio)
|
||||
for i in range(10):
|
||||
# Create and close an AsyncIOLoop using Tornado interfaces.
|
||||
loop = AsyncIOLoop()
|
||||
loop.close()
|
||||
new_count = len(IOLoop._ioloop_for_asyncio) - orig_count
|
||||
self.assertEqual(new_count, 0)
|
||||
|
||||
def test_asyncio_close_leak(self):
|
||||
orig_count = len(IOLoop._ioloop_for_asyncio)
|
||||
for i in range(10):
|
||||
# Create and close an AsyncIOMainLoop using asyncio interfaces.
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.call_soon(IOLoop.current)
|
||||
loop.call_soon(loop.stop)
|
||||
loop.run_forever()
|
||||
loop.close()
|
||||
new_count = len(IOLoop._ioloop_for_asyncio) - orig_count
|
||||
# Because the cleanup is run on new loop creation, we have one
|
||||
# dangling entry in the map (but only one).
|
||||
self.assertEqual(new_count, 1)
|
||||
|
||||
|
||||
@unittest.skipIf(asyncio is None, "asyncio module not present")
|
||||
class AnyThreadEventLoopPolicyTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.orig_policy = asyncio.get_event_loop_policy()
|
||||
self.executor = ThreadPoolExecutor(1)
|
||||
|
||||
def tearDown(self):
|
||||
asyncio.set_event_loop_policy(self.orig_policy)
|
||||
self.executor.shutdown()
|
||||
|
||||
def get_event_loop_on_thread(self):
|
||||
def get_and_close_event_loop():
|
||||
"""Get the event loop. Close it if one is returned.
|
||||
|
||||
Returns the (closed) event loop. This is a silly thing
|
||||
to do and leaves the thread in a broken state, but it's
|
||||
enough for this test. Closing the loop avoids resource
|
||||
leak warnings.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.close()
|
||||
return loop
|
||||
future = self.executor.submit(get_and_close_event_loop)
|
||||
return future.result()
|
||||
|
||||
def run_policy_test(self, accessor, expected_type):
|
||||
# With the default policy, non-main threads don't get an event
|
||||
# loop.
|
||||
self.assertRaises((RuntimeError, AssertionError),
|
||||
self.executor.submit(accessor).result)
|
||||
# Set the policy and we can get a loop.
|
||||
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
||||
self.assertIsInstance(
|
||||
self.executor.submit(accessor).result(),
|
||||
expected_type)
|
||||
# Clean up to silence leak warnings. Always use asyncio since
|
||||
# IOLoop doesn't (currently) close the underlying loop.
|
||||
self.executor.submit(lambda: asyncio.get_event_loop().close()).result()
|
||||
|
||||
def test_asyncio_accessor(self):
|
||||
self.run_policy_test(asyncio.get_event_loop, asyncio.AbstractEventLoop)
|
||||
|
||||
def test_tornado_accessor(self):
|
||||
self.run_policy_test(IOLoop.current, IOLoop)
|
||||
|
|
|
|||
|
|
@ -4,37 +4,69 @@
|
|||
# python 3)
|
||||
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from tornado.auth import OpenIdMixin, OAuthMixin, OAuth2Mixin, TwitterMixin, AuthError, GoogleOAuth2Mixin, FacebookGraphMixin
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
from tornado.auth import (
|
||||
AuthError, OpenIdMixin, OAuthMixin, OAuth2Mixin,
|
||||
GoogleOAuth2Mixin, FacebookGraphMixin, TwitterMixin,
|
||||
)
|
||||
from tornado.concurrent import Future
|
||||
from tornado.escape import json_decode
|
||||
from tornado import gen
|
||||
from tornado.httputil import url_concat
|
||||
from tornado.log import gen_log
|
||||
from tornado.log import gen_log, app_log
|
||||
from tornado.testing import AsyncHTTPTestCase, ExpectLog
|
||||
from tornado.util import u
|
||||
from tornado.test.util import ignore_deprecation
|
||||
from tornado.web import RequestHandler, Application, asynchronous, HTTPError
|
||||
|
||||
try:
|
||||
from unittest import mock
|
||||
except ImportError:
|
||||
mock = None
|
||||
|
||||
|
||||
class OpenIdClientLoginHandlerLegacy(RequestHandler, OpenIdMixin):
|
||||
def initialize(self, test):
|
||||
self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate')
|
||||
|
||||
with ignore_deprecation():
|
||||
@asynchronous
|
||||
def get(self):
|
||||
if self.get_argument('openid.mode', None):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', DeprecationWarning)
|
||||
self.get_authenticated_user(
|
||||
self.on_user, http_client=self.settings['http_client'])
|
||||
return
|
||||
res = self.authenticate_redirect()
|
||||
assert isinstance(res, Future)
|
||||
assert res.done()
|
||||
|
||||
def on_user(self, user):
|
||||
if user is None:
|
||||
raise Exception("user is None")
|
||||
self.finish(user)
|
||||
|
||||
|
||||
class OpenIdClientLoginHandler(RequestHandler, OpenIdMixin):
|
||||
def initialize(self, test):
|
||||
self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate')
|
||||
|
||||
@asynchronous
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
if self.get_argument('openid.mode', None):
|
||||
self.get_authenticated_user(
|
||||
self.on_user, http_client=self.settings['http_client'])
|
||||
user = yield self.get_authenticated_user(http_client=self.settings['http_client'])
|
||||
if user is None:
|
||||
raise Exception("user is None")
|
||||
self.finish(user)
|
||||
return
|
||||
res = self.authenticate_redirect()
|
||||
assert isinstance(res, Future)
|
||||
assert res.done()
|
||||
|
||||
def on_user(self, user):
|
||||
if user is None:
|
||||
raise Exception("user is None")
|
||||
self.finish(user)
|
||||
|
||||
|
||||
class OpenIdServerAuthenticateHandler(RequestHandler):
|
||||
def post(self):
|
||||
|
|
@ -43,7 +75,7 @@ class OpenIdServerAuthenticateHandler(RequestHandler):
|
|||
self.write('is_valid:true')
|
||||
|
||||
|
||||
class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
|
||||
class OAuth1ClientLoginHandlerLegacy(RequestHandler, OAuthMixin):
|
||||
def initialize(self, test, version):
|
||||
self._OAUTH_VERSION = version
|
||||
self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
|
||||
|
|
@ -53,14 +85,17 @@ class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
|
|||
def _oauth_consumer_token(self):
|
||||
return dict(key='asdf', secret='qwer')
|
||||
|
||||
@asynchronous
|
||||
def get(self):
|
||||
if self.get_argument('oauth_token', None):
|
||||
self.get_authenticated_user(
|
||||
self.on_user, http_client=self.settings['http_client'])
|
||||
return
|
||||
res = self.authorize_redirect(http_client=self.settings['http_client'])
|
||||
assert isinstance(res, Future)
|
||||
with ignore_deprecation():
|
||||
@asynchronous
|
||||
def get(self):
|
||||
if self.get_argument('oauth_token', None):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', DeprecationWarning)
|
||||
self.get_authenticated_user(
|
||||
self.on_user, http_client=self.settings['http_client'])
|
||||
return
|
||||
res = self.authorize_redirect(http_client=self.settings['http_client'])
|
||||
assert isinstance(res, Future)
|
||||
|
||||
def on_user(self, user):
|
||||
if user is None:
|
||||
|
|
@ -75,6 +110,35 @@ class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
|
|||
callback(dict(email='foo@example.com'))
|
||||
|
||||
|
||||
class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
|
||||
def initialize(self, test, version):
|
||||
self._OAUTH_VERSION = version
|
||||
self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
|
||||
self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize')
|
||||
self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/oauth1/server/access_token')
|
||||
|
||||
def _oauth_consumer_token(self):
|
||||
return dict(key='asdf', secret='qwer')
|
||||
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
if self.get_argument('oauth_token', None):
|
||||
user = yield self.get_authenticated_user(http_client=self.settings['http_client'])
|
||||
if user is None:
|
||||
raise Exception("user is None")
|
||||
self.finish(user)
|
||||
return
|
||||
yield self.authorize_redirect(http_client=self.settings['http_client'])
|
||||
|
||||
@gen.coroutine
|
||||
def _oauth_get_user_future(self, access_token):
|
||||
if self.get_argument('fail_in_get_user', None):
|
||||
raise Exception("failing in get_user")
|
||||
if access_token != dict(key='uiop', secret='5678'):
|
||||
raise Exception("incorrect access token %r" % access_token)
|
||||
return dict(email='foo@example.com')
|
||||
|
||||
|
||||
class OAuth1ClientLoginCoroutineHandler(OAuth1ClientLoginHandler):
|
||||
"""Replaces OAuth1ClientLoginCoroutineHandler's get() with a coroutine."""
|
||||
@gen.coroutine
|
||||
|
|
@ -150,7 +214,7 @@ class FacebookClientLoginHandler(RequestHandler, FacebookGraphMixin):
|
|||
|
||||
class FacebookServerAccessTokenHandler(RequestHandler):
|
||||
def get(self):
|
||||
self.write('access_token=asdf')
|
||||
self.write(dict(access_token="asdf", expires_in=3600))
|
||||
|
||||
|
||||
class FacebookServerMeHandler(RequestHandler):
|
||||
|
|
@ -163,19 +227,21 @@ class TwitterClientHandler(RequestHandler, TwitterMixin):
|
|||
self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
|
||||
self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/twitter/server/access_token')
|
||||
self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize')
|
||||
self._OAUTH_AUTHENTICATE_URL = test.get_url('/twitter/server/authenticate')
|
||||
self._TWITTER_BASE_URL = test.get_url('/twitter/api')
|
||||
|
||||
def get_auth_http_client(self):
|
||||
return self.settings['http_client']
|
||||
|
||||
|
||||
class TwitterClientLoginHandler(TwitterClientHandler):
|
||||
@asynchronous
|
||||
def get(self):
|
||||
if self.get_argument("oauth_token", None):
|
||||
self.get_authenticated_user(self.on_user)
|
||||
return
|
||||
self.authorize_redirect()
|
||||
class TwitterClientLoginHandlerLegacy(TwitterClientHandler):
|
||||
with ignore_deprecation():
|
||||
@asynchronous
|
||||
def get(self):
|
||||
if self.get_argument("oauth_token", None):
|
||||
self.get_authenticated_user(self.on_user)
|
||||
return
|
||||
self.authorize_redirect()
|
||||
|
||||
def on_user(self, user):
|
||||
if user is None:
|
||||
|
|
@ -183,17 +249,44 @@ class TwitterClientLoginHandler(TwitterClientHandler):
|
|||
self.finish(user)
|
||||
|
||||
|
||||
class TwitterClientLoginGenEngineHandler(TwitterClientHandler):
|
||||
@asynchronous
|
||||
@gen.engine
|
||||
class TwitterClientLoginHandler(TwitterClientHandler):
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
if self.get_argument("oauth_token", None):
|
||||
user = yield self.get_authenticated_user()
|
||||
if user is None:
|
||||
raise Exception("user is None")
|
||||
self.finish(user)
|
||||
else:
|
||||
# Old style: with @gen.engine we can ignore the Future from
|
||||
# authorize_redirect.
|
||||
self.authorize_redirect()
|
||||
return
|
||||
yield self.authorize_redirect()
|
||||
|
||||
|
||||
class TwitterClientAuthenticateHandler(TwitterClientHandler):
|
||||
# Like TwitterClientLoginHandler, but uses authenticate_redirect
|
||||
# instead of authorize_redirect.
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
if self.get_argument("oauth_token", None):
|
||||
user = yield self.get_authenticated_user()
|
||||
if user is None:
|
||||
raise Exception("user is None")
|
||||
self.finish(user)
|
||||
return
|
||||
yield self.authenticate_redirect()
|
||||
|
||||
|
||||
class TwitterClientLoginGenEngineHandler(TwitterClientHandler):
|
||||
with ignore_deprecation():
|
||||
@asynchronous
|
||||
@gen.engine
|
||||
def get(self):
|
||||
if self.get_argument("oauth_token", None):
|
||||
user = yield self.get_authenticated_user()
|
||||
self.finish(user)
|
||||
else:
|
||||
# Old style: with @gen.engine we can ignore the Future from
|
||||
# authorize_redirect.
|
||||
self.authorize_redirect()
|
||||
|
||||
|
||||
class TwitterClientLoginGenCoroutineHandler(TwitterClientHandler):
|
||||
|
|
@ -208,36 +301,39 @@ class TwitterClientLoginGenCoroutineHandler(TwitterClientHandler):
|
|||
yield self.authorize_redirect()
|
||||
|
||||
|
||||
class TwitterClientShowUserHandlerLegacy(TwitterClientHandler):
|
||||
with ignore_deprecation():
|
||||
@asynchronous
|
||||
@gen.engine
|
||||
def get(self):
|
||||
# TODO: would be nice to go through the login flow instead of
|
||||
# cheating with a hard-coded access token.
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', DeprecationWarning)
|
||||
response = yield gen.Task(self.twitter_request,
|
||||
'/users/show/%s' % self.get_argument('name'),
|
||||
access_token=dict(key='hjkl', secret='vbnm'))
|
||||
if response is None:
|
||||
self.set_status(500)
|
||||
self.finish('error from twitter request')
|
||||
else:
|
||||
self.finish(response)
|
||||
|
||||
|
||||
class TwitterClientShowUserHandler(TwitterClientHandler):
|
||||
@asynchronous
|
||||
@gen.engine
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
# TODO: would be nice to go through the login flow instead of
|
||||
# cheating with a hard-coded access token.
|
||||
response = yield gen.Task(self.twitter_request,
|
||||
'/users/show/%s' % self.get_argument('name'),
|
||||
access_token=dict(key='hjkl', secret='vbnm'))
|
||||
if response is None:
|
||||
self.set_status(500)
|
||||
self.finish('error from twitter request')
|
||||
else:
|
||||
self.finish(response)
|
||||
|
||||
|
||||
class TwitterClientShowUserFutureHandler(TwitterClientHandler):
|
||||
@asynchronous
|
||||
@gen.engine
|
||||
def get(self):
|
||||
try:
|
||||
response = yield self.twitter_request(
|
||||
'/users/show/%s' % self.get_argument('name'),
|
||||
access_token=dict(key='hjkl', secret='vbnm'))
|
||||
except AuthError as e:
|
||||
except AuthError:
|
||||
self.set_status(500)
|
||||
self.finish(str(e))
|
||||
return
|
||||
assert response is not None
|
||||
self.finish(response)
|
||||
self.finish('error from twitter request')
|
||||
else:
|
||||
self.finish(response)
|
||||
|
||||
|
||||
class TwitterServerAccessTokenHandler(RequestHandler):
|
||||
|
|
@ -276,12 +372,17 @@ class AuthTest(AsyncHTTPTestCase):
|
|||
return Application(
|
||||
[
|
||||
# test endpoints
|
||||
('/legacy/openid/client/login', OpenIdClientLoginHandlerLegacy, dict(test=self)),
|
||||
('/openid/client/login', OpenIdClientLoginHandler, dict(test=self)),
|
||||
('/legacy/oauth10/client/login', OAuth1ClientLoginHandlerLegacy,
|
||||
dict(test=self, version='1.0')),
|
||||
('/oauth10/client/login', OAuth1ClientLoginHandler,
|
||||
dict(test=self, version='1.0')),
|
||||
('/oauth10/client/request_params',
|
||||
OAuth1ClientRequestParametersHandler,
|
||||
dict(version='1.0')),
|
||||
('/legacy/oauth10a/client/login', OAuth1ClientLoginHandlerLegacy,
|
||||
dict(test=self, version='1.0a')),
|
||||
('/oauth10a/client/login', OAuth1ClientLoginHandler,
|
||||
dict(test=self, version='1.0a')),
|
||||
('/oauth10a/client/login_coroutine',
|
||||
|
|
@ -294,11 +395,17 @@ class AuthTest(AsyncHTTPTestCase):
|
|||
|
||||
('/facebook/client/login', FacebookClientLoginHandler, dict(test=self)),
|
||||
|
||||
('/legacy/twitter/client/login', TwitterClientLoginHandlerLegacy, dict(test=self)),
|
||||
('/twitter/client/login', TwitterClientLoginHandler, dict(test=self)),
|
||||
('/twitter/client/login_gen_engine', TwitterClientLoginGenEngineHandler, dict(test=self)),
|
||||
('/twitter/client/login_gen_coroutine', TwitterClientLoginGenCoroutineHandler, dict(test=self)),
|
||||
('/twitter/client/show_user', TwitterClientShowUserHandler, dict(test=self)),
|
||||
('/twitter/client/show_user_future', TwitterClientShowUserFutureHandler, dict(test=self)),
|
||||
('/twitter/client/authenticate', TwitterClientAuthenticateHandler, dict(test=self)),
|
||||
('/twitter/client/login_gen_engine',
|
||||
TwitterClientLoginGenEngineHandler, dict(test=self)),
|
||||
('/twitter/client/login_gen_coroutine',
|
||||
TwitterClientLoginGenCoroutineHandler, dict(test=self)),
|
||||
('/legacy/twitter/client/show_user',
|
||||
TwitterClientShowUserHandlerLegacy, dict(test=self)),
|
||||
('/twitter/client/show_user',
|
||||
TwitterClientShowUserHandler, dict(test=self)),
|
||||
|
||||
# simulated servers
|
||||
('/openid/server/authenticate', OpenIdServerAuthenticateHandler),
|
||||
|
|
@ -309,7 +416,8 @@ class AuthTest(AsyncHTTPTestCase):
|
|||
('/facebook/server/me', FacebookServerMeHandler),
|
||||
('/twitter/server/access_token', TwitterServerAccessTokenHandler),
|
||||
(r'/twitter/api/users/show/(.*)\.json', TwitterServerShowUserHandler),
|
||||
(r'/twitter/api/account/verify_credentials\.json', TwitterServerVerifyCredentialsHandler),
|
||||
(r'/twitter/api/account/verify_credentials\.json',
|
||||
TwitterServerVerifyCredentialsHandler),
|
||||
],
|
||||
http_client=self.http_client,
|
||||
twitter_consumer_key='test_twitter_consumer_key',
|
||||
|
|
@ -317,6 +425,21 @@ class AuthTest(AsyncHTTPTestCase):
|
|||
facebook_api_key='test_facebook_api_key',
|
||||
facebook_secret='test_facebook_secret')
|
||||
|
||||
def test_openid_redirect_legacy(self):
|
||||
response = self.fetch('/legacy/openid/client/login', follow_redirects=False)
|
||||
self.assertEqual(response.code, 302)
|
||||
self.assertTrue(
|
||||
'/openid/server/authenticate?' in response.headers['Location'])
|
||||
|
||||
def test_openid_get_user_legacy(self):
|
||||
response = self.fetch('/legacy/openid/client/login?openid.mode=blah'
|
||||
'&openid.ns.ax=http://openid.net/srv/ax/1.0'
|
||||
'&openid.ax.type.email=http://axschema.org/contact/email'
|
||||
'&openid.ax.value.email=foo@example.com')
|
||||
response.rethrow()
|
||||
parsed = json_decode(response.body)
|
||||
self.assertEqual(parsed["email"], "foo@example.com")
|
||||
|
||||
def test_openid_redirect(self):
|
||||
response = self.fetch('/openid/client/login', follow_redirects=False)
|
||||
self.assertEqual(response.code, 302)
|
||||
|
|
@ -324,11 +447,24 @@ class AuthTest(AsyncHTTPTestCase):
|
|||
'/openid/server/authenticate?' in response.headers['Location'])
|
||||
|
||||
def test_openid_get_user(self):
|
||||
response = self.fetch('/openid/client/login?openid.mode=blah&openid.ns.ax=http://openid.net/srv/ax/1.0&openid.ax.type.email=http://axschema.org/contact/email&openid.ax.value.email=foo@example.com')
|
||||
response = self.fetch('/openid/client/login?openid.mode=blah'
|
||||
'&openid.ns.ax=http://openid.net/srv/ax/1.0'
|
||||
'&openid.ax.type.email=http://axschema.org/contact/email'
|
||||
'&openid.ax.value.email=foo@example.com')
|
||||
response.rethrow()
|
||||
parsed = json_decode(response.body)
|
||||
self.assertEqual(parsed["email"], "foo@example.com")
|
||||
|
||||
def test_oauth10_redirect_legacy(self):
|
||||
response = self.fetch('/legacy/oauth10/client/login', follow_redirects=False)
|
||||
self.assertEqual(response.code, 302)
|
||||
self.assertTrue(response.headers['Location'].endswith(
|
||||
'/oauth1/server/authorize?oauth_token=zxcv'))
|
||||
# the cookie is base64('zxcv')|base64('1234')
|
||||
self.assertTrue(
|
||||
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
|
||||
response.headers['Set-Cookie'])
|
||||
|
||||
def test_oauth10_redirect(self):
|
||||
response = self.fetch('/oauth10/client/login', follow_redirects=False)
|
||||
self.assertEqual(response.code, 302)
|
||||
|
|
@ -339,6 +475,16 @@ class AuthTest(AsyncHTTPTestCase):
|
|||
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
|
||||
response.headers['Set-Cookie'])
|
||||
|
||||
def test_oauth10_get_user_legacy(self):
|
||||
with ignore_deprecation():
|
||||
response = self.fetch(
|
||||
'/legacy/oauth10/client/login?oauth_token=zxcv',
|
||||
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
|
||||
response.rethrow()
|
||||
parsed = json_decode(response.body)
|
||||
self.assertEqual(parsed['email'], 'foo@example.com')
|
||||
self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
|
||||
|
||||
def test_oauth10_get_user(self):
|
||||
response = self.fetch(
|
||||
'/oauth10/client/login?oauth_token=zxcv',
|
||||
|
|
@ -357,6 +503,26 @@ class AuthTest(AsyncHTTPTestCase):
|
|||
self.assertTrue('oauth_nonce' in parsed)
|
||||
self.assertTrue('oauth_signature' in parsed)
|
||||
|
||||
def test_oauth10a_redirect_legacy(self):
|
||||
response = self.fetch('/legacy/oauth10a/client/login', follow_redirects=False)
|
||||
self.assertEqual(response.code, 302)
|
||||
self.assertTrue(response.headers['Location'].endswith(
|
||||
'/oauth1/server/authorize?oauth_token=zxcv'))
|
||||
# the cookie is base64('zxcv')|base64('1234')
|
||||
self.assertTrue(
|
||||
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
|
||||
response.headers['Set-Cookie'])
|
||||
|
||||
def test_oauth10a_get_user_legacy(self):
|
||||
with ignore_deprecation():
|
||||
response = self.fetch(
|
||||
'/legacy/oauth10a/client/login?oauth_token=zxcv',
|
||||
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
|
||||
response.rethrow()
|
||||
parsed = json_decode(response.body)
|
||||
self.assertEqual(parsed['email'], 'foo@example.com')
|
||||
self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
|
||||
|
||||
def test_oauth10a_redirect(self):
|
||||
response = self.fetch('/oauth10a/client/login', follow_redirects=False)
|
||||
self.assertEqual(response.code, 302)
|
||||
|
|
@ -367,6 +533,14 @@ class AuthTest(AsyncHTTPTestCase):
|
|||
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
|
||||
response.headers['Set-Cookie'])
|
||||
|
||||
@unittest.skipIf(mock is None, 'mock package not present')
|
||||
def test_oauth10a_redirect_error(self):
|
||||
with mock.patch.object(OAuth1ServerRequestTokenHandler, 'get') as get:
|
||||
get.side_effect = Exception("boom")
|
||||
with ExpectLog(app_log, "Uncaught exception"):
|
||||
response = self.fetch('/oauth10a/client/login', follow_redirects=False)
|
||||
self.assertEqual(response.code, 500)
|
||||
|
||||
def test_oauth10a_get_user(self):
|
||||
response = self.fetch(
|
||||
'/oauth10a/client/login?oauth_token=zxcv',
|
||||
|
|
@ -402,6 +576,9 @@ class AuthTest(AsyncHTTPTestCase):
|
|||
self.assertTrue('/facebook/server/authorize?' in response.headers['Location'])
|
||||
response = self.fetch('/facebook/client/login?code=1234', follow_redirects=False)
|
||||
self.assertEqual(response.code, 200)
|
||||
user = json_decode(response.body)
|
||||
self.assertEqual(user['access_token'], 'asdf')
|
||||
self.assertEqual(user['session_expires'], '3600')
|
||||
|
||||
def base_twitter_redirect(self, url):
|
||||
# Same as test_oauth10a_redirect
|
||||
|
|
@ -414,6 +591,9 @@ class AuthTest(AsyncHTTPTestCase):
|
|||
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
|
||||
response.headers['Set-Cookie'])
|
||||
|
||||
def test_twitter_redirect_legacy(self):
|
||||
self.base_twitter_redirect('/legacy/twitter/client/login')
|
||||
|
||||
def test_twitter_redirect(self):
|
||||
self.base_twitter_redirect('/twitter/client/login')
|
||||
|
||||
|
|
@ -423,6 +603,16 @@ class AuthTest(AsyncHTTPTestCase):
|
|||
def test_twitter_redirect_gen_coroutine(self):
|
||||
self.base_twitter_redirect('/twitter/client/login_gen_coroutine')
|
||||
|
||||
def test_twitter_authenticate_redirect(self):
|
||||
response = self.fetch('/twitter/client/authenticate', follow_redirects=False)
|
||||
self.assertEqual(response.code, 302)
|
||||
self.assertTrue(response.headers['Location'].endswith(
|
||||
'/twitter/server/authenticate?oauth_token=zxcv'), response.headers['Location'])
|
||||
# the cookie is base64('zxcv')|base64('1234')
|
||||
self.assertTrue(
|
||||
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
|
||||
response.headers['Set-Cookie'])
|
||||
|
||||
def test_twitter_get_user(self):
|
||||
response = self.fetch(
|
||||
'/twitter/client/login?oauth_token=zxcv',
|
||||
|
|
@ -430,12 +620,24 @@ class AuthTest(AsyncHTTPTestCase):
|
|||
response.rethrow()
|
||||
parsed = json_decode(response.body)
|
||||
self.assertEqual(parsed,
|
||||
{u('access_token'): {u('key'): u('hjkl'),
|
||||
u('screen_name'): u('foo'),
|
||||
u('secret'): u('vbnm')},
|
||||
u('name'): u('Foo'),
|
||||
u('screen_name'): u('foo'),
|
||||
u('username'): u('foo')})
|
||||
{u'access_token': {u'key': u'hjkl',
|
||||
u'screen_name': u'foo',
|
||||
u'secret': u'vbnm'},
|
||||
u'name': u'Foo',
|
||||
u'screen_name': u'foo',
|
||||
u'username': u'foo'})
|
||||
|
||||
def test_twitter_show_user_legacy(self):
|
||||
response = self.fetch('/legacy/twitter/client/show_user?name=somebody')
|
||||
response.rethrow()
|
||||
self.assertEqual(json_decode(response.body),
|
||||
{'name': 'Somebody', 'screen_name': 'somebody'})
|
||||
|
||||
def test_twitter_show_user_error_legacy(self):
|
||||
with ExpectLog(gen_log, 'Error response HTTP 500'):
|
||||
response = self.fetch('/legacy/twitter/client/show_user?name=error')
|
||||
self.assertEqual(response.code, 500)
|
||||
self.assertEqual(response.body, b'error from twitter request')
|
||||
|
||||
def test_twitter_show_user(self):
|
||||
response = self.fetch('/twitter/client/show_user?name=somebody')
|
||||
|
|
@ -444,22 +646,10 @@ class AuthTest(AsyncHTTPTestCase):
|
|||
{'name': 'Somebody', 'screen_name': 'somebody'})
|
||||
|
||||
def test_twitter_show_user_error(self):
|
||||
with ExpectLog(gen_log, 'Error response HTTP 500'):
|
||||
response = self.fetch('/twitter/client/show_user?name=error')
|
||||
response = self.fetch('/twitter/client/show_user?name=error')
|
||||
self.assertEqual(response.code, 500)
|
||||
self.assertEqual(response.body, b'error from twitter request')
|
||||
|
||||
def test_twitter_show_user_future(self):
|
||||
response = self.fetch('/twitter/client/show_user_future?name=somebody')
|
||||
response.rethrow()
|
||||
self.assertEqual(json_decode(response.body),
|
||||
{'name': 'Somebody', 'screen_name': 'somebody'})
|
||||
|
||||
def test_twitter_show_user_future_error(self):
|
||||
response = self.fetch('/twitter/client/show_user_future?name=error')
|
||||
self.assertEqual(response.code, 500)
|
||||
self.assertIn(b'Error response HTTP 500', response.body)
|
||||
|
||||
|
||||
class GoogleLoginHandler(RequestHandler, GoogleOAuth2Mixin):
|
||||
def initialize(self, test):
|
||||
|
|
@ -539,7 +729,7 @@ class GoogleOAuth2Test(AsyncHTTPTestCase):
|
|||
def test_google_login(self):
|
||||
response = self.fetch('/client/login')
|
||||
self.assertDictEqual({
|
||||
u('name'): u('Foo'),
|
||||
u('email'): u('foo@example.com'),
|
||||
u('access_token'): u('fake-access-token'),
|
||||
u'name': u'Foo',
|
||||
u'email': u'foo@example.com',
|
||||
u'access_token': u'fake-access-token',
|
||||
}, json_decode(response.body))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,114 @@
|
|||
from __future__ import absolute_import, division, print_function
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from subprocess import Popen
|
||||
import sys
|
||||
from tempfile import mkdtemp
|
||||
import time
|
||||
|
||||
from tornado.test.util import unittest
|
||||
|
||||
|
||||
class AutoreloadTest(unittest.TestCase):
|
||||
|
||||
def test_reload_module(self):
|
||||
main = """\
|
||||
import os
|
||||
import sys
|
||||
|
||||
from tornado import autoreload
|
||||
|
||||
# This import will fail if path is not set up correctly
|
||||
import testapp
|
||||
|
||||
print('Starting')
|
||||
if 'TESTAPP_STARTED' not in os.environ:
|
||||
os.environ['TESTAPP_STARTED'] = '1'
|
||||
sys.stdout.flush()
|
||||
autoreload._reload()
|
||||
"""
|
||||
|
||||
# Create temporary test application
|
||||
path = mkdtemp()
|
||||
self.addCleanup(shutil.rmtree, path)
|
||||
os.mkdir(os.path.join(path, 'testapp'))
|
||||
open(os.path.join(path, 'testapp/__init__.py'), 'w').close()
|
||||
with open(os.path.join(path, 'testapp/__main__.py'), 'w') as f:
|
||||
f.write(main)
|
||||
|
||||
# Make sure the tornado module under test is available to the test
|
||||
# application
|
||||
pythonpath = os.getcwd()
|
||||
if 'PYTHONPATH' in os.environ:
|
||||
pythonpath += os.pathsep + os.environ['PYTHONPATH']
|
||||
|
||||
p = Popen(
|
||||
[sys.executable, '-m', 'testapp'], stdout=subprocess.PIPE,
|
||||
cwd=path, env=dict(os.environ, PYTHONPATH=pythonpath),
|
||||
universal_newlines=True)
|
||||
out = p.communicate()[0]
|
||||
self.assertEqual(out, 'Starting\nStarting\n')
|
||||
|
||||
def test_reload_wrapper_preservation(self):
|
||||
# This test verifies that when `python -m tornado.autoreload`
|
||||
# is used on an application that also has an internal
|
||||
# autoreload, the reload wrapper is preserved on restart.
|
||||
main = """\
|
||||
import os
|
||||
import sys
|
||||
|
||||
# This import will fail if path is not set up correctly
|
||||
import testapp
|
||||
|
||||
if 'tornado.autoreload' not in sys.modules:
|
||||
raise Exception('started without autoreload wrapper')
|
||||
|
||||
import tornado.autoreload
|
||||
|
||||
print('Starting')
|
||||
sys.stdout.flush()
|
||||
if 'TESTAPP_STARTED' not in os.environ:
|
||||
os.environ['TESTAPP_STARTED'] = '1'
|
||||
# Simulate an internal autoreload (one not caused
|
||||
# by the wrapper).
|
||||
tornado.autoreload._reload()
|
||||
else:
|
||||
# Exit directly so autoreload doesn't catch it.
|
||||
os._exit(0)
|
||||
"""
|
||||
|
||||
# Create temporary test application
|
||||
path = mkdtemp()
|
||||
os.mkdir(os.path.join(path, 'testapp'))
|
||||
self.addCleanup(shutil.rmtree, path)
|
||||
init_file = os.path.join(path, 'testapp', '__init__.py')
|
||||
open(init_file, 'w').close()
|
||||
main_file = os.path.join(path, 'testapp', '__main__.py')
|
||||
with open(main_file, 'w') as f:
|
||||
f.write(main)
|
||||
|
||||
# Make sure the tornado module under test is available to the test
|
||||
# application
|
||||
pythonpath = os.getcwd()
|
||||
if 'PYTHONPATH' in os.environ:
|
||||
pythonpath += os.pathsep + os.environ['PYTHONPATH']
|
||||
|
||||
autoreload_proc = Popen(
|
||||
[sys.executable, '-m', 'tornado.autoreload', '-m', 'testapp'],
|
||||
stdout=subprocess.PIPE, cwd=path,
|
||||
env=dict(os.environ, PYTHONPATH=pythonpath),
|
||||
universal_newlines=True)
|
||||
|
||||
# This timeout needs to be fairly generous for pypy due to jit
|
||||
# warmup costs.
|
||||
for i in range(40):
|
||||
if autoreload_proc.poll() is not None:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
autoreload_proc.kill()
|
||||
raise Exception("subprocess failed to terminate")
|
||||
|
||||
out = autoreload_proc.communicate()[0]
|
||||
self.assertEqual(out, 'Starting\n' * 2)
|
||||
|
|
@ -1,4 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright 2012 Facebook
|
||||
#
|
||||
|
|
@ -13,22 +12,27 @@
|
|||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import re
|
||||
import socket
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
|
||||
from tornado.concurrent import Future, return_future, ReturnValueIgnoredError, run_on_executor
|
||||
from tornado.concurrent import (Future, return_future, ReturnValueIgnoredError,
|
||||
run_on_executor, future_set_result_unless_cancelled)
|
||||
from tornado.escape import utf8, to_unicode
|
||||
from tornado import gen
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.iostream import IOStream
|
||||
from tornado.log import app_log
|
||||
from tornado import stack_context
|
||||
from tornado.tcpserver import TCPServer
|
||||
from tornado.testing import AsyncTestCase, LogTrapTestCase, bind_unused_port, gen_test
|
||||
from tornado.test.util import unittest
|
||||
from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test
|
||||
from tornado.test.util import unittest, skipBefore35, exec_test, ignore_deprecation
|
||||
|
||||
|
||||
try:
|
||||
|
|
@ -37,33 +41,51 @@ except ImportError:
|
|||
futures = None
|
||||
|
||||
|
||||
class MiscFutureTest(AsyncTestCase):
|
||||
|
||||
def test_future_set_result_unless_cancelled(self):
|
||||
fut = Future()
|
||||
future_set_result_unless_cancelled(fut, 42)
|
||||
self.assertEqual(fut.result(), 42)
|
||||
self.assertFalse(fut.cancelled())
|
||||
|
||||
fut = Future()
|
||||
fut.cancel()
|
||||
is_cancelled = fut.cancelled()
|
||||
future_set_result_unless_cancelled(fut, 42)
|
||||
self.assertEqual(fut.cancelled(), is_cancelled)
|
||||
if not is_cancelled:
|
||||
self.assertEqual(fut.result(), 42)
|
||||
|
||||
|
||||
class ReturnFutureTest(AsyncTestCase):
|
||||
@return_future
|
||||
def sync_future(self, callback):
|
||||
callback(42)
|
||||
with ignore_deprecation():
|
||||
@return_future
|
||||
def sync_future(self, callback):
|
||||
callback(42)
|
||||
|
||||
@return_future
|
||||
def async_future(self, callback):
|
||||
self.io_loop.add_callback(callback, 42)
|
||||
@return_future
|
||||
def async_future(self, callback):
|
||||
self.io_loop.add_callback(callback, 42)
|
||||
|
||||
@return_future
|
||||
def immediate_failure(self, callback):
|
||||
1 / 0
|
||||
@return_future
|
||||
def immediate_failure(self, callback):
|
||||
1 / 0
|
||||
|
||||
@return_future
|
||||
def delayed_failure(self, callback):
|
||||
self.io_loop.add_callback(lambda: 1 / 0)
|
||||
@return_future
|
||||
def delayed_failure(self, callback):
|
||||
self.io_loop.add_callback(lambda: 1 / 0)
|
||||
|
||||
@return_future
|
||||
def return_value(self, callback):
|
||||
# Note that the result of both running the callback and returning
|
||||
# a value (or raising an exception) is unspecified; with current
|
||||
# implementations the last event prior to callback resolution wins.
|
||||
return 42
|
||||
@return_future
|
||||
def return_value(self, callback):
|
||||
# Note that the result of both running the callback and returning
|
||||
# a value (or raising an exception) is unspecified; with current
|
||||
# implementations the last event prior to callback resolution wins.
|
||||
return 42
|
||||
|
||||
@return_future
|
||||
def no_result_future(self, callback):
|
||||
callback()
|
||||
@return_future
|
||||
def no_result_future(self, callback):
|
||||
callback()
|
||||
|
||||
def test_immediate_failure(self):
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
|
|
@ -79,7 +101,8 @@ class ReturnFutureTest(AsyncTestCase):
|
|||
self.return_value(callback=self.stop)
|
||||
|
||||
def test_callback_kw(self):
|
||||
future = self.sync_future(callback=self.stop)
|
||||
with ignore_deprecation():
|
||||
future = self.sync_future(callback=self.stop)
|
||||
result = self.wait()
|
||||
self.assertEqual(result, 42)
|
||||
self.assertEqual(future.result(), 42)
|
||||
|
|
@ -87,7 +110,8 @@ class ReturnFutureTest(AsyncTestCase):
|
|||
def test_callback_positional(self):
|
||||
# When the callback is passed in positionally, future_wrap shouldn't
|
||||
# add another callback in the kwargs.
|
||||
future = self.sync_future(self.stop)
|
||||
with ignore_deprecation():
|
||||
future = self.sync_future(self.stop)
|
||||
result = self.wait()
|
||||
self.assertEqual(result, 42)
|
||||
self.assertEqual(future.result(), 42)
|
||||
|
|
@ -120,44 +144,68 @@ class ReturnFutureTest(AsyncTestCase):
|
|||
|
||||
def test_delayed_failure(self):
|
||||
future = self.delayed_failure()
|
||||
self.io_loop.add_future(future, self.stop)
|
||||
future2 = self.wait()
|
||||
with ignore_deprecation():
|
||||
self.io_loop.add_future(future, self.stop)
|
||||
future2 = self.wait()
|
||||
self.assertIs(future, future2)
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
future.result()
|
||||
|
||||
def test_kw_only_callback(self):
|
||||
@return_future
|
||||
def f(**kwargs):
|
||||
kwargs['callback'](42)
|
||||
with ignore_deprecation():
|
||||
@return_future
|
||||
def f(**kwargs):
|
||||
kwargs['callback'](42)
|
||||
future = f()
|
||||
self.assertEqual(future.result(), 42)
|
||||
|
||||
def test_error_in_callback(self):
|
||||
self.sync_future(callback=lambda future: 1 / 0)
|
||||
with ignore_deprecation():
|
||||
self.sync_future(callback=lambda future: 1 / 0)
|
||||
# The exception gets caught by our StackContext and will be re-raised
|
||||
# when we wait.
|
||||
self.assertRaises(ZeroDivisionError, self.wait)
|
||||
|
||||
def test_no_result_future(self):
|
||||
future = self.no_result_future(self.stop)
|
||||
with ignore_deprecation():
|
||||
future = self.no_result_future(self.stop)
|
||||
result = self.wait()
|
||||
self.assertIs(result, None)
|
||||
# result of this future is undefined, but not an error
|
||||
future.result()
|
||||
|
||||
def test_no_result_future_callback(self):
|
||||
future = self.no_result_future(callback=lambda: self.stop())
|
||||
with ignore_deprecation():
|
||||
future = self.no_result_future(callback=lambda: self.stop())
|
||||
result = self.wait()
|
||||
self.assertIs(result, None)
|
||||
future.result()
|
||||
|
||||
@gen_test
|
||||
def test_future_traceback_legacy(self):
|
||||
with ignore_deprecation():
|
||||
@return_future
|
||||
@gen.engine
|
||||
def f(callback):
|
||||
yield gen.Task(self.io_loop.add_callback)
|
||||
try:
|
||||
1 / 0
|
||||
except ZeroDivisionError:
|
||||
self.expected_frame = traceback.extract_tb(
|
||||
sys.exc_info()[2], limit=1)[0]
|
||||
raise
|
||||
try:
|
||||
yield f()
|
||||
self.fail("didn't get expected exception")
|
||||
except ZeroDivisionError:
|
||||
tb = traceback.extract_tb(sys.exc_info()[2])
|
||||
self.assertIn(self.expected_frame, tb)
|
||||
|
||||
@gen_test
|
||||
def test_future_traceback(self):
|
||||
@return_future
|
||||
@gen.engine
|
||||
def f(callback):
|
||||
yield gen.Task(self.io_loop.add_callback)
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield gen.moment
|
||||
try:
|
||||
1 / 0
|
||||
except ZeroDivisionError:
|
||||
|
|
@ -171,25 +219,50 @@ class ReturnFutureTest(AsyncTestCase):
|
|||
tb = traceback.extract_tb(sys.exc_info()[2])
|
||||
self.assertIn(self.expected_frame, tb)
|
||||
|
||||
@gen_test
|
||||
def test_uncaught_exception_log(self):
|
||||
if IOLoop.configured_class().__name__.endswith('AsyncIOLoop'):
|
||||
# Install an exception handler that mirrors our
|
||||
# non-asyncio logging behavior.
|
||||
def exc_handler(loop, context):
|
||||
app_log.error('%s: %s', context['message'],
|
||||
type(context.get('exception')))
|
||||
self.io_loop.asyncio_loop.set_exception_handler(exc_handler)
|
||||
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield gen.moment
|
||||
1 / 0
|
||||
|
||||
g = f()
|
||||
|
||||
with ExpectLog(app_log,
|
||||
"(?s)Future.* exception was never retrieved:"
|
||||
".*ZeroDivisionError"):
|
||||
yield gen.moment
|
||||
yield gen.moment
|
||||
# For some reason, TwistedIOLoop and pypy3 need a third iteration
|
||||
# in order to drain references to the future
|
||||
yield gen.moment
|
||||
del g
|
||||
gc.collect() # for PyPy
|
||||
|
||||
|
||||
# The following series of classes demonstrate and test various styles
|
||||
# of use, with and without generators and futures.
|
||||
|
||||
|
||||
class CapServer(TCPServer):
|
||||
@gen.coroutine
|
||||
def handle_stream(self, stream, address):
|
||||
logging.info("handle_stream")
|
||||
self.stream = stream
|
||||
self.stream.read_until(b"\n", self.handle_read)
|
||||
|
||||
def handle_read(self, data):
|
||||
logging.info("handle_read")
|
||||
data = yield stream.read_until(b"\n")
|
||||
data = to_unicode(data)
|
||||
if data == data.upper():
|
||||
self.stream.write(b"error\talready capitalized\n")
|
||||
stream.write(b"error\talready capitalized\n")
|
||||
else:
|
||||
# data already has \n
|
||||
self.stream.write(utf8("ok\t%s" % data.upper()))
|
||||
self.stream.close()
|
||||
stream.write(utf8("ok\t%s" % data.upper()))
|
||||
stream.close()
|
||||
|
||||
|
||||
class CapError(Exception):
|
||||
|
|
@ -197,9 +270,8 @@ class CapError(Exception):
|
|||
|
||||
|
||||
class BaseCapClient(object):
|
||||
def __init__(self, port, io_loop):
|
||||
def __init__(self, port):
|
||||
self.port = port
|
||||
self.io_loop = io_loop
|
||||
|
||||
def process_response(self, data):
|
||||
status, message = re.match('(.*)\t(.*)\n', to_unicode(data)).groups()
|
||||
|
|
@ -211,9 +283,9 @@ class BaseCapClient(object):
|
|||
|
||||
class ManualCapClient(BaseCapClient):
|
||||
def capitalize(self, request_data, callback=None):
|
||||
logging.info("capitalize")
|
||||
logging.debug("capitalize")
|
||||
self.request_data = request_data
|
||||
self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
|
||||
self.stream = IOStream(socket.socket())
|
||||
self.stream.connect(('127.0.0.1', self.port),
|
||||
callback=self.handle_connect)
|
||||
self.future = Future()
|
||||
|
|
@ -223,12 +295,12 @@ class ManualCapClient(BaseCapClient):
|
|||
return self.future
|
||||
|
||||
def handle_connect(self):
|
||||
logging.info("handle_connect")
|
||||
logging.debug("handle_connect")
|
||||
self.stream.write(utf8(self.request_data + "\n"))
|
||||
self.stream.read_until(b'\n', callback=self.handle_read)
|
||||
|
||||
def handle_read(self, data):
|
||||
logging.info("handle_read")
|
||||
logging.debug("handle_read")
|
||||
self.stream.close()
|
||||
try:
|
||||
self.future.set_result(self.process_response(data))
|
||||
|
|
@ -237,62 +309,64 @@ class ManualCapClient(BaseCapClient):
|
|||
|
||||
|
||||
class DecoratorCapClient(BaseCapClient):
|
||||
@return_future
|
||||
def capitalize(self, request_data, callback):
|
||||
logging.info("capitalize")
|
||||
self.request_data = request_data
|
||||
self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
|
||||
self.stream.connect(('127.0.0.1', self.port),
|
||||
callback=self.handle_connect)
|
||||
self.callback = callback
|
||||
with ignore_deprecation():
|
||||
@return_future
|
||||
def capitalize(self, request_data, callback):
|
||||
logging.debug("capitalize")
|
||||
self.request_data = request_data
|
||||
self.stream = IOStream(socket.socket())
|
||||
self.stream.connect(('127.0.0.1', self.port),
|
||||
callback=self.handle_connect)
|
||||
self.callback = callback
|
||||
|
||||
def handle_connect(self):
|
||||
logging.info("handle_connect")
|
||||
logging.debug("handle_connect")
|
||||
self.stream.write(utf8(self.request_data + "\n"))
|
||||
self.stream.read_until(b'\n', callback=self.handle_read)
|
||||
|
||||
def handle_read(self, data):
|
||||
logging.info("handle_read")
|
||||
logging.debug("handle_read")
|
||||
self.stream.close()
|
||||
self.callback(self.process_response(data))
|
||||
|
||||
|
||||
class GeneratorCapClient(BaseCapClient):
|
||||
@return_future
|
||||
@gen.engine
|
||||
def capitalize(self, request_data, callback):
|
||||
logging.info('capitalize')
|
||||
stream = IOStream(socket.socket(), io_loop=self.io_loop)
|
||||
logging.info('connecting')
|
||||
yield gen.Task(stream.connect, ('127.0.0.1', self.port))
|
||||
@gen.coroutine
|
||||
def capitalize(self, request_data):
|
||||
logging.debug('capitalize')
|
||||
stream = IOStream(socket.socket())
|
||||
logging.debug('connecting')
|
||||
yield stream.connect(('127.0.0.1', self.port))
|
||||
stream.write(utf8(request_data + '\n'))
|
||||
logging.info('reading')
|
||||
data = yield gen.Task(stream.read_until, b'\n')
|
||||
logging.info('returning')
|
||||
logging.debug('reading')
|
||||
data = yield stream.read_until(b'\n')
|
||||
logging.debug('returning')
|
||||
stream.close()
|
||||
callback(self.process_response(data))
|
||||
raise gen.Return(self.process_response(data))
|
||||
|
||||
|
||||
class ClientTestMixin(object):
|
||||
def setUp(self):
|
||||
super(ClientTestMixin, self).setUp()
|
||||
self.server = CapServer(io_loop=self.io_loop)
|
||||
super(ClientTestMixin, self).setUp() # type: ignore
|
||||
self.server = CapServer()
|
||||
sock, port = bind_unused_port()
|
||||
self.server.add_sockets([sock])
|
||||
self.client = self.client_class(io_loop=self.io_loop, port=port)
|
||||
self.client = self.client_class(port=port)
|
||||
|
||||
def tearDown(self):
|
||||
self.server.stop()
|
||||
super(ClientTestMixin, self).tearDown()
|
||||
super(ClientTestMixin, self).tearDown() # type: ignore
|
||||
|
||||
def test_callback(self):
|
||||
self.client.capitalize("hello", callback=self.stop)
|
||||
with ignore_deprecation():
|
||||
self.client.capitalize("hello", callback=self.stop)
|
||||
result = self.wait()
|
||||
self.assertEqual(result, "HELLO")
|
||||
|
||||
def test_callback_error(self):
|
||||
self.client.capitalize("HELLO", callback=self.stop)
|
||||
self.assertRaisesRegexp(CapError, "already capitalized", self.wait)
|
||||
with ignore_deprecation():
|
||||
self.client.capitalize("HELLO", callback=self.stop)
|
||||
self.assertRaisesRegexp(CapError, "already capitalized", self.wait)
|
||||
|
||||
def test_future(self):
|
||||
future = self.client.capitalize("hello")
|
||||
|
|
@ -307,33 +381,49 @@ class ClientTestMixin(object):
|
|||
self.assertRaisesRegexp(CapError, "already capitalized", future.result)
|
||||
|
||||
def test_generator(self):
|
||||
@gen.engine
|
||||
@gen.coroutine
|
||||
def f():
|
||||
result = yield self.client.capitalize("hello")
|
||||
self.assertEqual(result, "HELLO")
|
||||
self.stop()
|
||||
f()
|
||||
self.wait()
|
||||
self.io_loop.run_sync(f)
|
||||
|
||||
def test_generator_error(self):
|
||||
@gen.engine
|
||||
@gen.coroutine
|
||||
def f():
|
||||
with self.assertRaisesRegexp(CapError, "already capitalized"):
|
||||
yield self.client.capitalize("HELLO")
|
||||
self.stop()
|
||||
f()
|
||||
self.wait()
|
||||
self.io_loop.run_sync(f)
|
||||
|
||||
|
||||
class ManualClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
|
||||
class ManualClientTest(ClientTestMixin, AsyncTestCase):
|
||||
client_class = ManualCapClient
|
||||
|
||||
def setUp(self):
|
||||
self.warning_catcher = warnings.catch_warnings()
|
||||
self.warning_catcher.__enter__()
|
||||
warnings.simplefilter('ignore', DeprecationWarning)
|
||||
super(ManualClientTest, self).setUp()
|
||||
|
||||
class DecoratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
|
||||
def tearDown(self):
|
||||
super(ManualClientTest, self).tearDown()
|
||||
self.warning_catcher.__exit__(None, None, None)
|
||||
|
||||
|
||||
class DecoratorClientTest(ClientTestMixin, AsyncTestCase):
|
||||
client_class = DecoratorCapClient
|
||||
|
||||
def setUp(self):
|
||||
self.warning_catcher = warnings.catch_warnings()
|
||||
self.warning_catcher.__enter__()
|
||||
warnings.simplefilter('ignore', DeprecationWarning)
|
||||
super(DecoratorClientTest, self).setUp()
|
||||
|
||||
class GeneratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
|
||||
def tearDown(self):
|
||||
super(DecoratorClientTest, self).tearDown()
|
||||
self.warning_catcher.__exit__(None, None, None)
|
||||
|
||||
|
||||
class GeneratorClientTest(ClientTestMixin, AsyncTestCase):
|
||||
client_class = GeneratorCapClient
|
||||
|
||||
|
||||
|
|
@ -342,74 +432,65 @@ class RunOnExecutorTest(AsyncTestCase):
|
|||
@gen_test
|
||||
def test_no_calling(self):
|
||||
class Object(object):
|
||||
def __init__(self, io_loop):
|
||||
self.io_loop = io_loop
|
||||
def __init__(self):
|
||||
self.executor = futures.thread.ThreadPoolExecutor(1)
|
||||
|
||||
@run_on_executor
|
||||
def f(self):
|
||||
return 42
|
||||
|
||||
o = Object(io_loop=self.io_loop)
|
||||
o = Object()
|
||||
answer = yield o.f()
|
||||
self.assertEqual(answer, 42)
|
||||
|
||||
@gen_test
|
||||
def test_call_with_no_args(self):
|
||||
class Object(object):
|
||||
def __init__(self, io_loop):
|
||||
self.io_loop = io_loop
|
||||
def __init__(self):
|
||||
self.executor = futures.thread.ThreadPoolExecutor(1)
|
||||
|
||||
@run_on_executor()
|
||||
def f(self):
|
||||
return 42
|
||||
|
||||
o = Object(io_loop=self.io_loop)
|
||||
answer = yield o.f()
|
||||
self.assertEqual(answer, 42)
|
||||
|
||||
@gen_test
|
||||
def test_call_with_io_loop(self):
|
||||
class Object(object):
|
||||
def __init__(self, io_loop):
|
||||
self._io_loop = io_loop
|
||||
self.executor = futures.thread.ThreadPoolExecutor(1)
|
||||
|
||||
@run_on_executor(io_loop='_io_loop')
|
||||
def f(self):
|
||||
return 42
|
||||
|
||||
o = Object(io_loop=self.io_loop)
|
||||
o = Object()
|
||||
answer = yield o.f()
|
||||
self.assertEqual(answer, 42)
|
||||
|
||||
@gen_test
|
||||
def test_call_with_executor(self):
|
||||
class Object(object):
|
||||
def __init__(self, io_loop):
|
||||
self.io_loop = io_loop
|
||||
def __init__(self):
|
||||
self.__executor = futures.thread.ThreadPoolExecutor(1)
|
||||
|
||||
@run_on_executor(executor='_Object__executor')
|
||||
def f(self):
|
||||
return 42
|
||||
|
||||
o = Object(io_loop=self.io_loop)
|
||||
o = Object()
|
||||
answer = yield o.f()
|
||||
self.assertEqual(answer, 42)
|
||||
|
||||
@skipBefore35
|
||||
@gen_test
|
||||
def test_call_with_both(self):
|
||||
def test_async_await(self):
|
||||
class Object(object):
|
||||
def __init__(self, io_loop):
|
||||
self._io_loop = io_loop
|
||||
self.__executor = futures.thread.ThreadPoolExecutor(1)
|
||||
def __init__(self):
|
||||
self.executor = futures.thread.ThreadPoolExecutor(1)
|
||||
|
||||
@run_on_executor(io_loop='_io_loop', executor='_Object__executor')
|
||||
@run_on_executor()
|
||||
def f(self):
|
||||
return 42
|
||||
|
||||
o = Object(io_loop=self.io_loop)
|
||||
answer = yield o.f()
|
||||
self.assertEqual(answer, 42)
|
||||
o = Object()
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
async def f():
|
||||
answer = await o.f()
|
||||
return answer
|
||||
""")
|
||||
result = yield namespace['f']()
|
||||
self.assertEqual(result, 42)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -1,18 +1,20 @@
|
|||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
# coding: utf-8
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from hashlib import md5
|
||||
|
||||
from tornado.escape import utf8
|
||||
from tornado.httpclient import HTTPRequest
|
||||
from tornado.httpclient import HTTPRequest, HTTPClientError
|
||||
from tornado.locks import Event
|
||||
from tornado.stack_context import ExceptionStackContext
|
||||
from tornado.testing import AsyncHTTPTestCase
|
||||
from tornado.testing import AsyncHTTPTestCase, gen_test
|
||||
from tornado.test import httpclient_test
|
||||
from tornado.test.util import unittest
|
||||
from tornado.test.util import unittest, ignore_deprecation
|
||||
from tornado.web import Application, RequestHandler
|
||||
|
||||
|
||||
try:
|
||||
import pycurl
|
||||
import pycurl # type: ignore
|
||||
except ImportError:
|
||||
pycurl = None
|
||||
|
||||
|
|
@ -23,21 +25,22 @@ if pycurl is not None:
|
|||
@unittest.skipIf(pycurl is None, "pycurl module not present")
|
||||
class CurlHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
|
||||
def get_http_client(self):
|
||||
client = CurlAsyncHTTPClient(io_loop=self.io_loop,
|
||||
defaults=dict(allow_ipv6=False))
|
||||
client = CurlAsyncHTTPClient(defaults=dict(allow_ipv6=False))
|
||||
# make sure AsyncHTTPClient magic doesn't give us the wrong class
|
||||
self.assertTrue(isinstance(client, CurlAsyncHTTPClient))
|
||||
return client
|
||||
|
||||
|
||||
class DigestAuthHandler(RequestHandler):
|
||||
def initialize(self, username, password):
|
||||
self.username = username
|
||||
self.password = password
|
||||
|
||||
def get(self):
|
||||
realm = 'test'
|
||||
opaque = 'asdf'
|
||||
# Real implementations would use a random nonce.
|
||||
nonce = "1234"
|
||||
username = 'foo'
|
||||
password = 'bar'
|
||||
|
||||
auth_header = self.request.headers.get('Authorization', None)
|
||||
if auth_header is not None:
|
||||
|
|
@ -52,9 +55,9 @@ class DigestAuthHandler(RequestHandler):
|
|||
assert param_dict['realm'] == realm
|
||||
assert param_dict['opaque'] == opaque
|
||||
assert param_dict['nonce'] == nonce
|
||||
assert param_dict['username'] == username
|
||||
assert param_dict['username'] == self.username
|
||||
assert param_dict['uri'] == self.request.path
|
||||
h1 = md5(utf8('%s:%s:%s' % (username, realm, password))).hexdigest()
|
||||
h1 = md5(utf8('%s:%s:%s' % (self.username, realm, self.password))).hexdigest()
|
||||
h2 = md5(utf8('%s:%s' % (self.request.method,
|
||||
self.request.path))).hexdigest()
|
||||
digest = md5(utf8('%s:%s:%s' % (h1, nonce, h2))).hexdigest()
|
||||
|
|
@ -83,29 +86,36 @@ class CustomFailReasonHandler(RequestHandler):
|
|||
class CurlHTTPClientTestCase(AsyncHTTPTestCase):
|
||||
def setUp(self):
|
||||
super(CurlHTTPClientTestCase, self).setUp()
|
||||
self.http_client = CurlAsyncHTTPClient(self.io_loop,
|
||||
defaults=dict(allow_ipv6=False))
|
||||
self.http_client = self.create_client()
|
||||
|
||||
def get_app(self):
|
||||
return Application([
|
||||
('/digest', DigestAuthHandler),
|
||||
('/digest', DigestAuthHandler, {'username': 'foo', 'password': 'bar'}),
|
||||
('/digest_non_ascii', DigestAuthHandler, {'username': 'foo', 'password': 'barユ£'}),
|
||||
('/custom_reason', CustomReasonHandler),
|
||||
('/custom_fail_reason', CustomFailReasonHandler),
|
||||
])
|
||||
|
||||
def create_client(self, **kwargs):
|
||||
return CurlAsyncHTTPClient(force_instance=True,
|
||||
defaults=dict(allow_ipv6=False),
|
||||
**kwargs)
|
||||
|
||||
@gen_test
|
||||
def test_prepare_curl_callback_stack_context(self):
|
||||
exc_info = []
|
||||
error_event = Event()
|
||||
|
||||
def error_handler(typ, value, tb):
|
||||
exc_info.append((typ, value, tb))
|
||||
self.stop()
|
||||
error_event.set()
|
||||
return True
|
||||
|
||||
with ExceptionStackContext(error_handler):
|
||||
request = HTTPRequest(self.get_url('/'),
|
||||
prepare_curl_callback=lambda curl: 1 / 0)
|
||||
self.http_client.fetch(request, callback=self.stop)
|
||||
self.wait()
|
||||
with ignore_deprecation():
|
||||
with ExceptionStackContext(error_handler):
|
||||
request = HTTPRequest(self.get_url('/custom_reason'),
|
||||
prepare_curl_callback=lambda curl: 1 / 0)
|
||||
yield [error_event.wait(), self.http_client.fetch(request)]
|
||||
self.assertEqual(1, len(exc_info))
|
||||
self.assertIs(exc_info[0][0], ZeroDivisionError)
|
||||
|
||||
|
|
@ -122,3 +132,22 @@ class CurlHTTPClientTestCase(AsyncHTTPTestCase):
|
|||
response = self.fetch('/custom_fail_reason')
|
||||
self.assertEqual(str(response.error), "HTTP 400: Custom reason")
|
||||
|
||||
def test_failed_setup(self):
|
||||
self.http_client = self.create_client(max_clients=1)
|
||||
for i in range(5):
|
||||
with ignore_deprecation():
|
||||
response = self.fetch(u'/ユニコード')
|
||||
self.assertIsNot(response.error, None)
|
||||
|
||||
with self.assertRaises((UnicodeEncodeError, HTTPClientError)):
|
||||
# This raises UnicodeDecodeError on py3 and
|
||||
# HTTPClientError(404) on py2. The main motivation of
|
||||
# this test is to ensure that the UnicodeEncodeError
|
||||
# during the setup phase doesn't lead the request to
|
||||
# be dropped on the floor.
|
||||
response = self.fetch(u'/ユニコード', raise_error=True)
|
||||
|
||||
def test_digest_auth_non_ascii(self):
|
||||
response = self.fetch('/digest_non_ascii', auth_mode='digest',
|
||||
auth_username='foo', auth_password='barユ£')
|
||||
self.assertEqual(response.body, b'ok')
|
||||
|
|
|
|||
|
|
@ -1,134 +1,138 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
import tornado.escape
|
||||
|
||||
from tornado.escape import utf8, xhtml_escape, xhtml_unescape, url_escape, url_unescape, to_unicode, json_decode, json_encode, squeeze, recursive_unicode
|
||||
from tornado.util import u, unicode_type
|
||||
from tornado.escape import (
|
||||
utf8, xhtml_escape, xhtml_unescape, url_escape, url_unescape,
|
||||
to_unicode, json_decode, json_encode, squeeze, recursive_unicode,
|
||||
)
|
||||
from tornado.util import unicode_type
|
||||
from tornado.test.util import unittest
|
||||
|
||||
linkify_tests = [
|
||||
# (input, linkify_kwargs, expected_output)
|
||||
|
||||
("hello http://world.com/!", {},
|
||||
u('hello <a href="http://world.com/">http://world.com/</a>!')),
|
||||
u'hello <a href="http://world.com/">http://world.com/</a>!'),
|
||||
|
||||
("hello http://world.com/with?param=true&stuff=yes", {},
|
||||
u('hello <a href="http://world.com/with?param=true&stuff=yes">http://world.com/with?param=true&stuff=yes</a>')),
|
||||
u'hello <a href="http://world.com/with?param=true&stuff=yes">http://world.com/with?param=true&stuff=yes</a>'), # noqa: E501
|
||||
|
||||
# an opened paren followed by many chars killed Gruber's regex
|
||||
("http://url.com/w(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", {},
|
||||
u('<a href="http://url.com/w">http://url.com/w</a>(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa')),
|
||||
u'<a href="http://url.com/w">http://url.com/w</a>(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'), # noqa: E501
|
||||
|
||||
# as did too many dots at the end
|
||||
("http://url.com/withmany.......................................", {},
|
||||
u('<a href="http://url.com/withmany">http://url.com/withmany</a>.......................................')),
|
||||
u'<a href="http://url.com/withmany">http://url.com/withmany</a>.......................................'), # noqa: E501
|
||||
|
||||
("http://url.com/withmany((((((((((((((((((((((((((((((((((a)", {},
|
||||
u('<a href="http://url.com/withmany">http://url.com/withmany</a>((((((((((((((((((((((((((((((((((a)')),
|
||||
u'<a href="http://url.com/withmany">http://url.com/withmany</a>((((((((((((((((((((((((((((((((((a)'), # noqa: E501
|
||||
|
||||
# some examples from http://daringfireball.net/2009/11/liberal_regex_for_matching_urls
|
||||
# plus a fex extras (such as multiple parentheses).
|
||||
("http://foo.com/blah_blah", {},
|
||||
u('<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>')),
|
||||
u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>'),
|
||||
|
||||
("http://foo.com/blah_blah/", {},
|
||||
u('<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>')),
|
||||
u'<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>'),
|
||||
|
||||
("(Something like http://foo.com/blah_blah)", {},
|
||||
u('(Something like <a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>)')),
|
||||
u'(Something like <a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>)'),
|
||||
|
||||
("http://foo.com/blah_blah_(wikipedia)", {},
|
||||
u('<a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>')),
|
||||
u'<a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>'),
|
||||
|
||||
("http://foo.com/blah_(blah)_(wikipedia)_blah", {},
|
||||
u('<a href="http://foo.com/blah_(blah)_(wikipedia)_blah">http://foo.com/blah_(blah)_(wikipedia)_blah</a>')),
|
||||
u'<a href="http://foo.com/blah_(blah)_(wikipedia)_blah">http://foo.com/blah_(blah)_(wikipedia)_blah</a>'), # noqa: E501
|
||||
|
||||
("(Something like http://foo.com/blah_blah_(wikipedia))", {},
|
||||
u('(Something like <a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>)')),
|
||||
u'(Something like <a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>)'), # noqa: E501
|
||||
|
||||
("http://foo.com/blah_blah.", {},
|
||||
u('<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>.')),
|
||||
u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>.'),
|
||||
|
||||
("http://foo.com/blah_blah/.", {},
|
||||
u('<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>.')),
|
||||
u'<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>.'),
|
||||
|
||||
("<http://foo.com/blah_blah>", {},
|
||||
u('<<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>>')),
|
||||
u'<<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>>'),
|
||||
|
||||
("<http://foo.com/blah_blah/>", {},
|
||||
u('<<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>>')),
|
||||
u'<<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>>'),
|
||||
|
||||
("http://foo.com/blah_blah,", {},
|
||||
u('<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>,')),
|
||||
u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>,'),
|
||||
|
||||
("http://www.example.com/wpstyle/?p=364.", {},
|
||||
u('<a href="http://www.example.com/wpstyle/?p=364">http://www.example.com/wpstyle/?p=364</a>.')),
|
||||
u'<a href="http://www.example.com/wpstyle/?p=364">http://www.example.com/wpstyle/?p=364</a>.'),
|
||||
|
||||
("rdar://1234",
|
||||
{"permitted_protocols": ["http", "rdar"]},
|
||||
u('<a href="rdar://1234">rdar://1234</a>')),
|
||||
u'<a href="rdar://1234">rdar://1234</a>'),
|
||||
|
||||
("rdar:/1234",
|
||||
{"permitted_protocols": ["rdar"]},
|
||||
u('<a href="rdar:/1234">rdar:/1234</a>')),
|
||||
u'<a href="rdar:/1234">rdar:/1234</a>'),
|
||||
|
||||
("http://userid:password@example.com:8080", {},
|
||||
u('<a href="http://userid:password@example.com:8080">http://userid:password@example.com:8080</a>')),
|
||||
u'<a href="http://userid:password@example.com:8080">http://userid:password@example.com:8080</a>'), # noqa: E501
|
||||
|
||||
("http://userid@example.com", {},
|
||||
u('<a href="http://userid@example.com">http://userid@example.com</a>')),
|
||||
u'<a href="http://userid@example.com">http://userid@example.com</a>'),
|
||||
|
||||
("http://userid@example.com:8080", {},
|
||||
u('<a href="http://userid@example.com:8080">http://userid@example.com:8080</a>')),
|
||||
u'<a href="http://userid@example.com:8080">http://userid@example.com:8080</a>'),
|
||||
|
||||
("http://userid:password@example.com", {},
|
||||
u('<a href="http://userid:password@example.com">http://userid:password@example.com</a>')),
|
||||
u'<a href="http://userid:password@example.com">http://userid:password@example.com</a>'),
|
||||
|
||||
("message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e",
|
||||
{"permitted_protocols": ["http", "message"]},
|
||||
u('<a href="message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e">message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e</a>')),
|
||||
u'<a href="message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e">'
|
||||
u'message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e</a>'),
|
||||
|
||||
(u("http://\u27a1.ws/\u4a39"), {},
|
||||
u('<a href="http://\u27a1.ws/\u4a39">http://\u27a1.ws/\u4a39</a>')),
|
||||
(u"http://\u27a1.ws/\u4a39", {},
|
||||
u'<a href="http://\u27a1.ws/\u4a39">http://\u27a1.ws/\u4a39</a>'),
|
||||
|
||||
("<tag>http://example.com</tag>", {},
|
||||
u('<tag><a href="http://example.com">http://example.com</a></tag>')),
|
||||
u'<tag><a href="http://example.com">http://example.com</a></tag>'),
|
||||
|
||||
("Just a www.example.com link.", {},
|
||||
u('Just a <a href="http://www.example.com">www.example.com</a> link.')),
|
||||
u'Just a <a href="http://www.example.com">www.example.com</a> link.'),
|
||||
|
||||
("Just a www.example.com link.",
|
||||
{"require_protocol": True},
|
||||
u('Just a www.example.com link.')),
|
||||
u'Just a www.example.com link.'),
|
||||
|
||||
("A http://reallylong.com/link/that/exceedsthelenglimit.html",
|
||||
{"require_protocol": True, "shorten": True},
|
||||
u('A <a href="http://reallylong.com/link/that/exceedsthelenglimit.html" title="http://reallylong.com/link/that/exceedsthelenglimit.html">http://reallylong.com/link...</a>')),
|
||||
u'A <a href="http://reallylong.com/link/that/exceedsthelenglimit.html"'
|
||||
u' title="http://reallylong.com/link/that/exceedsthelenglimit.html">http://reallylong.com/link...</a>'), # noqa: E501
|
||||
|
||||
("A http://reallylongdomainnamethatwillbetoolong.com/hi!",
|
||||
{"shorten": True},
|
||||
u('A <a href="http://reallylongdomainnamethatwillbetoolong.com/hi" title="http://reallylongdomainnamethatwillbetoolong.com/hi">http://reallylongdomainnametha...</a>!')),
|
||||
u'A <a href="http://reallylongdomainnamethatwillbetoolong.com/hi"'
|
||||
u' title="http://reallylongdomainnamethatwillbetoolong.com/hi">http://reallylongdomainnametha...</a>!'), # noqa: E501
|
||||
|
||||
("A file:///passwords.txt and http://web.com link", {},
|
||||
u('A file:///passwords.txt and <a href="http://web.com">http://web.com</a> link')),
|
||||
u'A file:///passwords.txt and <a href="http://web.com">http://web.com</a> link'),
|
||||
|
||||
("A file:///passwords.txt and http://web.com link",
|
||||
{"permitted_protocols": ["file"]},
|
||||
u('A <a href="file:///passwords.txt">file:///passwords.txt</a> and http://web.com link')),
|
||||
u'A <a href="file:///passwords.txt">file:///passwords.txt</a> and http://web.com link'),
|
||||
|
||||
("www.external-link.com",
|
||||
{"extra_params": 'rel="nofollow" class="external"'},
|
||||
u('<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>')),
|
||||
u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>'), # noqa: E501
|
||||
|
||||
("www.external-link.com and www.internal-link.com/blogs extra",
|
||||
{"extra_params": lambda href: 'class="internal"' if href.startswith("http://www.internal-link.com") else 'rel="nofollow" class="external"'},
|
||||
u('<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a> and <a href="http://www.internal-link.com/blogs" class="internal">www.internal-link.com/blogs</a> extra')),
|
||||
{"extra_params": lambda href: 'class="internal"' if href.startswith("http://www.internal-link.com") else 'rel="nofollow" class="external"'}, # noqa: E501
|
||||
u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>' # noqa: E501
|
||||
u' and <a href="http://www.internal-link.com/blogs" class="internal">www.internal-link.com/blogs</a> extra'), # noqa: E501
|
||||
|
||||
("www.external-link.com",
|
||||
{"extra_params": lambda href: ' rel="nofollow" class="external" '},
|
||||
u('<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>')),
|
||||
u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>'), # noqa: E501
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -141,13 +145,13 @@ class EscapeTestCase(unittest.TestCase):
|
|||
def test_xhtml_escape(self):
|
||||
tests = [
|
||||
("<foo>", "<foo>"),
|
||||
(u("<foo>"), u("<foo>")),
|
||||
(u"<foo>", u"<foo>"),
|
||||
(b"<foo>", b"<foo>"),
|
||||
|
||||
("<>&\"'", "<>&"'"),
|
||||
("&", "&amp;"),
|
||||
|
||||
(u("<\u00e9>"), u("<\u00e9>")),
|
||||
(u"<\u00e9>", u"<\u00e9>"),
|
||||
(b"<\xc3\xa9>", b"<\xc3\xa9>"),
|
||||
]
|
||||
for unescaped, escaped in tests:
|
||||
|
|
@ -159,7 +163,7 @@ class EscapeTestCase(unittest.TestCase):
|
|||
('foo bar', 'foo bar'),
|
||||
('foo bar', 'foo bar'),
|
||||
('foo bar', 'foo bar'),
|
||||
('foo઼bar', u('foo\u0abcbar')),
|
||||
('foo઼bar', u'foo\u0abcbar'),
|
||||
('foo&#xyz;bar', 'foo&#xyz;bar'), # invalid encoding
|
||||
('foo&#;bar', 'foo&#;bar'), # invalid encoding
|
||||
('foo&#x;bar', 'foo&#x;bar'), # invalid encoding
|
||||
|
|
@ -170,20 +174,20 @@ class EscapeTestCase(unittest.TestCase):
|
|||
def test_url_escape_unicode(self):
|
||||
tests = [
|
||||
# byte strings are passed through as-is
|
||||
(u('\u00e9').encode('utf8'), '%C3%A9'),
|
||||
(u('\u00e9').encode('latin1'), '%E9'),
|
||||
(u'\u00e9'.encode('utf8'), '%C3%A9'),
|
||||
(u'\u00e9'.encode('latin1'), '%E9'),
|
||||
|
||||
# unicode strings become utf8
|
||||
(u('\u00e9'), '%C3%A9'),
|
||||
(u'\u00e9', '%C3%A9'),
|
||||
]
|
||||
for unescaped, escaped in tests:
|
||||
self.assertEqual(url_escape(unescaped), escaped)
|
||||
|
||||
def test_url_unescape_unicode(self):
|
||||
tests = [
|
||||
('%C3%A9', u('\u00e9'), 'utf8'),
|
||||
('%C3%A9', u('\u00c3\u00a9'), 'latin1'),
|
||||
('%C3%A9', utf8(u('\u00e9')), None),
|
||||
('%C3%A9', u'\u00e9', 'utf8'),
|
||||
('%C3%A9', u'\u00c3\u00a9', 'latin1'),
|
||||
('%C3%A9', utf8(u'\u00e9'), None),
|
||||
]
|
||||
for escaped, unescaped, encoding in tests:
|
||||
# input strings to url_unescape should only contain ascii
|
||||
|
|
@ -209,28 +213,29 @@ class EscapeTestCase(unittest.TestCase):
|
|||
# On python2 the escape methods should generally return the same
|
||||
# type as their argument
|
||||
self.assertEqual(type(xhtml_escape("foo")), str)
|
||||
self.assertEqual(type(xhtml_escape(u("foo"))), unicode_type)
|
||||
self.assertEqual(type(xhtml_escape(u"foo")), unicode_type)
|
||||
|
||||
def test_json_decode(self):
|
||||
# json_decode accepts both bytes and unicode, but strings it returns
|
||||
# are always unicode.
|
||||
self.assertEqual(json_decode(b'"foo"'), u("foo"))
|
||||
self.assertEqual(json_decode(u('"foo"')), u("foo"))
|
||||
self.assertEqual(json_decode(b'"foo"'), u"foo")
|
||||
self.assertEqual(json_decode(u'"foo"'), u"foo")
|
||||
|
||||
# Non-ascii bytes are interpreted as utf8
|
||||
self.assertEqual(json_decode(utf8(u('"\u00e9"'))), u("\u00e9"))
|
||||
self.assertEqual(json_decode(utf8(u'"\u00e9"')), u"\u00e9")
|
||||
|
||||
def test_json_encode(self):
|
||||
# json deals with strings, not bytes. On python 2 byte strings will
|
||||
# convert automatically if they are utf8; on python 3 byte strings
|
||||
# are not allowed.
|
||||
self.assertEqual(json_decode(json_encode(u("\u00e9"))), u("\u00e9"))
|
||||
self.assertEqual(json_decode(json_encode(u"\u00e9")), u"\u00e9")
|
||||
if bytes is str:
|
||||
self.assertEqual(json_decode(json_encode(utf8(u("\u00e9")))), u("\u00e9"))
|
||||
self.assertEqual(json_decode(json_encode(utf8(u"\u00e9"))), u"\u00e9")
|
||||
self.assertRaises(UnicodeDecodeError, json_encode, b"\xe9")
|
||||
|
||||
def test_squeeze(self):
|
||||
self.assertEqual(squeeze(u('sequences of whitespace chars')), u('sequences of whitespace chars'))
|
||||
self.assertEqual(squeeze(u'sequences of whitespace chars'),
|
||||
u'sequences of whitespace chars')
|
||||
|
||||
def test_recursive_unicode(self):
|
||||
tests = {
|
||||
|
|
@ -239,7 +244,7 @@ class EscapeTestCase(unittest.TestCase):
|
|||
'tuple': (b"foo", b"bar"),
|
||||
'bytes': b"foo"
|
||||
}
|
||||
self.assertEqual(recursive_unicode(tests['dict']), {u("foo"): u("bar")})
|
||||
self.assertEqual(recursive_unicode(tests['list']), [u("foo"), u("bar")])
|
||||
self.assertEqual(recursive_unicode(tests['tuple']), (u("foo"), u("bar")))
|
||||
self.assertEqual(recursive_unicode(tests['bytes']), u("foo"))
|
||||
self.assertEqual(recursive_unicode(tests['dict']), {u"foo": u"bar"})
|
||||
self.assertEqual(recursive_unicode(tests['list']), [u"foo", u"bar"])
|
||||
self.assertEqual(recursive_unicode(tests['tuple']), (u"foo", u"bar"))
|
||||
self.assertEqual(recursive_unicode(tests['bytes']), u"foo")
|
||||
|
|
|
|||
|
|
@ -1,12 +1,15 @@
|
|||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import gc
|
||||
import contextlib
|
||||
import datetime
|
||||
import functools
|
||||
import platform
|
||||
import sys
|
||||
import textwrap
|
||||
import time
|
||||
import weakref
|
||||
import warnings
|
||||
|
||||
from tornado.concurrent import return_future, Future
|
||||
from tornado.escape import url_escape
|
||||
|
|
@ -15,7 +18,7 @@ from tornado.ioloop import IOLoop
|
|||
from tornado.log import app_log
|
||||
from tornado import stack_context
|
||||
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
|
||||
from tornado.test.util import unittest, skipOnTravis, skipBefore33, skipBefore35, skipNotCPython, exec_test
|
||||
from tornado.test.util import unittest, skipOnTravis, skipBefore33, skipBefore35, skipNotCPython, exec_test, ignore_deprecation # noqa: E501
|
||||
from tornado.web import Application, RequestHandler, asynchronous, HTTPError
|
||||
|
||||
from tornado import gen
|
||||
|
|
@ -25,12 +28,24 @@ try:
|
|||
except ImportError:
|
||||
futures = None
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
except ImportError:
|
||||
asyncio = None
|
||||
|
||||
|
||||
class GenEngineTest(AsyncTestCase):
|
||||
def setUp(self):
|
||||
self.warning_catcher = warnings.catch_warnings()
|
||||
self.warning_catcher.__enter__()
|
||||
warnings.simplefilter('ignore', DeprecationWarning)
|
||||
super(GenEngineTest, self).setUp()
|
||||
self.named_contexts = []
|
||||
|
||||
def tearDown(self):
|
||||
super(GenEngineTest, self).tearDown()
|
||||
self.warning_catcher.__exit__(None, None, None)
|
||||
|
||||
def named_context(self, name):
|
||||
@contextlib.contextmanager
|
||||
def context():
|
||||
|
|
@ -53,9 +68,10 @@ class GenEngineTest(AsyncTestCase):
|
|||
self.io_loop.add_callback(functools.partial(
|
||||
self.delay_callback, iterations - 1, callback, arg))
|
||||
|
||||
@return_future
|
||||
def async_future(self, result, callback):
|
||||
self.io_loop.add_callback(callback, result)
|
||||
with ignore_deprecation():
|
||||
@return_future
|
||||
def async_future(self, result, callback):
|
||||
self.io_loop.add_callback(callback, result)
|
||||
|
||||
@gen.coroutine
|
||||
def async_exception(self, e):
|
||||
|
|
@ -276,6 +292,13 @@ class GenEngineTest(AsyncTestCase):
|
|||
pass
|
||||
self.orphaned_callback()
|
||||
|
||||
def test_none(self):
|
||||
@gen.engine
|
||||
def f():
|
||||
yield None
|
||||
self.stop()
|
||||
self.run_gen(f)
|
||||
|
||||
def test_multi(self):
|
||||
@gen.engine
|
||||
def f():
|
||||
|
|
@ -645,6 +668,237 @@ class GenEngineTest(AsyncTestCase):
|
|||
self.assertIs(self.task_ref(), None)
|
||||
|
||||
|
||||
# GenBasicTest duplicates the non-deprecated portions of GenEngineTest
|
||||
# with gen.coroutine to ensure we don't lose coverage when gen.engine
|
||||
# goes away.
|
||||
class GenBasicTest(AsyncTestCase):
|
||||
@gen.coroutine
|
||||
def delay(self, iterations, arg):
|
||||
"""Returns arg after a number of IOLoop iterations."""
|
||||
for i in range(iterations):
|
||||
yield gen.moment
|
||||
raise gen.Return(arg)
|
||||
|
||||
with ignore_deprecation():
|
||||
@return_future
|
||||
def async_future(self, result, callback):
|
||||
self.io_loop.add_callback(callback, result)
|
||||
|
||||
@gen.coroutine
|
||||
def async_exception(self, e):
|
||||
yield gen.moment
|
||||
raise e
|
||||
|
||||
@gen.coroutine
|
||||
def add_one_async(self, x):
|
||||
yield gen.moment
|
||||
raise gen.Return(x + 1)
|
||||
|
||||
def test_no_yield(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
pass
|
||||
self.io_loop.run_sync(f)
|
||||
|
||||
def test_exception_phase1(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
1 / 0
|
||||
self.assertRaises(ZeroDivisionError, self.io_loop.run_sync, f)
|
||||
|
||||
def test_exception_phase2(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield gen.moment
|
||||
1 / 0
|
||||
self.assertRaises(ZeroDivisionError, self.io_loop.run_sync, f)
|
||||
|
||||
def test_bogus_yield(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield 42
|
||||
self.assertRaises(gen.BadYieldError, self.io_loop.run_sync, f)
|
||||
|
||||
def test_bogus_yield_tuple(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield (1, 2)
|
||||
self.assertRaises(gen.BadYieldError, self.io_loop.run_sync, f)
|
||||
|
||||
def test_reuse(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield gen.moment
|
||||
self.io_loop.run_sync(f)
|
||||
self.io_loop.run_sync(f)
|
||||
|
||||
def test_none(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield None
|
||||
self.io_loop.run_sync(f)
|
||||
|
||||
def test_multi(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
results = yield [self.add_one_async(1), self.add_one_async(2)]
|
||||
self.assertEqual(results, [2, 3])
|
||||
self.io_loop.run_sync(f)
|
||||
|
||||
def test_multi_dict(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
results = yield dict(foo=self.add_one_async(1), bar=self.add_one_async(2))
|
||||
self.assertEqual(results, dict(foo=2, bar=3))
|
||||
self.io_loop.run_sync(f)
|
||||
|
||||
def test_multi_delayed(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
# callbacks run at different times
|
||||
responses = yield gen.multi_future([
|
||||
self.delay(3, "v1"),
|
||||
self.delay(1, "v2"),
|
||||
])
|
||||
self.assertEqual(responses, ["v1", "v2"])
|
||||
self.io_loop.run_sync(f)
|
||||
|
||||
def test_multi_dict_delayed(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
# callbacks run at different times
|
||||
responses = yield gen.multi_future(dict(
|
||||
foo=self.delay(3, "v1"),
|
||||
bar=self.delay(1, "v2"),
|
||||
))
|
||||
self.assertEqual(responses, dict(foo="v1", bar="v2"))
|
||||
self.io_loop.run_sync(f)
|
||||
|
||||
@skipOnTravis
|
||||
@gen_test
|
||||
def test_multi_performance(self):
|
||||
# Yielding a list used to have quadratic performance; make
|
||||
# sure a large list stays reasonable. On my laptop a list of
|
||||
# 2000 used to take 1.8s, now it takes 0.12.
|
||||
start = time.time()
|
||||
yield [gen.moment for i in range(2000)]
|
||||
end = time.time()
|
||||
self.assertLess(end - start, 1.0)
|
||||
|
||||
@gen_test
|
||||
def test_multi_empty(self):
|
||||
# Empty lists or dicts should return the same type.
|
||||
x = yield []
|
||||
self.assertTrue(isinstance(x, list))
|
||||
y = yield {}
|
||||
self.assertTrue(isinstance(y, dict))
|
||||
|
||||
@gen_test
|
||||
def test_future(self):
|
||||
result = yield self.async_future(1)
|
||||
self.assertEqual(result, 1)
|
||||
|
||||
@gen_test
|
||||
def test_multi_future(self):
|
||||
results = yield [self.async_future(1), self.async_future(2)]
|
||||
self.assertEqual(results, [1, 2])
|
||||
|
||||
@gen_test
|
||||
def test_multi_future_duplicate(self):
|
||||
f = self.async_future(2)
|
||||
results = yield [self.async_future(1), f, self.async_future(3), f]
|
||||
self.assertEqual(results, [1, 2, 3, 2])
|
||||
|
||||
@gen_test
|
||||
def test_multi_dict_future(self):
|
||||
results = yield dict(foo=self.async_future(1), bar=self.async_future(2))
|
||||
self.assertEqual(results, dict(foo=1, bar=2))
|
||||
|
||||
@gen_test
|
||||
def test_multi_exceptions(self):
|
||||
with ExpectLog(app_log, "Multiple exceptions in yield list"):
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
yield gen.Multi([self.async_exception(RuntimeError("error 1")),
|
||||
self.async_exception(RuntimeError("error 2"))])
|
||||
self.assertEqual(str(cm.exception), "error 1")
|
||||
|
||||
# With only one exception, no error is logged.
|
||||
with self.assertRaises(RuntimeError):
|
||||
yield gen.Multi([self.async_exception(RuntimeError("error 1")),
|
||||
self.async_future(2)])
|
||||
|
||||
# Exception logging may be explicitly quieted.
|
||||
with self.assertRaises(RuntimeError):
|
||||
yield gen.Multi([self.async_exception(RuntimeError("error 1")),
|
||||
self.async_exception(RuntimeError("error 2"))],
|
||||
quiet_exceptions=RuntimeError)
|
||||
|
||||
@gen_test
|
||||
def test_multi_future_exceptions(self):
|
||||
with ExpectLog(app_log, "Multiple exceptions in yield list"):
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
yield [self.async_exception(RuntimeError("error 1")),
|
||||
self.async_exception(RuntimeError("error 2"))]
|
||||
self.assertEqual(str(cm.exception), "error 1")
|
||||
|
||||
# With only one exception, no error is logged.
|
||||
with self.assertRaises(RuntimeError):
|
||||
yield [self.async_exception(RuntimeError("error 1")),
|
||||
self.async_future(2)]
|
||||
|
||||
# Exception logging may be explicitly quieted.
|
||||
with self.assertRaises(RuntimeError):
|
||||
yield gen.multi_future(
|
||||
[self.async_exception(RuntimeError("error 1")),
|
||||
self.async_exception(RuntimeError("error 2"))],
|
||||
quiet_exceptions=RuntimeError)
|
||||
|
||||
def test_sync_raise_return(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
raise gen.Return()
|
||||
|
||||
self.io_loop.run_sync(f)
|
||||
|
||||
def test_async_raise_return(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield gen.moment
|
||||
raise gen.Return()
|
||||
|
||||
self.io_loop.run_sync(f)
|
||||
|
||||
def test_sync_raise_return_value(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
raise gen.Return(42)
|
||||
|
||||
self.assertEqual(42, self.io_loop.run_sync(f))
|
||||
|
||||
def test_sync_raise_return_value_tuple(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
raise gen.Return((1, 2))
|
||||
|
||||
self.assertEqual((1, 2), self.io_loop.run_sync(f))
|
||||
|
||||
def test_async_raise_return_value(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield gen.moment
|
||||
raise gen.Return(42)
|
||||
|
||||
self.assertEqual(42, self.io_loop.run_sync(f))
|
||||
|
||||
def test_async_raise_return_value_tuple(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield gen.moment
|
||||
raise gen.Return((1, 2))
|
||||
|
||||
self.assertEqual((1, 2), self.io_loop.run_sync(f))
|
||||
|
||||
|
||||
class GenCoroutineTest(AsyncTestCase):
|
||||
def setUp(self):
|
||||
# Stray StopIteration exceptions can lead to tests exiting prematurely,
|
||||
|
|
@ -657,6 +911,28 @@ class GenCoroutineTest(AsyncTestCase):
|
|||
super(GenCoroutineTest, self).tearDown()
|
||||
assert self.finished
|
||||
|
||||
def test_attributes(self):
|
||||
self.finished = True
|
||||
|
||||
def f():
|
||||
yield gen.moment
|
||||
|
||||
coro = gen.coroutine(f)
|
||||
self.assertEqual(coro.__name__, f.__name__)
|
||||
self.assertEqual(coro.__module__, f.__module__)
|
||||
self.assertIs(coro.__wrapped__, f)
|
||||
|
||||
def test_is_coroutine_function(self):
|
||||
self.finished = True
|
||||
|
||||
def f():
|
||||
yield gen.moment
|
||||
|
||||
coro = gen.coroutine(f)
|
||||
self.assertFalse(gen.is_coroutine_function(f))
|
||||
self.assertTrue(gen.is_coroutine_function(coro))
|
||||
self.assertFalse(gen.is_coroutine_function(coro()))
|
||||
|
||||
@gen_test
|
||||
def test_sync_gen_return(self):
|
||||
@gen.coroutine
|
||||
|
|
@ -670,7 +946,7 @@ class GenCoroutineTest(AsyncTestCase):
|
|||
def test_async_gen_return(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield gen.Task(self.io_loop.add_callback)
|
||||
yield gen.moment
|
||||
raise gen.Return(42)
|
||||
result = yield f()
|
||||
self.assertEqual(result, 42)
|
||||
|
|
@ -691,7 +967,7 @@ class GenCoroutineTest(AsyncTestCase):
|
|||
namespace = exec_test(globals(), locals(), """
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield gen.Task(self.io_loop.add_callback)
|
||||
yield gen.moment
|
||||
return 42
|
||||
""")
|
||||
result = yield namespace['f']()
|
||||
|
|
@ -718,12 +994,32 @@ class GenCoroutineTest(AsyncTestCase):
|
|||
@skipBefore35
|
||||
@gen_test
|
||||
def test_async_await(self):
|
||||
@gen.coroutine
|
||||
def f1():
|
||||
yield gen.moment
|
||||
raise gen.Return(42)
|
||||
|
||||
# This test verifies that an async function can await a
|
||||
# yield-based gen.coroutine, and that a gen.coroutine
|
||||
# (the test method itself) can yield an async function.
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
async def f2():
|
||||
result = await f1()
|
||||
return result
|
||||
""")
|
||||
result = yield namespace['f2']()
|
||||
self.assertEqual(result, 42)
|
||||
self.finished = True
|
||||
|
||||
@skipBefore35
|
||||
@gen_test
|
||||
def test_asyncio_sleep_zero(self):
|
||||
# asyncio.sleep(0) turns into a special case (equivalent to
|
||||
# `yield None`)
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
async def f():
|
||||
await gen.Task(self.io_loop.add_callback)
|
||||
import asyncio
|
||||
await asyncio.sleep(0)
|
||||
return 42
|
||||
""")
|
||||
result = yield namespace['f']()
|
||||
|
|
@ -733,18 +1029,22 @@ class GenCoroutineTest(AsyncTestCase):
|
|||
@skipBefore35
|
||||
@gen_test
|
||||
def test_async_await_mixed_multi_native_future(self):
|
||||
@gen.coroutine
|
||||
def f1():
|
||||
yield gen.moment
|
||||
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
async def f1():
|
||||
await gen.Task(self.io_loop.add_callback)
|
||||
async def f2():
|
||||
await f1()
|
||||
return 42
|
||||
""")
|
||||
|
||||
@gen.coroutine
|
||||
def f2():
|
||||
yield gen.Task(self.io_loop.add_callback)
|
||||
def f3():
|
||||
yield gen.moment
|
||||
raise gen.Return(43)
|
||||
|
||||
results = yield [namespace['f1'](), f2()]
|
||||
results = yield [namespace['f2'](), f3()]
|
||||
self.assertEqual(results, [42, 43])
|
||||
self.finished = True
|
||||
|
||||
|
|
@ -762,11 +1062,25 @@ class GenCoroutineTest(AsyncTestCase):
|
|||
yield gen.Task(self.io_loop.add_callback)
|
||||
raise gen.Return(43)
|
||||
|
||||
f2(callback=(yield gen.Callback('cb')))
|
||||
results = yield [namespace['f1'](), gen.Wait('cb')]
|
||||
with ignore_deprecation():
|
||||
f2(callback=(yield gen.Callback('cb')))
|
||||
results = yield [namespace['f1'](), gen.Wait('cb')]
|
||||
self.assertEqual(results, [42, 43])
|
||||
self.finished = True
|
||||
|
||||
@skipBefore35
|
||||
@gen_test
|
||||
def test_async_with_timeout(self):
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
async def f1():
|
||||
return 42
|
||||
""")
|
||||
|
||||
result = yield gen.with_timeout(datetime.timedelta(hours=1),
|
||||
namespace['f1']())
|
||||
self.assertEqual(result, 42)
|
||||
self.finished = True
|
||||
|
||||
@gen_test
|
||||
def test_sync_return_no_value(self):
|
||||
@gen.coroutine
|
||||
|
|
@ -781,7 +1095,7 @@ class GenCoroutineTest(AsyncTestCase):
|
|||
# Without a return value we don't need python 3.3.
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield gen.Task(self.io_loop.add_callback)
|
||||
yield gen.moment
|
||||
return
|
||||
result = yield f()
|
||||
self.assertEqual(result, None)
|
||||
|
|
@ -804,7 +1118,7 @@ class GenCoroutineTest(AsyncTestCase):
|
|||
def test_async_raise(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield gen.Task(self.io_loop.add_callback)
|
||||
yield gen.moment
|
||||
1 / 0
|
||||
future = f()
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
|
|
@ -813,10 +1127,11 @@ class GenCoroutineTest(AsyncTestCase):
|
|||
|
||||
@gen_test
|
||||
def test_pass_callback(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
raise gen.Return(42)
|
||||
result = yield gen.Task(f)
|
||||
with ignore_deprecation():
|
||||
@gen.coroutine
|
||||
def f():
|
||||
raise gen.Return(42)
|
||||
result = yield gen.Task(f)
|
||||
self.assertEqual(result, 42)
|
||||
self.finished = True
|
||||
|
||||
|
|
@ -861,46 +1176,48 @@ class GenCoroutineTest(AsyncTestCase):
|
|||
|
||||
@gen_test
|
||||
def test_replace_context_exception(self):
|
||||
# Test exception handling: exceptions thrown into the stack context
|
||||
# can be caught and replaced.
|
||||
# Note that this test and the following are for behavior that is
|
||||
# not really supported any more: coroutines no longer create a
|
||||
# stack context automatically; but one is created after the first
|
||||
# YieldPoint (i.e. not a Future).
|
||||
@gen.coroutine
|
||||
def f2():
|
||||
(yield gen.Callback(1))()
|
||||
yield gen.Wait(1)
|
||||
self.io_loop.add_callback(lambda: 1 / 0)
|
||||
try:
|
||||
yield gen.Task(self.io_loop.add_timeout,
|
||||
self.io_loop.time() + 10)
|
||||
except ZeroDivisionError:
|
||||
raise KeyError()
|
||||
with ignore_deprecation():
|
||||
# Test exception handling: exceptions thrown into the stack context
|
||||
# can be caught and replaced.
|
||||
# Note that this test and the following are for behavior that is
|
||||
# not really supported any more: coroutines no longer create a
|
||||
# stack context automatically; but one is created after the first
|
||||
# YieldPoint (i.e. not a Future).
|
||||
@gen.coroutine
|
||||
def f2():
|
||||
(yield gen.Callback(1))()
|
||||
yield gen.Wait(1)
|
||||
self.io_loop.add_callback(lambda: 1 / 0)
|
||||
try:
|
||||
yield gen.Task(self.io_loop.add_timeout,
|
||||
self.io_loop.time() + 10)
|
||||
except ZeroDivisionError:
|
||||
raise KeyError()
|
||||
|
||||
future = f2()
|
||||
with self.assertRaises(KeyError):
|
||||
yield future
|
||||
self.finished = True
|
||||
future = f2()
|
||||
with self.assertRaises(KeyError):
|
||||
yield future
|
||||
self.finished = True
|
||||
|
||||
@gen_test
|
||||
def test_swallow_context_exception(self):
|
||||
# Test exception handling: exceptions thrown into the stack context
|
||||
# can be caught and ignored.
|
||||
@gen.coroutine
|
||||
def f2():
|
||||
(yield gen.Callback(1))()
|
||||
yield gen.Wait(1)
|
||||
self.io_loop.add_callback(lambda: 1 / 0)
|
||||
try:
|
||||
yield gen.Task(self.io_loop.add_timeout,
|
||||
self.io_loop.time() + 10)
|
||||
except ZeroDivisionError:
|
||||
raise gen.Return(42)
|
||||
with ignore_deprecation():
|
||||
# Test exception handling: exceptions thrown into the stack context
|
||||
# can be caught and ignored.
|
||||
@gen.coroutine
|
||||
def f2():
|
||||
(yield gen.Callback(1))()
|
||||
yield gen.Wait(1)
|
||||
self.io_loop.add_callback(lambda: 1 / 0)
|
||||
try:
|
||||
yield gen.Task(self.io_loop.add_timeout,
|
||||
self.io_loop.time() + 10)
|
||||
except ZeroDivisionError:
|
||||
raise gen.Return(42)
|
||||
|
||||
result = yield f2()
|
||||
self.assertEqual(result, 42)
|
||||
self.finished = True
|
||||
result = yield f2()
|
||||
self.assertEqual(result, 42)
|
||||
self.finished = True
|
||||
|
||||
@gen_test
|
||||
def test_moment(self):
|
||||
|
|
@ -957,77 +1274,132 @@ class GenCoroutineTest(AsyncTestCase):
|
|||
|
||||
self.finished = True
|
||||
|
||||
@skipNotCPython
|
||||
@unittest.skipIf((3,) < sys.version_info < (3, 6),
|
||||
"asyncio.Future has reference cycles")
|
||||
def test_coroutine_refcounting(self):
|
||||
# On CPython, tasks and their arguments should be released immediately
|
||||
# without waiting for garbage collection.
|
||||
@gen.coroutine
|
||||
def inner():
|
||||
class Foo(object):
|
||||
pass
|
||||
local_var = Foo()
|
||||
self.local_ref = weakref.ref(local_var)
|
||||
yield gen.coroutine(lambda: None)()
|
||||
raise ValueError('Some error')
|
||||
|
||||
@gen.coroutine
|
||||
def inner2():
|
||||
try:
|
||||
yield inner()
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
self.io_loop.run_sync(inner2, timeout=3)
|
||||
|
||||
self.assertIs(self.local_ref(), None)
|
||||
self.finished = True
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3,),
|
||||
"test only relevant with asyncio Futures")
|
||||
def test_asyncio_future_debug_info(self):
|
||||
self.finished = True
|
||||
# Enable debug mode
|
||||
asyncio_loop = asyncio.get_event_loop()
|
||||
self.addCleanup(asyncio_loop.set_debug, asyncio_loop.get_debug())
|
||||
asyncio_loop.set_debug(True)
|
||||
|
||||
def f():
|
||||
yield gen.moment
|
||||
|
||||
coro = gen.coroutine(f)()
|
||||
self.assertIsInstance(coro, asyncio.Future)
|
||||
# We expect the coroutine repr() to show the place where
|
||||
# it was instantiated
|
||||
expected = ("created at %s:%d"
|
||||
% (__file__, f.__code__.co_firstlineno + 3))
|
||||
actual = repr(coro)
|
||||
self.assertIn(expected, actual)
|
||||
|
||||
@unittest.skipIf(asyncio is None, "asyncio module not present")
|
||||
@gen_test
|
||||
def test_asyncio_gather(self):
|
||||
# This demonstrates that tornado coroutines can be understood
|
||||
# by asyncio (This failed prior to Tornado 5.0).
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield gen.moment
|
||||
raise gen.Return(1)
|
||||
|
||||
ret = yield asyncio.gather(f(), f())
|
||||
self.assertEqual(ret, [1, 1])
|
||||
self.finished = True
|
||||
|
||||
|
||||
class GenSequenceHandler(RequestHandler):
|
||||
@asynchronous
|
||||
@gen.engine
|
||||
def get(self):
|
||||
self.io_loop = self.request.connection.stream.io_loop
|
||||
self.io_loop.add_callback((yield gen.Callback("k1")))
|
||||
yield gen.Wait("k1")
|
||||
self.write("1")
|
||||
self.io_loop.add_callback((yield gen.Callback("k2")))
|
||||
yield gen.Wait("k2")
|
||||
self.write("2")
|
||||
# reuse an old key
|
||||
self.io_loop.add_callback((yield gen.Callback("k1")))
|
||||
yield gen.Wait("k1")
|
||||
self.finish("3")
|
||||
with ignore_deprecation():
|
||||
@asynchronous
|
||||
@gen.engine
|
||||
def get(self):
|
||||
# The outer ignore_deprecation applies at definition time.
|
||||
# We need another for serving time.
|
||||
with ignore_deprecation():
|
||||
self.io_loop = self.request.connection.stream.io_loop
|
||||
self.io_loop.add_callback((yield gen.Callback("k1")))
|
||||
yield gen.Wait("k1")
|
||||
self.write("1")
|
||||
self.io_loop.add_callback((yield gen.Callback("k2")))
|
||||
yield gen.Wait("k2")
|
||||
self.write("2")
|
||||
# reuse an old key
|
||||
self.io_loop.add_callback((yield gen.Callback("k1")))
|
||||
yield gen.Wait("k1")
|
||||
self.finish("3")
|
||||
|
||||
|
||||
class GenCoroutineSequenceHandler(RequestHandler):
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
self.io_loop = self.request.connection.stream.io_loop
|
||||
self.io_loop.add_callback((yield gen.Callback("k1")))
|
||||
yield gen.Wait("k1")
|
||||
yield gen.moment
|
||||
self.write("1")
|
||||
self.io_loop.add_callback((yield gen.Callback("k2")))
|
||||
yield gen.Wait("k2")
|
||||
yield gen.moment
|
||||
self.write("2")
|
||||
# reuse an old key
|
||||
self.io_loop.add_callback((yield gen.Callback("k1")))
|
||||
yield gen.Wait("k1")
|
||||
yield gen.moment
|
||||
self.finish("3")
|
||||
|
||||
|
||||
class GenCoroutineUnfinishedSequenceHandler(RequestHandler):
|
||||
@asynchronous
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
self.io_loop = self.request.connection.stream.io_loop
|
||||
self.io_loop.add_callback((yield gen.Callback("k1")))
|
||||
yield gen.Wait("k1")
|
||||
yield gen.moment
|
||||
self.write("1")
|
||||
self.io_loop.add_callback((yield gen.Callback("k2")))
|
||||
yield gen.Wait("k2")
|
||||
yield gen.moment
|
||||
self.write("2")
|
||||
# reuse an old key
|
||||
self.io_loop.add_callback((yield gen.Callback("k1")))
|
||||
yield gen.Wait("k1")
|
||||
yield gen.moment
|
||||
# just write, don't finish
|
||||
self.write("3")
|
||||
|
||||
|
||||
class GenTaskHandler(RequestHandler):
|
||||
@asynchronous
|
||||
@gen.engine
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
io_loop = self.request.connection.stream.io_loop
|
||||
client = AsyncHTTPClient(io_loop=io_loop)
|
||||
response = yield gen.Task(client.fetch, self.get_argument('url'))
|
||||
client = AsyncHTTPClient()
|
||||
with ignore_deprecation():
|
||||
response = yield gen.Task(client.fetch, self.get_argument('url'))
|
||||
response.rethrow()
|
||||
self.finish(b"got response: " + response.body)
|
||||
|
||||
|
||||
class GenExceptionHandler(RequestHandler):
|
||||
@asynchronous
|
||||
@gen.engine
|
||||
def get(self):
|
||||
# This test depends on the order of the two decorators.
|
||||
io_loop = self.request.connection.stream.io_loop
|
||||
yield gen.Task(io_loop.add_callback)
|
||||
raise Exception("oops")
|
||||
with ignore_deprecation():
|
||||
@asynchronous
|
||||
@gen.engine
|
||||
def get(self):
|
||||
# This test depends on the order of the two decorators.
|
||||
io_loop = self.request.connection.stream.io_loop
|
||||
yield gen.Task(io_loop.add_callback)
|
||||
raise Exception("oops")
|
||||
|
||||
|
||||
class GenCoroutineExceptionHandler(RequestHandler):
|
||||
|
|
@ -1040,19 +1412,18 @@ class GenCoroutineExceptionHandler(RequestHandler):
|
|||
|
||||
|
||||
class GenYieldExceptionHandler(RequestHandler):
|
||||
@asynchronous
|
||||
@gen.engine
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
io_loop = self.request.connection.stream.io_loop
|
||||
# Test the interaction of the two stack_contexts.
|
||||
|
||||
def fail_task(callback):
|
||||
io_loop.add_callback(lambda: 1 / 0)
|
||||
try:
|
||||
yield gen.Task(fail_task)
|
||||
raise Exception("did not get expected exception")
|
||||
except ZeroDivisionError:
|
||||
self.finish('ok')
|
||||
with ignore_deprecation():
|
||||
def fail_task(callback):
|
||||
io_loop.add_callback(lambda: 1 / 0)
|
||||
try:
|
||||
yield gen.Task(fail_task)
|
||||
raise Exception("did not get expected exception")
|
||||
except ZeroDivisionError:
|
||||
self.finish('ok')
|
||||
|
||||
|
||||
# "Undecorated" here refers to the absence of @asynchronous.
|
||||
|
|
@ -1060,22 +1431,22 @@ class UndecoratedCoroutinesHandler(RequestHandler):
|
|||
@gen.coroutine
|
||||
def prepare(self):
|
||||
self.chunks = []
|
||||
yield gen.Task(IOLoop.current().add_callback)
|
||||
yield gen.moment
|
||||
self.chunks.append('1')
|
||||
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
self.chunks.append('2')
|
||||
yield gen.Task(IOLoop.current().add_callback)
|
||||
yield gen.moment
|
||||
self.chunks.append('3')
|
||||
yield gen.Task(IOLoop.current().add_callback)
|
||||
yield gen.moment
|
||||
self.write(''.join(self.chunks))
|
||||
|
||||
|
||||
class AsyncPrepareErrorHandler(RequestHandler):
|
||||
@gen.coroutine
|
||||
def prepare(self):
|
||||
yield gen.Task(IOLoop.current().add_callback)
|
||||
yield gen.moment
|
||||
raise HTTPError(403)
|
||||
|
||||
def get(self):
|
||||
|
|
@ -1086,7 +1457,8 @@ class NativeCoroutineHandler(RequestHandler):
|
|||
if sys.version_info > (3, 5):
|
||||
exec(textwrap.dedent("""
|
||||
async def get(self):
|
||||
await gen.Task(IOLoop.current().add_callback)
|
||||
import asyncio
|
||||
await asyncio.sleep(0)
|
||||
self.write("ok")
|
||||
"""))
|
||||
|
||||
|
|
@ -1167,7 +1539,7 @@ class WithTimeoutTest(AsyncTestCase):
|
|||
self.io_loop.add_timeout(datetime.timedelta(seconds=0.1),
|
||||
lambda: future.set_result('asdf'))
|
||||
result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
|
||||
future, io_loop=self.io_loop)
|
||||
future)
|
||||
self.assertEqual(result, 'asdf')
|
||||
|
||||
@gen_test
|
||||
|
|
@ -1178,19 +1550,20 @@ class WithTimeoutTest(AsyncTestCase):
|
|||
lambda: future.set_exception(ZeroDivisionError()))
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
yield gen.with_timeout(datetime.timedelta(seconds=3600),
|
||||
future, io_loop=self.io_loop)
|
||||
future)
|
||||
|
||||
@gen_test
|
||||
def test_already_resolved(self):
|
||||
future = Future()
|
||||
future.set_result('asdf')
|
||||
result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
|
||||
future, io_loop=self.io_loop)
|
||||
future)
|
||||
self.assertEqual(result, 'asdf')
|
||||
|
||||
@unittest.skipIf(futures is None, 'futures module not present')
|
||||
@gen_test
|
||||
def test_timeout_concurrent_future(self):
|
||||
# A concurrent future that does not resolve before the timeout.
|
||||
with futures.ThreadPoolExecutor(1) as executor:
|
||||
with self.assertRaises(gen.TimeoutError):
|
||||
yield gen.with_timeout(self.io_loop.time(),
|
||||
|
|
@ -1199,9 +1572,20 @@ class WithTimeoutTest(AsyncTestCase):
|
|||
@unittest.skipIf(futures is None, 'futures module not present')
|
||||
@gen_test
|
||||
def test_completed_concurrent_future(self):
|
||||
# A concurrent future that is resolved before we even submit it
|
||||
# to with_timeout.
|
||||
with futures.ThreadPoolExecutor(1) as executor:
|
||||
f = executor.submit(lambda: None)
|
||||
f.result() # wait for completion
|
||||
yield gen.with_timeout(datetime.timedelta(seconds=3600), f)
|
||||
|
||||
@unittest.skipIf(futures is None, 'futures module not present')
|
||||
@gen_test
|
||||
def test_normal_concurrent_future(self):
|
||||
# A conccurrent future that resolves while waiting for the timeout.
|
||||
with futures.ThreadPoolExecutor(1) as executor:
|
||||
yield gen.with_timeout(datetime.timedelta(seconds=3600),
|
||||
executor.submit(lambda: None))
|
||||
executor.submit(lambda: time.sleep(0.01)))
|
||||
|
||||
|
||||
class WaitIteratorTest(AsyncTestCase):
|
||||
|
|
@ -1355,5 +1739,124 @@ class WaitIteratorTest(AsyncTestCase):
|
|||
gen.WaitIterator(gen.sleep(0)).next())
|
||||
|
||||
|
||||
class RunnerGCTest(AsyncTestCase):
|
||||
def is_pypy3(self):
|
||||
return (platform.python_implementation() == 'PyPy' and
|
||||
sys.version_info > (3,))
|
||||
|
||||
@gen_test
|
||||
def test_gc(self):
|
||||
# Github issue 1769: Runner objects can get GCed unexpectedly
|
||||
# while their future is alive.
|
||||
weakref_scope = [None]
|
||||
|
||||
def callback():
|
||||
gc.collect(2)
|
||||
weakref_scope[0]().set_result(123)
|
||||
|
||||
@gen.coroutine
|
||||
def tester():
|
||||
fut = Future()
|
||||
weakref_scope[0] = weakref.ref(fut)
|
||||
self.io_loop.add_callback(callback)
|
||||
yield fut
|
||||
|
||||
yield gen.with_timeout(
|
||||
datetime.timedelta(seconds=0.2),
|
||||
tester()
|
||||
)
|
||||
|
||||
def test_gc_infinite_coro(self):
|
||||
# Github issue 2229: suspended coroutines should be GCed when
|
||||
# their loop is closed, even if they're involved in a reference
|
||||
# cycle.
|
||||
if IOLoop.configured_class().__name__.endswith('TwistedIOLoop'):
|
||||
raise unittest.SkipTest("Test may fail on TwistedIOLoop")
|
||||
|
||||
loop = self.get_new_ioloop()
|
||||
result = []
|
||||
wfut = []
|
||||
|
||||
@gen.coroutine
|
||||
def infinite_coro():
|
||||
try:
|
||||
while True:
|
||||
yield gen.sleep(1e-3)
|
||||
result.append(True)
|
||||
finally:
|
||||
# coroutine finalizer
|
||||
result.append(None)
|
||||
|
||||
@gen.coroutine
|
||||
def do_something():
|
||||
fut = infinite_coro()
|
||||
fut._refcycle = fut
|
||||
wfut.append(weakref.ref(fut))
|
||||
yield gen.sleep(0.2)
|
||||
|
||||
loop.run_sync(do_something)
|
||||
loop.close()
|
||||
gc.collect()
|
||||
# Future was collected
|
||||
self.assertIs(wfut[0](), None)
|
||||
# At least one wakeup
|
||||
self.assertGreaterEqual(len(result), 2)
|
||||
if not self.is_pypy3():
|
||||
# coroutine finalizer was called (not on PyPy3 apparently)
|
||||
self.assertIs(result[-1], None)
|
||||
|
||||
@skipBefore35
|
||||
def test_gc_infinite_async_await(self):
|
||||
# Same as test_gc_infinite_coro, but with a `async def` function
|
||||
import asyncio
|
||||
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
async def infinite_coro(result):
|
||||
try:
|
||||
while True:
|
||||
await gen.sleep(1e-3)
|
||||
result.append(True)
|
||||
finally:
|
||||
# coroutine finalizer
|
||||
result.append(None)
|
||||
""")
|
||||
|
||||
infinite_coro = namespace['infinite_coro']
|
||||
loop = self.get_new_ioloop()
|
||||
result = []
|
||||
wfut = []
|
||||
|
||||
@gen.coroutine
|
||||
def do_something():
|
||||
fut = asyncio.get_event_loop().create_task(infinite_coro(result))
|
||||
fut._refcycle = fut
|
||||
wfut.append(weakref.ref(fut))
|
||||
yield gen.sleep(0.2)
|
||||
|
||||
loop.run_sync(do_something)
|
||||
with ExpectLog('asyncio', "Task was destroyed but it is pending"):
|
||||
loop.close()
|
||||
gc.collect()
|
||||
# Future was collected
|
||||
self.assertIs(wfut[0](), None)
|
||||
# At least one wakeup and one finally
|
||||
self.assertGreaterEqual(len(result), 2)
|
||||
if not self.is_pypy3():
|
||||
# coroutine finalizer was called (not on PyPy3 apparently)
|
||||
self.assertIs(result[-1], None)
|
||||
|
||||
def test_multi_moment(self):
|
||||
# Test gen.multi with moment
|
||||
# now that it's not a real Future
|
||||
@gen.coroutine
|
||||
def wait_a_moment():
|
||||
result = yield gen.multi([gen.moment, gen.moment])
|
||||
raise gen.Return(result)
|
||||
|
||||
loop = self.get_new_ioloop()
|
||||
result = loop.run_sync(wait_a_moment)
|
||||
self.assertEqual(result, [None, None])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,61 @@
|
|||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import socket
|
||||
|
||||
from tornado.http1connection import HTTP1Connection
|
||||
from tornado.httputil import HTTPMessageDelegate
|
||||
from tornado.iostream import IOStream
|
||||
from tornado.locks import Event
|
||||
from tornado.netutil import add_accept_handler
|
||||
from tornado.testing import AsyncTestCase, bind_unused_port, gen_test
|
||||
|
||||
|
||||
class HTTP1ConnectionTest(AsyncTestCase):
|
||||
def setUp(self):
|
||||
super(HTTP1ConnectionTest, self).setUp()
|
||||
self.asyncSetUp()
|
||||
|
||||
@gen_test
|
||||
def asyncSetUp(self):
|
||||
listener, port = bind_unused_port()
|
||||
event = Event()
|
||||
|
||||
def accept_callback(conn, addr):
|
||||
self.server_stream = IOStream(conn)
|
||||
self.addCleanup(self.server_stream.close)
|
||||
event.set()
|
||||
|
||||
add_accept_handler(listener, accept_callback)
|
||||
self.client_stream = IOStream(socket.socket())
|
||||
self.addCleanup(self.client_stream.close)
|
||||
yield [self.client_stream.connect(('127.0.0.1', port)),
|
||||
event.wait()]
|
||||
self.io_loop.remove_handler(listener)
|
||||
listener.close()
|
||||
|
||||
@gen_test
|
||||
def test_http10_no_content_length(self):
|
||||
# Regression test for a bug in which can_keep_alive would crash
|
||||
# for an HTTP/1.0 (not 1.1) response with no content-length.
|
||||
conn = HTTP1Connection(self.client_stream, True)
|
||||
self.server_stream.write(b"HTTP/1.0 200 Not Modified\r\n\r\nhello")
|
||||
self.server_stream.close()
|
||||
|
||||
event = Event()
|
||||
test = self
|
||||
body = []
|
||||
|
||||
class Delegate(HTTPMessageDelegate):
|
||||
def headers_received(self, start_line, headers):
|
||||
test.code = start_line.code
|
||||
|
||||
def data_received(self, data):
|
||||
body.append(data)
|
||||
|
||||
def finish(self):
|
||||
event.set()
|
||||
|
||||
yield conn.read_response(Delegate())
|
||||
yield event.wait()
|
||||
self.assertEqual(self.code, 200)
|
||||
self.assertEqual(b''.join(body), b'hello')
|
||||
|
|
@ -1,18 +1,18 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
from contextlib import closing
|
||||
import copy
|
||||
import functools
|
||||
import sys
|
||||
import threading
|
||||
import datetime
|
||||
from io import BytesIO
|
||||
import time
|
||||
import unicodedata
|
||||
|
||||
from tornado.escape import utf8
|
||||
from tornado.escape import utf8, native_str
|
||||
from tornado import gen
|
||||
from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient
|
||||
from tornado.httpserver import HTTPServer
|
||||
|
|
@ -22,8 +22,7 @@ from tornado.log import gen_log
|
|||
from tornado import netutil
|
||||
from tornado.stack_context import ExceptionStackContext, NullContext
|
||||
from tornado.testing import AsyncHTTPTestCase, bind_unused_port, gen_test, ExpectLog
|
||||
from tornado.test.util import unittest, skipOnTravis
|
||||
from tornado.util import u
|
||||
from tornado.test.util import unittest, skipOnTravis, ignore_deprecation
|
||||
from tornado.web import Application, RequestHandler, url
|
||||
from tornado.httputil import format_timestamp, HTTPHeaders
|
||||
|
||||
|
|
@ -114,6 +113,15 @@ class AllMethodsHandler(RequestHandler):
|
|||
|
||||
get = post = put = delete = options = patch = other = method
|
||||
|
||||
|
||||
class SetHeaderHandler(RequestHandler):
|
||||
def get(self):
|
||||
# Use get_arguments for keys to get strings, but
|
||||
# request.arguments for values to get bytes.
|
||||
for k, v in zip(self.get_arguments('k'),
|
||||
self.request.arguments['v']):
|
||||
self.set_header(k, v)
|
||||
|
||||
# These tests end up getting run redundantly: once here with the default
|
||||
# HTTPClient implementation, and then again in each implementation's own
|
||||
# test suite.
|
||||
|
|
@ -134,6 +142,7 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase):
|
|||
url("/304_with_content_length", ContentLength304Handler),
|
||||
url("/all_methods", AllMethodsHandler),
|
||||
url('/patch', PatchHandler),
|
||||
url('/set_header', SetHeaderHandler),
|
||||
], gzip=True)
|
||||
|
||||
def test_patch_receives_payload(self):
|
||||
|
|
@ -183,10 +192,15 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase):
|
|||
# over several ioloop iterations, but the connection is already closed.
|
||||
sock, port = bind_unused_port()
|
||||
with closing(sock):
|
||||
def write_response(stream, request_data):
|
||||
@gen.coroutine
|
||||
def accept_callback(conn, address):
|
||||
# fake an HTTP server using chunked encoding where the final chunks
|
||||
# and connection close all happen at once
|
||||
stream = IOStream(conn)
|
||||
request_data = yield stream.read_until(b"\r\n\r\n")
|
||||
if b"HTTP/1." not in request_data:
|
||||
self.skipTest("requires HTTP/1.x")
|
||||
stream.write(b"""\
|
||||
yield stream.write(b"""\
|
||||
HTTP/1.1 200 OK
|
||||
Transfer-Encoding: chunked
|
||||
|
||||
|
|
@ -196,17 +210,10 @@ Transfer-Encoding: chunked
|
|||
2
|
||||
0
|
||||
|
||||
""".replace(b"\n", b"\r\n"), callback=stream.close)
|
||||
|
||||
def accept_callback(conn, address):
|
||||
# fake an HTTP server using chunked encoding where the final chunks
|
||||
# and connection close all happen at once
|
||||
stream = IOStream(conn, io_loop=self.io_loop)
|
||||
stream.read_until(b"\r\n\r\n",
|
||||
functools.partial(write_response, stream))
|
||||
netutil.add_accept_handler(sock, accept_callback, self.io_loop)
|
||||
self.http_client.fetch("http://127.0.0.1:%d/" % port, self.stop)
|
||||
resp = self.wait()
|
||||
""".replace(b"\n", b"\r\n"))
|
||||
stream.close()
|
||||
netutil.add_accept_handler(sock, accept_callback)
|
||||
resp = self.fetch("http://127.0.0.1:%d/" % port)
|
||||
resp.rethrow()
|
||||
self.assertEqual(resp.body, b"12")
|
||||
self.io_loop.remove_handler(sock.fileno())
|
||||
|
|
@ -224,14 +231,16 @@ Transfer-Encoding: chunked
|
|||
if chunk == b'qwer':
|
||||
1 / 0
|
||||
|
||||
with ExceptionStackContext(error_handler):
|
||||
self.fetch('/chunk', streaming_callback=streaming_cb)
|
||||
with ignore_deprecation():
|
||||
with ExceptionStackContext(error_handler):
|
||||
self.fetch('/chunk', streaming_callback=streaming_cb)
|
||||
|
||||
self.assertEqual(chunks, [b'asdf', b'qwer'])
|
||||
self.assertEqual(1, len(exc_info))
|
||||
self.assertIs(exc_info[0][0], ZeroDivisionError)
|
||||
|
||||
def test_basic_auth(self):
|
||||
# This test data appears in section 2 of RFC 7617.
|
||||
self.assertEqual(self.fetch("/auth", auth_username="Aladdin",
|
||||
auth_password="open sesame").body,
|
||||
b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
|
||||
|
|
@ -242,16 +251,30 @@ Transfer-Encoding: chunked
|
|||
auth_mode="basic").body,
|
||||
b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
|
||||
|
||||
def test_basic_auth_unicode(self):
|
||||
# This test data appears in section 2.1 of RFC 7617.
|
||||
self.assertEqual(self.fetch("/auth", auth_username="test",
|
||||
auth_password="123£").body,
|
||||
b"Basic dGVzdDoxMjPCow==")
|
||||
|
||||
# The standard mandates NFC. Give it a decomposed username
|
||||
# and ensure it is normalized to composed form.
|
||||
username = unicodedata.normalize("NFD", u"josé")
|
||||
self.assertEqual(self.fetch("/auth",
|
||||
auth_username=username,
|
||||
auth_password="səcrət").body,
|
||||
b"Basic am9zw6k6c8mZY3LJmXQ=")
|
||||
|
||||
def test_unsupported_auth_mode(self):
|
||||
# curl and simple clients handle errors a bit differently; the
|
||||
# important thing is that they don't fall back to basic auth
|
||||
# on an unknown mode.
|
||||
with ExpectLog(gen_log, "uncaught exception", required=False):
|
||||
with self.assertRaises((ValueError, HTTPError)):
|
||||
response = self.fetch("/auth", auth_username="Aladdin",
|
||||
auth_password="open sesame",
|
||||
auth_mode="asdf")
|
||||
response.rethrow()
|
||||
self.fetch("/auth", auth_username="Aladdin",
|
||||
auth_password="open sesame",
|
||||
auth_mode="asdf",
|
||||
raise_error=True)
|
||||
|
||||
def test_follow_redirect(self):
|
||||
response = self.fetch("/countdown/2", follow_redirects=False)
|
||||
|
|
@ -265,13 +288,12 @@ Transfer-Encoding: chunked
|
|||
|
||||
def test_credentials_in_url(self):
|
||||
url = self.get_url("/auth").replace("http://", "http://me:secret@")
|
||||
self.http_client.fetch(url, self.stop)
|
||||
response = self.wait()
|
||||
response = self.fetch(url)
|
||||
self.assertEqual(b"Basic " + base64.b64encode(b"me:secret"),
|
||||
response.body)
|
||||
|
||||
def test_body_encoding(self):
|
||||
unicode_body = u("\xe9")
|
||||
unicode_body = u"\xe9"
|
||||
byte_body = binascii.a2b_hex(b"e9")
|
||||
|
||||
# unicode string in body gets converted to utf8
|
||||
|
|
@ -291,7 +313,7 @@ Transfer-Encoding: chunked
|
|||
# break anything
|
||||
response = self.fetch("/echopost", method="POST", body=byte_body,
|
||||
headers={"Content-Type": "application/blah"},
|
||||
user_agent=u("foo"))
|
||||
user_agent=u"foo")
|
||||
self.assertEqual(response.headers["Content-Length"], "1")
|
||||
self.assertEqual(response.body, byte_body)
|
||||
|
||||
|
|
@ -341,19 +363,20 @@ Transfer-Encoding: chunked
|
|||
if header_line.lower().startswith('content-type:'):
|
||||
1 / 0
|
||||
|
||||
with ExceptionStackContext(error_handler):
|
||||
self.fetch('/chunk', header_callback=header_callback)
|
||||
with ignore_deprecation():
|
||||
with ExceptionStackContext(error_handler):
|
||||
self.fetch('/chunk', header_callback=header_callback)
|
||||
self.assertEqual(len(exc_info), 1)
|
||||
self.assertIs(exc_info[0][0], ZeroDivisionError)
|
||||
|
||||
@gen_test
|
||||
def test_configure_defaults(self):
|
||||
defaults = dict(user_agent='TestDefaultUserAgent', allow_ipv6=False)
|
||||
# Construct a new instance of the configured client class
|
||||
client = self.http_client.__class__(self.io_loop, force_instance=True,
|
||||
client = self.http_client.__class__(force_instance=True,
|
||||
defaults=defaults)
|
||||
try:
|
||||
client.fetch(self.get_url('/user_agent'), callback=self.stop)
|
||||
response = self.wait()
|
||||
response = yield client.fetch(self.get_url('/user_agent'))
|
||||
self.assertEqual(response.body, b'TestDefaultUserAgent')
|
||||
finally:
|
||||
client.close()
|
||||
|
|
@ -363,7 +386,7 @@ Transfer-Encoding: chunked
|
|||
# in a plain dictionary or an HTTPHeaders object.
|
||||
# Keys must always be the native str type.
|
||||
# All combinations should have the same results on the wire.
|
||||
for value in [u("MyUserAgent"), b"MyUserAgent"]:
|
||||
for value in [u"MyUserAgent", b"MyUserAgent"]:
|
||||
for container in [dict, HTTPHeaders]:
|
||||
headers = container()
|
||||
headers['User-Agent'] = value
|
||||
|
|
@ -378,23 +401,22 @@ Transfer-Encoding: chunked
|
|||
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
|
||||
sock, port = bind_unused_port()
|
||||
with closing(sock):
|
||||
def write_response(stream, request_data):
|
||||
@gen.coroutine
|
||||
def accept_callback(conn, address):
|
||||
stream = IOStream(conn)
|
||||
request_data = yield stream.read_until(b"\r\n\r\n")
|
||||
if b"HTTP/1." not in request_data:
|
||||
self.skipTest("requires HTTP/1.x")
|
||||
stream.write(b"""\
|
||||
yield stream.write(b"""\
|
||||
HTTP/1.1 200 OK
|
||||
X-XSS-Protection: 1;
|
||||
\tmode=block
|
||||
|
||||
""".replace(b"\n", b"\r\n"), callback=stream.close)
|
||||
""".replace(b"\n", b"\r\n"))
|
||||
stream.close()
|
||||
|
||||
def accept_callback(conn, address):
|
||||
stream = IOStream(conn, io_loop=self.io_loop)
|
||||
stream.read_until(b"\r\n\r\n",
|
||||
functools.partial(write_response, stream))
|
||||
netutil.add_accept_handler(sock, accept_callback, self.io_loop)
|
||||
self.http_client.fetch("http://127.0.0.1:%d/" % port, self.stop)
|
||||
resp = self.wait()
|
||||
netutil.add_accept_handler(sock, accept_callback)
|
||||
resp = self.fetch("http://127.0.0.1:%d/" % port)
|
||||
resp.rethrow()
|
||||
self.assertEqual(resp.headers['X-XSS-Protection'], "1; mode=block")
|
||||
self.io_loop.remove_handler(sock.fileno())
|
||||
|
|
@ -424,8 +446,9 @@ X-XSS-Protection: 1;
|
|||
self.stop()
|
||||
self.io_loop.handle_callback_exception = handle_callback_exception
|
||||
with NullContext():
|
||||
self.http_client.fetch(self.get_url('/hello'),
|
||||
lambda response: 1 / 0)
|
||||
with ignore_deprecation():
|
||||
self.http_client.fetch(self.get_url('/hello'),
|
||||
lambda response: 1 / 0)
|
||||
self.wait()
|
||||
self.assertEqual(exc_info[0][0], ZeroDivisionError)
|
||||
|
||||
|
|
@ -476,8 +499,7 @@ X-XSS-Protection: 1;
|
|||
# These methods require a body.
|
||||
for method in ('POST', 'PUT', 'PATCH'):
|
||||
with self.assertRaises(ValueError) as context:
|
||||
resp = self.fetch('/all_methods', method=method)
|
||||
resp.rethrow()
|
||||
self.fetch('/all_methods', method=method, raise_error=True)
|
||||
self.assertIn('must not be None', str(context.exception))
|
||||
|
||||
resp = self.fetch('/all_methods', method=method,
|
||||
|
|
@ -487,16 +509,14 @@ X-XSS-Protection: 1;
|
|||
# These methods don't allow a body.
|
||||
for method in ('GET', 'DELETE', 'OPTIONS'):
|
||||
with self.assertRaises(ValueError) as context:
|
||||
resp = self.fetch('/all_methods', method=method, body=b'asdf')
|
||||
resp.rethrow()
|
||||
self.fetch('/all_methods', method=method, body=b'asdf', raise_error=True)
|
||||
self.assertIn('must be None', str(context.exception))
|
||||
|
||||
# In most cases this can be overridden, but curl_httpclient
|
||||
# does not allow body with a GET at all.
|
||||
if method != 'GET':
|
||||
resp = self.fetch('/all_methods', method=method, body=b'asdf',
|
||||
allow_nonstandard_methods=True)
|
||||
resp.rethrow()
|
||||
self.fetch('/all_methods', method=method, body=b'asdf',
|
||||
allow_nonstandard_methods=True, raise_error=True)
|
||||
self.assertEqual(resp.code, 200)
|
||||
|
||||
# This test causes odd failures with the combination of
|
||||
|
|
@ -521,6 +541,28 @@ X-XSS-Protection: 1;
|
|||
response.rethrow()
|
||||
self.assertEqual(response.body, b"Put body: hello")
|
||||
|
||||
def test_non_ascii_header(self):
|
||||
# Non-ascii headers are sent as latin1.
|
||||
response = self.fetch("/set_header?k=foo&v=%E9")
|
||||
response.rethrow()
|
||||
self.assertEqual(response.headers["Foo"], native_str(u"\u00e9"))
|
||||
|
||||
def test_response_times(self):
|
||||
# A few simple sanity checks of the response time fields to
|
||||
# make sure they're using the right basis (between the
|
||||
# wall-time and monotonic clocks).
|
||||
start_time = time.time()
|
||||
response = self.fetch("/hello")
|
||||
response.rethrow()
|
||||
self.assertGreaterEqual(response.request_time, 0)
|
||||
self.assertLess(response.request_time, 1.0)
|
||||
# A very crude check to make sure that start_time is based on
|
||||
# wall time and not the monotonic clock.
|
||||
self.assertLess(abs(response.start_time - start_time), 1.0)
|
||||
|
||||
for k, v in response.time_info.items():
|
||||
self.assertTrue(0 <= v < 1.0, "time_info[%s] out of bounds: %s" % (k, v))
|
||||
|
||||
|
||||
class RequestProxyTest(unittest.TestCase):
|
||||
def test_request_set(self):
|
||||
|
|
@ -567,22 +609,20 @@ class HTTPResponseTestCase(unittest.TestCase):
|
|||
|
||||
class SyncHTTPClientTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
if IOLoop.configured_class().__name__ in ('TwistedIOLoop',
|
||||
'AsyncIOMainLoop'):
|
||||
if IOLoop.configured_class().__name__ == 'TwistedIOLoop':
|
||||
# TwistedIOLoop only supports the global reactor, so we can't have
|
||||
# separate IOLoops for client and server threads.
|
||||
# AsyncIOMainLoop doesn't work with the default policy
|
||||
# (although it could with some tweaks to this test and a
|
||||
# policy that created loops for non-main threads).
|
||||
raise unittest.SkipTest(
|
||||
'Sync HTTPClient not compatible with TwistedIOLoop or '
|
||||
'AsyncIOMainLoop')
|
||||
'Sync HTTPClient not compatible with TwistedIOLoop')
|
||||
self.server_ioloop = IOLoop()
|
||||
|
||||
sock, self.port = bind_unused_port()
|
||||
app = Application([('/', HelloWorldHandler)])
|
||||
self.server = HTTPServer(app, io_loop=self.server_ioloop)
|
||||
self.server.add_socket(sock)
|
||||
@gen.coroutine
|
||||
def init_server():
|
||||
sock, self.port = bind_unused_port()
|
||||
app = Application([('/', HelloWorldHandler)])
|
||||
self.server = HTTPServer(app)
|
||||
self.server.add_socket(sock)
|
||||
self.server_ioloop.run_sync(init_server)
|
||||
|
||||
self.server_thread = threading.Thread(target=self.server_ioloop.start)
|
||||
self.server_thread.start()
|
||||
|
|
@ -592,12 +632,20 @@ class SyncHTTPClientTest(unittest.TestCase):
|
|||
def tearDown(self):
|
||||
def stop_server():
|
||||
self.server.stop()
|
||||
# Delay the shutdown of the IOLoop by one iteration because
|
||||
# Delay the shutdown of the IOLoop by several iterations because
|
||||
# the server may still have some cleanup work left when
|
||||
# the client finishes with the response (this is noticable
|
||||
# the client finishes with the response (this is noticeable
|
||||
# with http/2, which leaves a Future with an unexamined
|
||||
# StreamClosedError on the loop).
|
||||
self.server_ioloop.add_callback(self.server_ioloop.stop)
|
||||
|
||||
@gen.coroutine
|
||||
def slow_stop():
|
||||
# The number of iterations is difficult to predict. Typically,
|
||||
# one is sufficient, although sometimes it needs more.
|
||||
for i in range(5):
|
||||
yield
|
||||
self.server_ioloop.stop()
|
||||
self.server_ioloop.add_callback(slow_stop)
|
||||
self.server_ioloop.add_callback(stop_server)
|
||||
self.server_thread.join()
|
||||
self.http_client.close()
|
||||
|
|
@ -656,6 +704,15 @@ class HTTPErrorTestCase(unittest.TestCase):
|
|||
self.assertIsNot(e, e2)
|
||||
self.assertEqual(e.code, e2.code)
|
||||
|
||||
def test_str(self):
|
||||
def test_plain_error(self):
|
||||
e = HTTPError(403)
|
||||
self.assertEqual(str(e), "HTTP 403: Forbidden")
|
||||
self.assertEqual(repr(e), "HTTP 403: Forbidden")
|
||||
|
||||
def test_error_with_response(self):
|
||||
resp = HTTPResponse(HTTPRequest('http://example.com/'), 403)
|
||||
with self.assertRaises(HTTPError) as cm:
|
||||
resp.rethrow()
|
||||
e = cm.exception
|
||||
self.assertEqual(str(e), "HTTP 403: Forbidden")
|
||||
self.assertEqual(repr(e), "HTTP 403: Forbidden")
|
||||
|
|
|
|||
|
|
@ -1,21 +1,21 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from tornado import netutil
|
||||
from tornado import gen, netutil
|
||||
from tornado.concurrent import Future
|
||||
from tornado.escape import json_decode, json_encode, utf8, _unicode, recursive_unicode, native_str
|
||||
from tornado import gen
|
||||
from tornado.http1connection import HTTP1Connection
|
||||
from tornado.httpclient import HTTPError
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine
|
||||
from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine # noqa: E501
|
||||
from tornado.iostream import IOStream
|
||||
from tornado.locks import Event
|
||||
from tornado.log import gen_log
|
||||
from tornado.netutil import ssl_options_to_context
|
||||
from tornado.simple_httpclient import SimpleAsyncHTTPClient
|
||||
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test
|
||||
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test # noqa: E501
|
||||
from tornado.test.util import unittest, skipOnTravis
|
||||
from tornado.util import u
|
||||
from tornado.web import Application, RequestHandler, asynchronous, stream_request_body
|
||||
from tornado.web import Application, RequestHandler, stream_request_body
|
||||
|
||||
from contextlib import closing
|
||||
import datetime
|
||||
import gzip
|
||||
|
|
@ -30,18 +30,20 @@ from io import BytesIO
|
|||
|
||||
def read_stream_body(stream, callback):
|
||||
"""Reads an HTTP response from `stream` and runs callback with its
|
||||
headers and body."""
|
||||
start_line, headers and body."""
|
||||
chunks = []
|
||||
|
||||
class Delegate(HTTPMessageDelegate):
|
||||
def headers_received(self, start_line, headers):
|
||||
self.headers = headers
|
||||
self.start_line = start_line
|
||||
|
||||
def data_received(self, chunk):
|
||||
chunks.append(chunk)
|
||||
|
||||
def finish(self):
|
||||
callback((self.headers, b''.join(chunks)))
|
||||
conn.detach()
|
||||
callback((self.start_line, self.headers, b''.join(chunks)))
|
||||
conn = HTTP1Connection(stream, True)
|
||||
conn.read_response(Delegate())
|
||||
|
||||
|
|
@ -87,7 +89,7 @@ class BaseSSLTest(AsyncHTTPSTestCase):
|
|||
|
||||
class SSLTestMixin(object):
|
||||
def get_ssl_options(self):
|
||||
return dict(ssl_version=self.get_ssl_version(),
|
||||
return dict(ssl_version=self.get_ssl_version(), # type: ignore
|
||||
**AsyncHTTPSTestCase.get_ssl_options())
|
||||
|
||||
def get_ssl_version(self):
|
||||
|
|
@ -109,22 +111,19 @@ class SSLTestMixin(object):
|
|||
# misbehaving.
|
||||
with ExpectLog(gen_log, '(SSL Error|uncaught exception)'):
|
||||
with ExpectLog(gen_log, 'Uncaught exception', required=False):
|
||||
self.http_client.fetch(
|
||||
self.get_url("/").replace('https:', 'http:'),
|
||||
self.stop,
|
||||
request_timeout=3600,
|
||||
connect_timeout=3600)
|
||||
response = self.wait()
|
||||
self.assertEqual(response.code, 599)
|
||||
with self.assertRaises((IOError, HTTPError)):
|
||||
self.fetch(
|
||||
self.get_url("/").replace('https:', 'http:'),
|
||||
request_timeout=3600,
|
||||
connect_timeout=3600,
|
||||
raise_error=True)
|
||||
|
||||
def test_error_logging(self):
|
||||
# No stack traces are logged for SSL errors.
|
||||
with ExpectLog(gen_log, 'SSL Error') as expect_log:
|
||||
self.http_client.fetch(
|
||||
self.get_url("/").replace("https:", "http:"),
|
||||
self.stop)
|
||||
response = self.wait()
|
||||
self.assertEqual(response.code, 599)
|
||||
with self.assertRaises((IOError, HTTPError)):
|
||||
self.fetch(self.get_url("/").replace("https:", "http:"),
|
||||
raise_error=True)
|
||||
self.assertFalse(expect_log.logged_stack)
|
||||
|
||||
# Python's SSL implementation differs significantly between versions.
|
||||
|
|
@ -150,7 +149,6 @@ class TLSv1Test(BaseSSLTest, SSLTestMixin):
|
|||
return ssl.PROTOCOL_TLSv1
|
||||
|
||||
|
||||
@unittest.skipIf(not hasattr(ssl, 'SSLContext'), 'ssl.SSLContext not present')
|
||||
class SSLContextTest(BaseSSLTest, SSLTestMixin):
|
||||
def get_ssl_options(self):
|
||||
context = ssl_options_to_context(
|
||||
|
|
@ -211,14 +209,13 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
|
|||
|
||||
def raw_fetch(self, headers, body, newline=b"\r\n"):
|
||||
with closing(IOStream(socket.socket())) as stream:
|
||||
stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
|
||||
self.wait()
|
||||
self.io_loop.run_sync(lambda: stream.connect(('127.0.0.1', self.get_http_port())))
|
||||
stream.write(
|
||||
newline.join(headers +
|
||||
[utf8("Content-Length: %d" % len(body))]) +
|
||||
newline + newline + body)
|
||||
read_stream_body(stream, self.stop)
|
||||
headers, body = self.wait()
|
||||
start_line, headers, body = self.wait()
|
||||
return body
|
||||
|
||||
def test_multipart_form(self):
|
||||
|
|
@ -232,19 +229,19 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
|
|||
b"\r\n".join([
|
||||
b"Content-Disposition: form-data; name=argument",
|
||||
b"",
|
||||
u("\u00e1").encode("utf-8"),
|
||||
u"\u00e1".encode("utf-8"),
|
||||
b"--1234567890",
|
||||
u('Content-Disposition: form-data; name="files"; filename="\u00f3"').encode("utf8"),
|
||||
u'Content-Disposition: form-data; name="files"; filename="\u00f3"'.encode("utf8"),
|
||||
b"",
|
||||
u("\u00fa").encode("utf-8"),
|
||||
u"\u00fa".encode("utf-8"),
|
||||
b"--1234567890--",
|
||||
b"",
|
||||
]))
|
||||
data = json_decode(response)
|
||||
self.assertEqual(u("\u00e9"), data["header"])
|
||||
self.assertEqual(u("\u00e1"), data["argument"])
|
||||
self.assertEqual(u("\u00f3"), data["filename"])
|
||||
self.assertEqual(u("\u00fa"), data["filebody"])
|
||||
self.assertEqual(u"\u00e9", data["header"])
|
||||
self.assertEqual(u"\u00e1", data["argument"])
|
||||
self.assertEqual(u"\u00f3", data["filename"])
|
||||
self.assertEqual(u"\u00fa", data["filebody"])
|
||||
|
||||
def test_newlines(self):
|
||||
# We support both CRLF and bare LF as line separators.
|
||||
|
|
@ -253,31 +250,27 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
|
|||
newline=newline)
|
||||
self.assertEqual(response, b'Hello world')
|
||||
|
||||
@gen_test
|
||||
def test_100_continue(self):
|
||||
# Run through a 100-continue interaction by hand:
|
||||
# When given Expect: 100-continue, we get a 100 response after the
|
||||
# headers, and then the real response after the body.
|
||||
stream = IOStream(socket.socket(), io_loop=self.io_loop)
|
||||
stream.connect(("127.0.0.1", self.get_http_port()), callback=self.stop)
|
||||
self.wait()
|
||||
stream.write(b"\r\n".join([b"POST /hello HTTP/1.1",
|
||||
b"Content-Length: 1024",
|
||||
b"Expect: 100-continue",
|
||||
b"Connection: close",
|
||||
b"\r\n"]), callback=self.stop)
|
||||
self.wait()
|
||||
stream.read_until(b"\r\n\r\n", self.stop)
|
||||
data = self.wait()
|
||||
stream = IOStream(socket.socket())
|
||||
yield stream.connect(("127.0.0.1", self.get_http_port()))
|
||||
yield stream.write(b"\r\n".join([
|
||||
b"POST /hello HTTP/1.1",
|
||||
b"Content-Length: 1024",
|
||||
b"Expect: 100-continue",
|
||||
b"Connection: close",
|
||||
b"\r\n"]))
|
||||
data = yield stream.read_until(b"\r\n\r\n")
|
||||
self.assertTrue(data.startswith(b"HTTP/1.1 100 "), data)
|
||||
stream.write(b"a" * 1024)
|
||||
stream.read_until(b"\r\n", self.stop)
|
||||
first_line = self.wait()
|
||||
first_line = yield stream.read_until(b"\r\n")
|
||||
self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
|
||||
stream.read_until(b"\r\n\r\n", self.stop)
|
||||
header_data = self.wait()
|
||||
header_data = yield stream.read_until(b"\r\n\r\n")
|
||||
headers = HTTPHeaders.parse(native_str(header_data.decode('latin1')))
|
||||
stream.read_bytes(int(headers["Content-Length"]), self.stop)
|
||||
body = self.wait()
|
||||
body = yield stream.read_bytes(int(headers["Content-Length"]))
|
||||
self.assertEqual(body, b"Got 1024 bytes in POST")
|
||||
stream.close()
|
||||
|
||||
|
|
@ -340,17 +333,17 @@ class HTTPServerTest(AsyncHTTPTestCase):
|
|||
def test_query_string_encoding(self):
|
||||
response = self.fetch("/echo?foo=%C3%A9")
|
||||
data = json_decode(response.body)
|
||||
self.assertEqual(data, {u("foo"): [u("\u00e9")]})
|
||||
self.assertEqual(data, {u"foo": [u"\u00e9"]})
|
||||
|
||||
def test_empty_query_string(self):
|
||||
response = self.fetch("/echo?foo=&foo=")
|
||||
data = json_decode(response.body)
|
||||
self.assertEqual(data, {u("foo"): [u(""), u("")]})
|
||||
self.assertEqual(data, {u"foo": [u"", u""]})
|
||||
|
||||
def test_empty_post_parameters(self):
|
||||
response = self.fetch("/echo", method="POST", body="foo=&bar=")
|
||||
data = json_decode(response.body)
|
||||
self.assertEqual(data, {u("foo"): [u("")], u("bar"): [u("")]})
|
||||
self.assertEqual(data, {u"foo": [u""], u"bar": [u""]})
|
||||
|
||||
def test_types(self):
|
||||
headers = {"Cookie": "foo=bar"}
|
||||
|
|
@ -395,8 +388,7 @@ class HTTPServerRawTest(AsyncHTTPTestCase):
|
|||
def setUp(self):
|
||||
super(HTTPServerRawTest, self).setUp()
|
||||
self.stream = IOStream(socket.socket())
|
||||
self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
|
||||
self.wait()
|
||||
self.io_loop.run_sync(lambda: self.stream.connect(('127.0.0.1', self.get_http_port())))
|
||||
|
||||
def tearDown(self):
|
||||
self.stream.close()
|
||||
|
|
@ -407,19 +399,28 @@ class HTTPServerRawTest(AsyncHTTPTestCase):
|
|||
self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
|
||||
self.wait()
|
||||
|
||||
def test_malformed_first_line(self):
|
||||
def test_malformed_first_line_response(self):
|
||||
with ExpectLog(gen_log, '.*Malformed HTTP request line'):
|
||||
self.stream.write(b'asdf\r\n\r\n')
|
||||
read_stream_body(self.stream, self.stop)
|
||||
start_line, headers, response = self.wait()
|
||||
self.assertEqual('HTTP/1.1', start_line.version)
|
||||
self.assertEqual(400, start_line.code)
|
||||
self.assertEqual('Bad Request', start_line.reason)
|
||||
|
||||
def test_malformed_first_line_log(self):
|
||||
with ExpectLog(gen_log, '.*Malformed HTTP request line'):
|
||||
self.stream.write(b'asdf\r\n\r\n')
|
||||
# TODO: need an async version of ExpectLog so we don't need
|
||||
# hard-coded timeouts here.
|
||||
self.io_loop.add_timeout(datetime.timedelta(seconds=0.01),
|
||||
self.io_loop.add_timeout(datetime.timedelta(seconds=0.05),
|
||||
self.stop)
|
||||
self.wait()
|
||||
|
||||
def test_malformed_headers(self):
|
||||
with ExpectLog(gen_log, '.*Malformed HTTP headers'):
|
||||
with ExpectLog(gen_log, '.*Malformed HTTP message.*no colon in header line'):
|
||||
self.stream.write(b'GET / HTTP/1.0\r\nasdf\r\n\r\n')
|
||||
self.io_loop.add_timeout(datetime.timedelta(seconds=0.01),
|
||||
self.io_loop.add_timeout(datetime.timedelta(seconds=0.05),
|
||||
self.stop)
|
||||
self.wait()
|
||||
|
||||
|
|
@ -439,18 +440,50 @@ bar
|
|||
|
||||
""".replace(b"\n", b"\r\n"))
|
||||
read_stream_body(self.stream, self.stop)
|
||||
headers, response = self.wait()
|
||||
self.assertEqual(json_decode(response), {u('foo'): [u('bar')]})
|
||||
start_line, headers, response = self.wait()
|
||||
self.assertEqual(json_decode(response), {u'foo': [u'bar']})
|
||||
|
||||
def test_chunked_request_uppercase(self):
|
||||
# As per RFC 2616 section 3.6, "Transfer-Encoding" header's value is
|
||||
# case-insensitive.
|
||||
self.stream.write(b"""\
|
||||
POST /echo HTTP/1.1
|
||||
Transfer-Encoding: Chunked
|
||||
Content-Type: application/x-www-form-urlencoded
|
||||
|
||||
4
|
||||
foo=
|
||||
3
|
||||
bar
|
||||
0
|
||||
|
||||
""".replace(b"\n", b"\r\n"))
|
||||
read_stream_body(self.stream, self.stop)
|
||||
start_line, headers, response = self.wait()
|
||||
self.assertEqual(json_decode(response), {u'foo': [u'bar']})
|
||||
|
||||
@gen_test
|
||||
def test_invalid_content_length(self):
|
||||
with ExpectLog(gen_log, '.*Only integer Content-Length is allowed'):
|
||||
self.stream.write(b"""\
|
||||
POST /echo HTTP/1.1
|
||||
Content-Length: foo
|
||||
|
||||
bar
|
||||
|
||||
""".replace(b"\n", b"\r\n"))
|
||||
yield self.stream.read_until_close()
|
||||
|
||||
|
||||
class XHeaderTest(HandlerBaseTestCase):
|
||||
class Handler(RequestHandler):
|
||||
def get(self):
|
||||
self.set_header('request-version', self.request.version)
|
||||
self.write(dict(remote_ip=self.request.remote_ip,
|
||||
remote_protocol=self.request.protocol))
|
||||
|
||||
def get_httpserver_options(self):
|
||||
return dict(xheaders=True)
|
||||
return dict(xheaders=True, trusted_downstream=['5.5.5.5'])
|
||||
|
||||
def test_ip_headers(self):
|
||||
self.assertEqual(self.fetch_json("/")["remote_ip"], "127.0.0.1")
|
||||
|
|
@ -490,6 +523,16 @@ class XHeaderTest(HandlerBaseTestCase):
|
|||
self.fetch_json("/", headers=invalid_host)["remote_ip"],
|
||||
"127.0.0.1")
|
||||
|
||||
def test_trusted_downstream(self):
|
||||
valid_ipv4_list = {"X-Forwarded-For": "127.0.0.1, 4.4.4.4, 5.5.5.5"}
|
||||
resp = self.fetch("/", headers=valid_ipv4_list)
|
||||
if resp.headers['request-version'].startswith('HTTP/2'):
|
||||
# This is a hack - there's nothing that fundamentally requires http/1
|
||||
# here but tornado_http2 doesn't support it yet.
|
||||
self.skipTest('requires HTTP/1.x')
|
||||
result = json_decode(resp.body)
|
||||
self.assertEqual(result['remote_ip'], "4.4.4.4")
|
||||
|
||||
def test_scheme_headers(self):
|
||||
self.assertEqual(self.fetch_json("/")["remote_protocol"], "http")
|
||||
|
||||
|
|
@ -503,6 +546,16 @@ class XHeaderTest(HandlerBaseTestCase):
|
|||
self.fetch_json("/", headers=https_forwarded)["remote_protocol"],
|
||||
"https")
|
||||
|
||||
https_multi_forwarded = {"X-Forwarded-Proto": "https , http"}
|
||||
self.assertEqual(
|
||||
self.fetch_json("/", headers=https_multi_forwarded)["remote_protocol"],
|
||||
"http")
|
||||
|
||||
http_multi_forwarded = {"X-Forwarded-Proto": "http,https"}
|
||||
self.assertEqual(
|
||||
self.fetch_json("/", headers=http_multi_forwarded)["remote_protocol"],
|
||||
"https")
|
||||
|
||||
bad_forwarded = {"X-Forwarded-Proto": "unknown"}
|
||||
self.assertEqual(
|
||||
self.fetch_json("/", headers=bad_forwarded)["remote_protocol"],
|
||||
|
|
@ -560,37 +613,36 @@ class UnixSocketTest(AsyncTestCase):
|
|||
self.sockfile = os.path.join(self.tmpdir, "test.sock")
|
||||
sock = netutil.bind_unix_socket(self.sockfile)
|
||||
app = Application([("/hello", HelloWorldRequestHandler)])
|
||||
self.server = HTTPServer(app, io_loop=self.io_loop)
|
||||
self.server = HTTPServer(app)
|
||||
self.server.add_socket(sock)
|
||||
self.stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop)
|
||||
self.stream.connect(self.sockfile, self.stop)
|
||||
self.wait()
|
||||
self.stream = IOStream(socket.socket(socket.AF_UNIX))
|
||||
self.io_loop.run_sync(lambda: self.stream.connect(self.sockfile))
|
||||
|
||||
def tearDown(self):
|
||||
self.stream.close()
|
||||
self.io_loop.run_sync(self.server.close_all_connections)
|
||||
self.server.stop()
|
||||
shutil.rmtree(self.tmpdir)
|
||||
super(UnixSocketTest, self).tearDown()
|
||||
|
||||
@gen_test
|
||||
def test_unix_socket(self):
|
||||
self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
|
||||
self.stream.read_until(b"\r\n", self.stop)
|
||||
response = self.wait()
|
||||
response = yield self.stream.read_until(b"\r\n")
|
||||
self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
|
||||
self.stream.read_until(b"\r\n\r\n", self.stop)
|
||||
headers = HTTPHeaders.parse(self.wait().decode('latin1'))
|
||||
self.stream.read_bytes(int(headers["Content-Length"]), self.stop)
|
||||
body = self.wait()
|
||||
header_data = yield self.stream.read_until(b"\r\n\r\n")
|
||||
headers = HTTPHeaders.parse(header_data.decode('latin1'))
|
||||
body = yield self.stream.read_bytes(int(headers["Content-Length"]))
|
||||
self.assertEqual(body, b"Hello world")
|
||||
|
||||
@gen_test
|
||||
def test_unix_socket_bad_request(self):
|
||||
# Unix sockets don't have remote addresses so they just return an
|
||||
# empty string.
|
||||
with ExpectLog(gen_log, "Malformed HTTP message from"):
|
||||
self.stream.write(b"garbage\r\n\r\n")
|
||||
self.stream.read_until_close(self.stop)
|
||||
response = self.wait()
|
||||
self.assertEqual(response, b"")
|
||||
response = yield self.stream.read_until_close()
|
||||
self.assertEqual(response, b"HTTP/1.1 400 Bad Request\r\n\r\n")
|
||||
|
||||
|
||||
class KeepAliveTest(AsyncHTTPTestCase):
|
||||
|
|
@ -614,9 +666,11 @@ class KeepAliveTest(AsyncHTTPTestCase):
|
|||
self.write(''.join(chr(i % 256) * 1024 for i in range(512)))
|
||||
|
||||
class FinishOnCloseHandler(RequestHandler):
|
||||
@asynchronous
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
self.flush()
|
||||
never_finish = Event()
|
||||
yield never_finish.wait()
|
||||
|
||||
def on_connection_close(self):
|
||||
# This is not very realistic, but finishing the request
|
||||
|
|
@ -643,119 +697,129 @@ class KeepAliveTest(AsyncHTTPTestCase):
|
|||
super(KeepAliveTest, self).tearDown()
|
||||
|
||||
# The next few methods are a crude manual http client
|
||||
@gen.coroutine
|
||||
def connect(self):
|
||||
self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
|
||||
self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
|
||||
self.wait()
|
||||
self.stream = IOStream(socket.socket())
|
||||
yield self.stream.connect(('127.0.0.1', self.get_http_port()))
|
||||
|
||||
@gen.coroutine
|
||||
def read_headers(self):
|
||||
self.stream.read_until(b'\r\n', self.stop)
|
||||
first_line = self.wait()
|
||||
first_line = yield self.stream.read_until(b'\r\n')
|
||||
self.assertTrue(first_line.startswith(b'HTTP/1.1 200'), first_line)
|
||||
self.stream.read_until(b'\r\n\r\n', self.stop)
|
||||
header_bytes = self.wait()
|
||||
header_bytes = yield self.stream.read_until(b'\r\n\r\n')
|
||||
headers = HTTPHeaders.parse(header_bytes.decode('latin1'))
|
||||
return headers
|
||||
raise gen.Return(headers)
|
||||
|
||||
@gen.coroutine
|
||||
def read_response(self):
|
||||
self.headers = self.read_headers()
|
||||
self.stream.read_bytes(int(self.headers['Content-Length']), self.stop)
|
||||
body = self.wait()
|
||||
self.headers = yield self.read_headers()
|
||||
body = yield self.stream.read_bytes(int(self.headers['Content-Length']))
|
||||
self.assertEqual(b'Hello world', body)
|
||||
|
||||
def close(self):
|
||||
self.stream.close()
|
||||
del self.stream
|
||||
|
||||
@gen_test
|
||||
def test_two_requests(self):
|
||||
self.connect()
|
||||
yield self.connect()
|
||||
self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
|
||||
self.read_response()
|
||||
yield self.read_response()
|
||||
self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
|
||||
self.read_response()
|
||||
yield self.read_response()
|
||||
self.close()
|
||||
|
||||
@gen_test
|
||||
def test_request_close(self):
|
||||
self.connect()
|
||||
yield self.connect()
|
||||
self.stream.write(b'GET / HTTP/1.1\r\nConnection: close\r\n\r\n')
|
||||
self.read_response()
|
||||
self.stream.read_until_close(callback=self.stop)
|
||||
data = self.wait()
|
||||
yield self.read_response()
|
||||
data = yield self.stream.read_until_close()
|
||||
self.assertTrue(not data)
|
||||
self.assertEqual(self.headers['Connection'], 'close')
|
||||
self.close()
|
||||
|
||||
# keepalive is supported for http 1.0 too, but it's opt-in
|
||||
@gen_test
|
||||
def test_http10(self):
|
||||
self.http_version = b'HTTP/1.0'
|
||||
self.connect()
|
||||
yield self.connect()
|
||||
self.stream.write(b'GET / HTTP/1.0\r\n\r\n')
|
||||
self.read_response()
|
||||
self.stream.read_until_close(callback=self.stop)
|
||||
data = self.wait()
|
||||
yield self.read_response()
|
||||
data = yield self.stream.read_until_close()
|
||||
self.assertTrue(not data)
|
||||
self.assertTrue('Connection' not in self.headers)
|
||||
self.close()
|
||||
|
||||
@gen_test
|
||||
def test_http10_keepalive(self):
|
||||
self.http_version = b'HTTP/1.0'
|
||||
self.connect()
|
||||
yield self.connect()
|
||||
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
|
||||
self.read_response()
|
||||
yield self.read_response()
|
||||
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
|
||||
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
|
||||
self.read_response()
|
||||
yield self.read_response()
|
||||
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
|
||||
self.close()
|
||||
|
||||
@gen_test
|
||||
def test_http10_keepalive_extra_crlf(self):
|
||||
self.http_version = b'HTTP/1.0'
|
||||
self.connect()
|
||||
yield self.connect()
|
||||
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n')
|
||||
self.read_response()
|
||||
yield self.read_response()
|
||||
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
|
||||
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
|
||||
self.read_response()
|
||||
yield self.read_response()
|
||||
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
|
||||
self.close()
|
||||
|
||||
@gen_test
|
||||
def test_pipelined_requests(self):
|
||||
self.connect()
|
||||
yield self.connect()
|
||||
self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
|
||||
self.read_response()
|
||||
self.read_response()
|
||||
yield self.read_response()
|
||||
yield self.read_response()
|
||||
self.close()
|
||||
|
||||
@gen_test
|
||||
def test_pipelined_cancel(self):
|
||||
self.connect()
|
||||
yield self.connect()
|
||||
self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
|
||||
# only read once
|
||||
self.read_response()
|
||||
yield self.read_response()
|
||||
self.close()
|
||||
|
||||
@gen_test
|
||||
def test_cancel_during_download(self):
|
||||
self.connect()
|
||||
yield self.connect()
|
||||
self.stream.write(b'GET /large HTTP/1.1\r\n\r\n')
|
||||
self.read_headers()
|
||||
self.stream.read_bytes(1024, self.stop)
|
||||
self.wait()
|
||||
yield self.read_headers()
|
||||
yield self.stream.read_bytes(1024)
|
||||
self.close()
|
||||
|
||||
@gen_test
|
||||
def test_finish_while_closed(self):
|
||||
self.connect()
|
||||
yield self.connect()
|
||||
self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n')
|
||||
self.read_headers()
|
||||
yield self.read_headers()
|
||||
self.close()
|
||||
|
||||
@gen_test
|
||||
def test_keepalive_chunked(self):
|
||||
self.http_version = b'HTTP/1.0'
|
||||
self.connect()
|
||||
self.stream.write(b'POST / HTTP/1.0\r\nConnection: keep-alive\r\n'
|
||||
yield self.connect()
|
||||
self.stream.write(b'POST / HTTP/1.0\r\n'
|
||||
b'Connection: keep-alive\r\n'
|
||||
b'Transfer-Encoding: chunked\r\n'
|
||||
b'\r\n0\r\n')
|
||||
self.read_response()
|
||||
b'\r\n'
|
||||
b'0\r\n'
|
||||
b'\r\n')
|
||||
yield self.read_response()
|
||||
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
|
||||
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
|
||||
self.read_response()
|
||||
yield self.read_response()
|
||||
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
|
||||
self.close()
|
||||
|
||||
|
|
@ -775,7 +839,7 @@ class GzipBaseTest(object):
|
|||
|
||||
def test_uncompressed(self):
|
||||
response = self.fetch('/', method='POST', body='foo=bar')
|
||||
self.assertEquals(json_decode(response.body), {u('foo'): [u('bar')]})
|
||||
self.assertEquals(json_decode(response.body), {u'foo': [u'bar']})
|
||||
|
||||
|
||||
class GzipTest(GzipBaseTest, AsyncHTTPTestCase):
|
||||
|
|
@ -784,7 +848,7 @@ class GzipTest(GzipBaseTest, AsyncHTTPTestCase):
|
|||
|
||||
def test_gzip(self):
|
||||
response = self.post_gzip('foo=bar')
|
||||
self.assertEquals(json_decode(response.body), {u('foo'): [u('bar')]})
|
||||
self.assertEquals(json_decode(response.body), {u'foo': [u'bar']})
|
||||
|
||||
|
||||
class GzipUnsupportedTest(GzipBaseTest, AsyncHTTPTestCase):
|
||||
|
|
@ -805,7 +869,7 @@ class StreamingChunkSizeTest(AsyncHTTPTestCase):
|
|||
def get_http_client(self):
|
||||
# body_producer doesn't work on curl_httpclient, so override the
|
||||
# configured AsyncHTTPClient implementation.
|
||||
return SimpleAsyncHTTPClient(io_loop=self.io_loop)
|
||||
return SimpleAsyncHTTPClient()
|
||||
|
||||
def get_httpserver_options(self):
|
||||
return dict(chunk_size=self.CHUNK_SIZE, decompress_request=True)
|
||||
|
|
@ -900,11 +964,14 @@ class MaxHeaderSizeTest(AsyncHTTPTestCase):
|
|||
|
||||
def test_large_headers(self):
|
||||
with ExpectLog(gen_log, "Unsatisfiable read", required=False):
|
||||
response = self.fetch("/", headers={'X-Filler': 'a' * 1000})
|
||||
# 431 is "Request Header Fields Too Large", defined in RFC
|
||||
# 6585. However, many implementations just close the
|
||||
# connection in this case, resulting in a 599.
|
||||
self.assertIn(response.code, (431, 599))
|
||||
try:
|
||||
self.fetch("/", headers={'X-Filler': 'a' * 1000}, raise_error=True)
|
||||
self.fail("did not raise expected exception")
|
||||
except HTTPError as e:
|
||||
# 431 is "Request Header Fields Too Large", defined in RFC
|
||||
# 6585. However, many implementations just close the
|
||||
# connection in this case, resulting in a 599.
|
||||
self.assertIn(e.response.code, (431, 599))
|
||||
|
||||
|
||||
@skipOnTravis
|
||||
|
|
@ -924,34 +991,35 @@ class IdleTimeoutTest(AsyncHTTPTestCase):
|
|||
for stream in self.streams:
|
||||
stream.close()
|
||||
|
||||
@gen.coroutine
|
||||
def connect(self):
|
||||
stream = IOStream(socket.socket())
|
||||
stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
|
||||
self.wait()
|
||||
yield stream.connect(('127.0.0.1', self.get_http_port()))
|
||||
self.streams.append(stream)
|
||||
return stream
|
||||
raise gen.Return(stream)
|
||||
|
||||
@gen_test
|
||||
def test_unused_connection(self):
|
||||
stream = self.connect()
|
||||
stream.set_close_callback(self.stop)
|
||||
self.wait()
|
||||
stream = yield self.connect()
|
||||
event = Event()
|
||||
stream.set_close_callback(event.set)
|
||||
yield event.wait()
|
||||
|
||||
@gen_test
|
||||
def test_idle_after_use(self):
|
||||
stream = self.connect()
|
||||
stream.set_close_callback(lambda: self.stop("closed"))
|
||||
stream = yield self.connect()
|
||||
event = Event()
|
||||
stream.set_close_callback(event.set)
|
||||
|
||||
# Use the connection twice to make sure keep-alives are working
|
||||
for i in range(2):
|
||||
stream.write(b"GET / HTTP/1.1\r\n\r\n")
|
||||
stream.read_until(b"\r\n\r\n", self.stop)
|
||||
self.wait()
|
||||
stream.read_bytes(11, self.stop)
|
||||
data = self.wait()
|
||||
yield stream.read_until(b"\r\n\r\n")
|
||||
data = yield stream.read_bytes(11)
|
||||
self.assertEqual(data, b"Hello world")
|
||||
|
||||
# Now let the timeout trigger and close the connection.
|
||||
data = self.wait()
|
||||
self.assertEqual(data, "closed")
|
||||
yield event.wait()
|
||||
|
||||
|
||||
class BodyLimitsTest(AsyncHTTPTestCase):
|
||||
|
|
@ -988,7 +1056,7 @@ class BodyLimitsTest(AsyncHTTPTestCase):
|
|||
def get_http_client(self):
|
||||
# body_producer doesn't work on curl_httpclient, so override the
|
||||
# configured AsyncHTTPClient implementation.
|
||||
return SimpleAsyncHTTPClient(io_loop=self.io_loop)
|
||||
return SimpleAsyncHTTPClient()
|
||||
|
||||
def test_small_body(self):
|
||||
response = self.fetch('/buffered', method='PUT', body=b'a' * 4096)
|
||||
|
|
@ -999,24 +1067,27 @@ class BodyLimitsTest(AsyncHTTPTestCase):
|
|||
def test_large_body_buffered(self):
|
||||
with ExpectLog(gen_log, '.*Content-Length too long'):
|
||||
response = self.fetch('/buffered', method='PUT', body=b'a' * 10240)
|
||||
self.assertEqual(response.code, 599)
|
||||
self.assertEqual(response.code, 400)
|
||||
|
||||
@unittest.skipIf(os.name == 'nt', 'flaky on windows')
|
||||
def test_large_body_buffered_chunked(self):
|
||||
# This test is flaky on windows for unknown reasons.
|
||||
with ExpectLog(gen_log, '.*chunked body too large'):
|
||||
response = self.fetch('/buffered', method='PUT',
|
||||
body_producer=lambda write: write(b'a' * 10240))
|
||||
self.assertEqual(response.code, 599)
|
||||
self.assertEqual(response.code, 400)
|
||||
|
||||
def test_large_body_streaming(self):
|
||||
with ExpectLog(gen_log, '.*Content-Length too long'):
|
||||
response = self.fetch('/streaming', method='PUT', body=b'a' * 10240)
|
||||
self.assertEqual(response.code, 599)
|
||||
self.assertEqual(response.code, 400)
|
||||
|
||||
@unittest.skipIf(os.name == 'nt', 'flaky on windows')
|
||||
def test_large_body_streaming_chunked(self):
|
||||
with ExpectLog(gen_log, '.*chunked body too large'):
|
||||
response = self.fetch('/streaming', method='PUT',
|
||||
body_producer=lambda write: write(b'a' * 10240))
|
||||
self.assertEqual(response.code, 599)
|
||||
self.assertEqual(response.code, 400)
|
||||
|
||||
def test_large_body_streaming_override(self):
|
||||
response = self.fetch('/streaming?expected_size=10240', method='PUT',
|
||||
|
|
@ -1053,14 +1124,16 @@ class BodyLimitsTest(AsyncHTTPTestCase):
|
|||
stream.write(b'PUT /streaming?expected_size=10240 HTTP/1.1\r\n'
|
||||
b'Content-Length: 10240\r\n\r\n')
|
||||
stream.write(b'a' * 10240)
|
||||
headers, response = yield gen.Task(read_stream_body, stream)
|
||||
fut = Future()
|
||||
read_stream_body(stream, callback=fut.set_result)
|
||||
start_line, headers, response = yield fut
|
||||
self.assertEqual(response, b'10240')
|
||||
# Without the ?expected_size parameter, we get the old default value
|
||||
stream.write(b'PUT /streaming HTTP/1.1\r\n'
|
||||
b'Content-Length: 10240\r\n\r\n')
|
||||
with ExpectLog(gen_log, '.*Content-Length too long'):
|
||||
data = yield stream.read_until_close()
|
||||
self.assertEqual(data, b'')
|
||||
self.assertEqual(data, b'HTTP/1.1 400 Bad Request\r\n\r\n')
|
||||
finally:
|
||||
stream.close()
|
||||
|
||||
|
|
@ -1081,10 +1154,10 @@ class LegacyInterfaceTest(AsyncHTTPTestCase):
|
|||
request.connection.finish()
|
||||
return
|
||||
message = b"Hello world"
|
||||
request.write(utf8("HTTP/1.1 200 OK\r\n"
|
||||
"Content-Length: %d\r\n\r\n" % len(message)))
|
||||
request.write(message)
|
||||
request.finish()
|
||||
request.connection.write(utf8("HTTP/1.1 200 OK\r\n"
|
||||
"Content-Length: %d\r\n\r\n" % len(message)))
|
||||
request.connection.write(message)
|
||||
request.connection.finish()
|
||||
return handle_request
|
||||
|
||||
def test_legacy_interface(self):
|
||||
|
|
|
|||
|
|
@ -1,13 +1,16 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from tornado.httputil import url_concat, parse_multipart_form_data, HTTPHeaders, format_timestamp, HTTPServerRequest, parse_request_start_line
|
||||
from tornado.httputil import (
|
||||
url_concat, parse_multipart_form_data, HTTPHeaders, format_timestamp,
|
||||
HTTPServerRequest, parse_request_start_line, parse_cookie, qs_to_qsl,
|
||||
HTTPInputError,
|
||||
)
|
||||
from tornado.escape import utf8, native_str
|
||||
from tornado.util import PY3
|
||||
from tornado.log import gen_log
|
||||
from tornado.testing import ExpectLog
|
||||
from tornado.test.util import unittest
|
||||
from tornado.util import u
|
||||
|
||||
import copy
|
||||
import datetime
|
||||
|
|
@ -15,6 +18,11 @@ import logging
|
|||
import pickle
|
||||
import time
|
||||
|
||||
if PY3:
|
||||
import urllib.parse as urllib_parse
|
||||
else:
|
||||
import urlparse as urllib_parse
|
||||
|
||||
|
||||
class TestUrlConcat(unittest.TestCase):
|
||||
def test_url_concat_no_query_params(self):
|
||||
|
|
@ -43,14 +51,14 @@ class TestUrlConcat(unittest.TestCase):
|
|||
"https://localhost/path?x",
|
||||
[('y', 'y'), ('z', 'z')],
|
||||
)
|
||||
self.assertEqual(url, "https://localhost/path?x&y=y&z=z")
|
||||
self.assertEqual(url, "https://localhost/path?x=&y=y&z=z")
|
||||
|
||||
def test_url_concat_trailing_amp(self):
|
||||
url = url_concat(
|
||||
"https://localhost/path?x&",
|
||||
[('y', 'y'), ('z', 'z')],
|
||||
)
|
||||
self.assertEqual(url, "https://localhost/path?x&y=y&z=z")
|
||||
self.assertEqual(url, "https://localhost/path?x=&y=y&z=z")
|
||||
|
||||
def test_url_concat_mult_params(self):
|
||||
url = url_concat(
|
||||
|
|
@ -66,6 +74,52 @@ class TestUrlConcat(unittest.TestCase):
|
|||
)
|
||||
self.assertEqual(url, "https://localhost/path?r=1&t=2")
|
||||
|
||||
def test_url_concat_none_params(self):
|
||||
url = url_concat(
|
||||
"https://localhost/path?r=1&t=2",
|
||||
None,
|
||||
)
|
||||
self.assertEqual(url, "https://localhost/path?r=1&t=2")
|
||||
|
||||
def test_url_concat_with_frag(self):
|
||||
url = url_concat(
|
||||
"https://localhost/path#tab",
|
||||
[('y', 'y')],
|
||||
)
|
||||
self.assertEqual(url, "https://localhost/path?y=y#tab")
|
||||
|
||||
def test_url_concat_multi_same_params(self):
|
||||
url = url_concat(
|
||||
"https://localhost/path",
|
||||
[('y', 'y1'), ('y', 'y2')],
|
||||
)
|
||||
self.assertEqual(url, "https://localhost/path?y=y1&y=y2")
|
||||
|
||||
def test_url_concat_multi_same_query_params(self):
|
||||
url = url_concat(
|
||||
"https://localhost/path?r=1&r=2",
|
||||
[('y', 'y')],
|
||||
)
|
||||
self.assertEqual(url, "https://localhost/path?r=1&r=2&y=y")
|
||||
|
||||
def test_url_concat_dict_params(self):
|
||||
url = url_concat(
|
||||
"https://localhost/path",
|
||||
dict(y='y'),
|
||||
)
|
||||
self.assertEqual(url, "https://localhost/path?y=y")
|
||||
|
||||
|
||||
class QsParseTest(unittest.TestCase):
|
||||
|
||||
def test_parsing(self):
|
||||
qsstring = "a=1&b=2&a=3"
|
||||
qs = urllib_parse.parse_qs(qsstring)
|
||||
qsl = list(qs_to_qsl(qs))
|
||||
self.assertIn(('a', '1'), qsl)
|
||||
self.assertIn(('a', '3'), qsl)
|
||||
self.assertIn(('b', '2'), qsl)
|
||||
|
||||
|
||||
class MultipartFormDataTest(unittest.TestCase):
|
||||
def test_file_upload(self):
|
||||
|
|
@ -122,6 +176,20 @@ Foo
|
|||
self.assertEqual(file["filename"], filename)
|
||||
self.assertEqual(file["body"], b"Foo")
|
||||
|
||||
def test_non_ascii_filename(self):
|
||||
data = b"""\
|
||||
--1234
|
||||
Content-Disposition: form-data; name="files"; filename="ab.txt"; filename*=UTF-8''%C3%A1b.txt
|
||||
|
||||
Foo
|
||||
--1234--""".replace(b"\n", b"\r\n")
|
||||
args = {}
|
||||
files = {}
|
||||
parse_multipart_form_data(b"1234", data, args, files)
|
||||
file = files["files"][0]
|
||||
self.assertEqual(file["filename"], u"áb.txt")
|
||||
self.assertEqual(file["body"], b"Foo")
|
||||
|
||||
def test_boundary_starts_and_ends_with_quotes(self):
|
||||
data = b'''\
|
||||
--1234
|
||||
|
|
@ -230,6 +298,13 @@ Foo: even
|
|||
("Foo", "bar baz"),
|
||||
("Foo", "even more lines")])
|
||||
|
||||
def test_malformed_continuation(self):
|
||||
# If the first line starts with whitespace, it's a
|
||||
# continuation line with nothing to continue, so reject it
|
||||
# (with a proper error).
|
||||
data = " Foo: bar"
|
||||
self.assertRaises(HTTPInputError, HTTPHeaders.parse, data)
|
||||
|
||||
def test_unicode_newlines(self):
|
||||
# Ensure that only \r\n is recognized as a header separator, and not
|
||||
# the other newline-like unicode characters.
|
||||
|
|
@ -238,13 +313,13 @@ Foo: even
|
|||
# and cpython's unicodeobject.c (which defines the implementation
|
||||
# of unicode_type.splitlines(), and uses a different list than TR13).
|
||||
newlines = [
|
||||
u('\u001b'), # VERTICAL TAB
|
||||
u('\u001c'), # FILE SEPARATOR
|
||||
u('\u001d'), # GROUP SEPARATOR
|
||||
u('\u001e'), # RECORD SEPARATOR
|
||||
u('\u0085'), # NEXT LINE
|
||||
u('\u2028'), # LINE SEPARATOR
|
||||
u('\u2029'), # PARAGRAPH SEPARATOR
|
||||
u'\u001b', # VERTICAL TAB
|
||||
u'\u001c', # FILE SEPARATOR
|
||||
u'\u001d', # GROUP SEPARATOR
|
||||
u'\u001e', # RECORD SEPARATOR
|
||||
u'\u0085', # NEXT LINE
|
||||
u'\u2028', # LINE SEPARATOR
|
||||
u'\u2029', # PARAGRAPH SEPARATOR
|
||||
]
|
||||
for newline in newlines:
|
||||
# Try the utf8 and latin1 representations of each newline
|
||||
|
|
@ -319,6 +394,14 @@ Foo: even
|
|||
self.assertEqual(headers['quux'], 'xyzzy')
|
||||
self.assertEqual(sorted(headers.get_all()), [('Foo', 'bar'), ('Quux', 'xyzzy')])
|
||||
|
||||
def test_string(self):
|
||||
headers = HTTPHeaders()
|
||||
headers.add("Foo", "1")
|
||||
headers.add("Foo", "2")
|
||||
headers.add("Foo", "3")
|
||||
headers2 = HTTPHeaders.parse(str(headers))
|
||||
self.assertEquals(headers, headers2)
|
||||
|
||||
|
||||
class FormatTimestampTest(unittest.TestCase):
|
||||
# Make sure that all the input types are supported.
|
||||
|
|
@ -359,6 +442,10 @@ class HTTPServerRequestTest(unittest.TestCase):
|
|||
requets = HTTPServerRequest(uri='/')
|
||||
self.assertIsInstance(requets.body, bytes)
|
||||
|
||||
def test_repr_does_not_contain_headers(self):
|
||||
request = HTTPServerRequest(uri='/', headers={'Canary': 'Coal Mine'})
|
||||
self.assertTrue('Canary' not in repr(request))
|
||||
|
||||
|
||||
class ParseRequestStartLineTest(unittest.TestCase):
|
||||
METHOD = "GET"
|
||||
|
|
@ -371,3 +458,59 @@ class ParseRequestStartLineTest(unittest.TestCase):
|
|||
self.assertEqual(parsed_start_line.method, self.METHOD)
|
||||
self.assertEqual(parsed_start_line.path, self.PATH)
|
||||
self.assertEqual(parsed_start_line.version, self.VERSION)
|
||||
|
||||
|
||||
class ParseCookieTest(unittest.TestCase):
|
||||
# These tests copied from Django:
|
||||
# https://github.com/django/django/pull/6277/commits/da810901ada1cae9fc1f018f879f11a7fb467b28
|
||||
def test_python_cookies(self):
|
||||
"""
|
||||
Test cases copied from Python's Lib/test/test_http_cookies.py
|
||||
"""
|
||||
self.assertEqual(parse_cookie('chips=ahoy; vienna=finger'),
|
||||
{'chips': 'ahoy', 'vienna': 'finger'})
|
||||
# Here parse_cookie() differs from Python's cookie parsing in that it
|
||||
# treats all semicolons as delimiters, even within quotes.
|
||||
self.assertEqual(
|
||||
parse_cookie('keebler="E=mc2; L=\\"Loves\\"; fudge=\\012;"'),
|
||||
{'keebler': '"E=mc2', 'L': '\\"Loves\\"', 'fudge': '\\012', '': '"'}
|
||||
)
|
||||
# Illegal cookies that have an '=' char in an unquoted value.
|
||||
self.assertEqual(parse_cookie('keebler=E=mc2'), {'keebler': 'E=mc2'})
|
||||
# Cookies with ':' character in their name.
|
||||
self.assertEqual(parse_cookie('key:term=value:term'), {'key:term': 'value:term'})
|
||||
# Cookies with '[' and ']'.
|
||||
self.assertEqual(parse_cookie('a=b; c=[; d=r; f=h'),
|
||||
{'a': 'b', 'c': '[', 'd': 'r', 'f': 'h'})
|
||||
|
||||
def test_cookie_edgecases(self):
|
||||
# Cookies that RFC6265 allows.
|
||||
self.assertEqual(parse_cookie('a=b; Domain=example.com'),
|
||||
{'a': 'b', 'Domain': 'example.com'})
|
||||
# parse_cookie() has historically kept only the last cookie with the
|
||||
# same name.
|
||||
self.assertEqual(parse_cookie('a=b; h=i; a=c'), {'a': 'c', 'h': 'i'})
|
||||
|
||||
def test_invalid_cookies(self):
|
||||
"""
|
||||
Cookie strings that go against RFC6265 but browsers will send if set
|
||||
via document.cookie.
|
||||
"""
|
||||
# Chunks without an equals sign appear as unnamed values per
|
||||
# https://bugzilla.mozilla.org/show_bug.cgi?id=169091
|
||||
self.assertIn('django_language',
|
||||
parse_cookie('abc=def; unnamed; django_language=en').keys())
|
||||
# Even a double quote may be an unamed value.
|
||||
self.assertEqual(parse_cookie('a=b; "; c=d'), {'a': 'b', '': '"', 'c': 'd'})
|
||||
# Spaces in names and values, and an equals sign in values.
|
||||
self.assertEqual(parse_cookie('a b c=d e = f; gh=i'), {'a b c': 'd e = f', 'gh': 'i'})
|
||||
# More characters the spec forbids.
|
||||
self.assertEqual(parse_cookie('a b,c<>@:/[]?{}=d " =e,f g'),
|
||||
{'a b,c<>@:/[]?{}': 'd " =e,f g'})
|
||||
# Unicode characters. The spec only allows ASCII.
|
||||
self.assertEqual(parse_cookie('saint=André Bessette'),
|
||||
{'saint': native_str('André Bessette')})
|
||||
# Browsers don't send extra whitespace or semicolons in Cookie headers,
|
||||
# but parse_cookie() should parse whitespace the same way
|
||||
# document.cookie parses whitespace.
|
||||
self.assertEqual(parse_cookie(' = b ; ; = ; c = ; '), {'': 'b', 'c': ''})
|
||||
|
|
|
|||
|
|
@ -1,47 +1,73 @@
|
|||
# flake8: noqa
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from tornado.test.util import unittest
|
||||
|
||||
_import_everything = b"""
|
||||
# The event loop is not fork-safe, and it's easy to initialize an asyncio.Future
|
||||
# at startup, which in turn creates the default event loop and prevents forking.
|
||||
# Explicitly disallow the default event loop so that an error will be raised
|
||||
# if something tries to touch it.
|
||||
try:
|
||||
import asyncio
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
import tornado.auth
|
||||
import tornado.autoreload
|
||||
import tornado.concurrent
|
||||
import tornado.escape
|
||||
import tornado.gen
|
||||
import tornado.http1connection
|
||||
import tornado.httpclient
|
||||
import tornado.httpserver
|
||||
import tornado.httputil
|
||||
import tornado.ioloop
|
||||
import tornado.iostream
|
||||
import tornado.locale
|
||||
import tornado.log
|
||||
import tornado.netutil
|
||||
import tornado.options
|
||||
import tornado.process
|
||||
import tornado.simple_httpclient
|
||||
import tornado.stack_context
|
||||
import tornado.tcpserver
|
||||
import tornado.tcpclient
|
||||
import tornado.template
|
||||
import tornado.testing
|
||||
import tornado.util
|
||||
import tornado.web
|
||||
import tornado.websocket
|
||||
import tornado.wsgi
|
||||
|
||||
try:
|
||||
import pycurl
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
import tornado.curl_httpclient
|
||||
"""
|
||||
|
||||
|
||||
class ImportTest(unittest.TestCase):
|
||||
def test_import_everything(self):
|
||||
# Some of our modules are not otherwise tested. Import them
|
||||
# all (unless they have external dependencies) here to at
|
||||
# least ensure that there are no syntax errors.
|
||||
import tornado.auth
|
||||
import tornado.autoreload
|
||||
import tornado.concurrent
|
||||
# import tornado.curl_httpclient # depends on pycurl
|
||||
import tornado.escape
|
||||
import tornado.gen
|
||||
import tornado.http1connection
|
||||
import tornado.httpclient
|
||||
import tornado.httpserver
|
||||
import tornado.httputil
|
||||
# Test that all Tornado modules can be imported without side effects,
|
||||
# specifically without initializing the default asyncio event loop.
|
||||
# Since we can't tell which modules may have already beein imported
|
||||
# in our process, do it in a subprocess for a clean slate.
|
||||
proc = subprocess.Popen([sys.executable], stdin=subprocess.PIPE)
|
||||
proc.communicate(_import_everything)
|
||||
self.assertEqual(proc.returncode, 0)
|
||||
|
||||
def test_import_aliases(self):
|
||||
# Ensure we don't delete formerly-documented aliases accidentally.
|
||||
import tornado.ioloop
|
||||
import tornado.iostream
|
||||
import tornado.locale
|
||||
import tornado.log
|
||||
import tornado.netutil
|
||||
import tornado.options
|
||||
import tornado.process
|
||||
import tornado.simple_httpclient
|
||||
import tornado.stack_context
|
||||
import tornado.tcpserver
|
||||
import tornado.template
|
||||
import tornado.testing
|
||||
import tornado.gen
|
||||
import tornado.util
|
||||
import tornado.web
|
||||
import tornado.websocket
|
||||
import tornado.wsgi
|
||||
|
||||
# for modules with dependencies, if those dependencies can be loaded,
|
||||
# load them too.
|
||||
|
||||
def test_import_pycurl(self):
|
||||
try:
|
||||
import pycurl
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
import tornado.curl_httpclient
|
||||
self.assertIs(tornado.ioloop.TimeoutError, tornado.util.TimeoutError)
|
||||
self.assertIs(tornado.gen.TimeoutError, tornado.util.TimeoutError)
|
||||
|
|
|
|||
|
|
@ -1,28 +1,48 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import contextlib
|
||||
import datetime
|
||||
import functools
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import types
|
||||
try:
|
||||
from unittest import mock # type: ignore
|
||||
except ImportError:
|
||||
try:
|
||||
import mock # type: ignore
|
||||
except ImportError:
|
||||
mock = None
|
||||
|
||||
from tornado.escape import native_str
|
||||
from tornado import gen
|
||||
from tornado.ioloop import IOLoop, TimeoutError, PollIOLoop, PeriodicCallback
|
||||
from tornado.log import app_log
|
||||
from tornado.platform.select import _Select
|
||||
from tornado.stack_context import ExceptionStackContext, StackContext, wrap, NullContext
|
||||
from tornado.testing import AsyncTestCase, bind_unused_port, ExpectLog
|
||||
from tornado.test.util import unittest, skipIfNonUnix, skipOnTravis, skipBefore35, exec_test
|
||||
from tornado.testing import AsyncTestCase, bind_unused_port, ExpectLog, gen_test
|
||||
from tornado.test.util import (unittest, skipIfNonUnix, skipOnTravis,
|
||||
skipBefore35, exec_test, ignore_deprecation)
|
||||
|
||||
try:
|
||||
from concurrent import futures
|
||||
except ImportError:
|
||||
futures = None
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
except ImportError:
|
||||
asyncio = None
|
||||
|
||||
try:
|
||||
import twisted
|
||||
except ImportError:
|
||||
twisted = None
|
||||
|
||||
|
||||
class FakeTimeSelect(_Select):
|
||||
def __init__(self):
|
||||
|
|
@ -61,6 +81,25 @@ class FakeTimeIOLoop(PollIOLoop):
|
|||
|
||||
|
||||
class TestIOLoop(AsyncTestCase):
|
||||
def test_add_callback_return_sequence(self):
|
||||
# A callback returning {} or [] shouldn't spin the CPU, see Issue #1803.
|
||||
self.calls = 0
|
||||
|
||||
loop = self.io_loop
|
||||
test = self
|
||||
old_add_callback = loop.add_callback
|
||||
|
||||
def add_callback(self, callback, *args, **kwargs):
|
||||
test.calls += 1
|
||||
old_add_callback(callback, *args, **kwargs)
|
||||
|
||||
loop.add_callback = types.MethodType(add_callback, loop)
|
||||
loop.add_callback(lambda: {})
|
||||
loop.add_callback(lambda: [])
|
||||
loop.add_timeout(datetime.timedelta(milliseconds=50), loop.stop)
|
||||
loop.start()
|
||||
self.assertLess(self.calls, 10)
|
||||
|
||||
@skipOnTravis
|
||||
def test_add_callback_wakeup(self):
|
||||
# Make sure that add_callback from inside a running IOLoop
|
||||
|
|
@ -138,8 +177,9 @@ class TestIOLoop(AsyncTestCase):
|
|||
other_ioloop.close()
|
||||
|
||||
def test_add_callback_while_closing(self):
|
||||
# Issue #635: add_callback() should raise a clean exception
|
||||
# if called while another thread is closing the IOLoop.
|
||||
# add_callback should not fail if it races with another thread
|
||||
# closing the IOLoop. The callbacks are dropped silently
|
||||
# without executing.
|
||||
closing = threading.Event()
|
||||
|
||||
def target():
|
||||
|
|
@ -152,11 +192,7 @@ class TestIOLoop(AsyncTestCase):
|
|||
thread.start()
|
||||
closing.wait()
|
||||
for i in range(1000):
|
||||
try:
|
||||
other_ioloop.add_callback(lambda: None)
|
||||
except RuntimeError as e:
|
||||
self.assertEqual("IOLoop is closing", str(e))
|
||||
break
|
||||
other_ioloop.add_callback(lambda: None)
|
||||
|
||||
def test_handle_callback_exception(self):
|
||||
# IOLoop.handle_callback_exception can be overridden to catch
|
||||
|
|
@ -241,7 +277,9 @@ class TestIOLoop(AsyncTestCase):
|
|||
self.io_loop.call_later(0, results.append, 4)
|
||||
self.io_loop.call_later(0, self.stop)
|
||||
self.wait()
|
||||
self.assertEqual(results, [1, 2, 3, 4])
|
||||
# The asyncio event loop does not guarantee the order of these
|
||||
# callbacks, but PollIOLoop does.
|
||||
self.assertEqual(sorted(results), [1, 2, 3, 4])
|
||||
|
||||
def test_add_timeout_return(self):
|
||||
# All the timeout methods return non-None handles that can be
|
||||
|
|
@ -368,25 +406,29 @@ class TestIOLoop(AsyncTestCase):
|
|||
"""The IOLoop examines exceptions from awaitables and logs them."""
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
async def callback():
|
||||
self.io_loop.add_callback(self.stop)
|
||||
# Stop the IOLoop two iterations after raising an exception
|
||||
# to give the exception time to be logged.
|
||||
self.io_loop.add_callback(self.io_loop.add_callback, self.stop)
|
||||
1 / 0
|
||||
""")
|
||||
with NullContext():
|
||||
self.io_loop.add_callback(namespace["callback"])
|
||||
with ExpectLog(app_log, "Exception in callback"):
|
||||
self.wait()
|
||||
|
||||
def test_spawn_callback(self):
|
||||
# An added callback runs in the test's stack_context, so will be
|
||||
# re-arised in wait().
|
||||
self.io_loop.add_callback(lambda: 1 / 0)
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
self.wait()
|
||||
# A spawned callback is run directly on the IOLoop, so it will be
|
||||
# logged without stopping the test.
|
||||
self.io_loop.spawn_callback(lambda: 1 / 0)
|
||||
self.io_loop.add_callback(self.stop)
|
||||
with ExpectLog(app_log, "Exception in callback"):
|
||||
self.wait()
|
||||
with ignore_deprecation():
|
||||
# An added callback runs in the test's stack_context, so will be
|
||||
# re-raised in wait().
|
||||
self.io_loop.add_callback(lambda: 1 / 0)
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
self.wait()
|
||||
# A spawned callback is run directly on the IOLoop, so it will be
|
||||
# logged without stopping the test.
|
||||
self.io_loop.spawn_callback(lambda: 1 / 0)
|
||||
self.io_loop.add_callback(self.stop)
|
||||
with ExpectLog(app_log, "Exception in callback"):
|
||||
self.wait()
|
||||
|
||||
@skipIfNonUnix
|
||||
def test_remove_handler_from_handler(self):
|
||||
|
|
@ -407,7 +449,7 @@ class TestIOLoop(AsyncTestCase):
|
|||
self.io_loop.remove_handler(client)
|
||||
self.io_loop.add_handler(client, handle_read, self.io_loop.READ)
|
||||
self.io_loop.add_handler(server, handle_read, self.io_loop.READ)
|
||||
self.io_loop.call_later(0.03, self.stop)
|
||||
self.io_loop.call_later(0.1, self.stop)
|
||||
self.wait()
|
||||
|
||||
# Only one fd was read; the other was cleanly removed.
|
||||
|
|
@ -416,6 +458,16 @@ class TestIOLoop(AsyncTestCase):
|
|||
client.close()
|
||||
server.close()
|
||||
|
||||
@gen_test
|
||||
def test_init_close_race(self):
|
||||
# Regression test for #2367
|
||||
def f():
|
||||
for i in range(10):
|
||||
loop = IOLoop()
|
||||
loop.close()
|
||||
|
||||
yield gen.multi([self.io_loop.run_in_executor(None, f) for i in range(2)])
|
||||
|
||||
|
||||
# Deliberately not a subclass of AsyncTestCase so the IOLoop isn't
|
||||
# automatically set as current.
|
||||
|
|
@ -463,6 +515,16 @@ class TestIOLoopCurrent(unittest.TestCase):
|
|||
self.assertIs(self.io_loop, IOLoop.current())
|
||||
|
||||
|
||||
class TestIOLoopCurrentAsync(AsyncTestCase):
|
||||
@gen_test
|
||||
def test_clear_without_current(self):
|
||||
# If there is no current IOLoop, clear_current is a no-op (but
|
||||
# should not fail). Use a thread so we see the threading.Local
|
||||
# in a pristine state.
|
||||
with ThreadPoolExecutor(1) as e:
|
||||
yield e.submit(IOLoop.clear_current)
|
||||
|
||||
|
||||
class TestIOLoopAddCallback(AsyncTestCase):
|
||||
def setUp(self):
|
||||
super(TestIOLoopAddCallback, self).setUp()
|
||||
|
|
@ -485,11 +547,12 @@ class TestIOLoopAddCallback(AsyncTestCase):
|
|||
self.assertNotIn('c2', self.active_contexts)
|
||||
self.stop()
|
||||
|
||||
with StackContext(functools.partial(self.context, 'c1')):
|
||||
wrapped = wrap(f1)
|
||||
with ignore_deprecation():
|
||||
with StackContext(functools.partial(self.context, 'c1')):
|
||||
wrapped = wrap(f1)
|
||||
|
||||
with StackContext(functools.partial(self.context, 'c2')):
|
||||
self.add_callback(wrapped)
|
||||
with StackContext(functools.partial(self.context, 'c2')):
|
||||
self.add_callback(wrapped)
|
||||
|
||||
self.wait()
|
||||
|
||||
|
|
@ -503,11 +566,12 @@ class TestIOLoopAddCallback(AsyncTestCase):
|
|||
self.assertNotIn('c2', self.active_contexts)
|
||||
self.stop((foo, bar))
|
||||
|
||||
with StackContext(functools.partial(self.context, 'c1')):
|
||||
wrapped = wrap(f1)
|
||||
with ignore_deprecation():
|
||||
with StackContext(functools.partial(self.context, 'c1')):
|
||||
wrapped = wrap(f1)
|
||||
|
||||
with StackContext(functools.partial(self.context, 'c2')):
|
||||
self.add_callback(wrapped, 1, bar=2)
|
||||
with StackContext(functools.partial(self.context, 'c2')):
|
||||
self.add_callback(wrapped, 1, bar=2)
|
||||
|
||||
result = self.wait()
|
||||
self.assertEqual(result, (1, 2))
|
||||
|
|
@ -552,15 +616,86 @@ class TestIOLoopFutures(AsyncTestCase):
|
|||
|
||||
# stack_context propagates to the ioloop callback, but the worker
|
||||
# task just has its exceptions caught and saved in the Future.
|
||||
with futures.ThreadPoolExecutor(1) as pool:
|
||||
with ExceptionStackContext(handle_exception):
|
||||
self.io_loop.add_future(pool.submit(task), callback)
|
||||
ready.set()
|
||||
self.wait()
|
||||
with ignore_deprecation():
|
||||
with futures.ThreadPoolExecutor(1) as pool:
|
||||
with ExceptionStackContext(handle_exception):
|
||||
self.io_loop.add_future(pool.submit(task), callback)
|
||||
ready.set()
|
||||
self.wait()
|
||||
|
||||
self.assertEqual(self.exception.args[0], "callback")
|
||||
self.assertEqual(self.future.exception().args[0], "worker")
|
||||
|
||||
@gen_test
|
||||
def test_run_in_executor_gen(self):
|
||||
event1 = threading.Event()
|
||||
event2 = threading.Event()
|
||||
|
||||
def sync_func(self_event, other_event):
|
||||
self_event.set()
|
||||
other_event.wait()
|
||||
# Note that return value doesn't actually do anything,
|
||||
# it is just passed through to our final assertion to
|
||||
# make sure it is passed through properly.
|
||||
return self_event
|
||||
|
||||
# Run two synchronous functions, which would deadlock if not
|
||||
# run in parallel.
|
||||
res = yield [
|
||||
IOLoop.current().run_in_executor(None, sync_func, event1, event2),
|
||||
IOLoop.current().run_in_executor(None, sync_func, event2, event1)
|
||||
]
|
||||
|
||||
self.assertEqual([event1, event2], res)
|
||||
|
||||
@skipBefore35
|
||||
@gen_test
|
||||
def test_run_in_executor_native(self):
|
||||
event1 = threading.Event()
|
||||
event2 = threading.Event()
|
||||
|
||||
def sync_func(self_event, other_event):
|
||||
self_event.set()
|
||||
other_event.wait()
|
||||
return self_event
|
||||
|
||||
# Go through an async wrapper to ensure that the result of
|
||||
# run_in_executor works with await and not just gen.coroutine
|
||||
# (simply passing the underlying concurrrent future would do that).
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
async def async_wrapper(self_event, other_event):
|
||||
return await IOLoop.current().run_in_executor(
|
||||
None, sync_func, self_event, other_event)
|
||||
""")
|
||||
|
||||
res = yield [
|
||||
namespace["async_wrapper"](event1, event2),
|
||||
namespace["async_wrapper"](event2, event1)
|
||||
]
|
||||
|
||||
self.assertEqual([event1, event2], res)
|
||||
|
||||
@gen_test
|
||||
def test_set_default_executor(self):
|
||||
count = [0]
|
||||
|
||||
class MyExecutor(futures.ThreadPoolExecutor):
|
||||
def submit(self, func, *args):
|
||||
count[0] += 1
|
||||
return super(MyExecutor, self).submit(func, *args)
|
||||
|
||||
event = threading.Event()
|
||||
|
||||
def sync_func():
|
||||
event.set()
|
||||
|
||||
executor = MyExecutor(1)
|
||||
loop = IOLoop.current()
|
||||
loop.set_default_executor(executor)
|
||||
yield loop.run_in_executor(None, sync_func)
|
||||
self.assertEqual(1, count[0])
|
||||
self.assertTrue(event.is_set())
|
||||
|
||||
|
||||
class TestIOLoopRunSync(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
|
@ -580,14 +715,14 @@ class TestIOLoopRunSync(unittest.TestCase):
|
|||
def test_async_result(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield gen.Task(self.io_loop.add_callback)
|
||||
yield gen.moment
|
||||
raise gen.Return(42)
|
||||
self.assertEqual(self.io_loop.run_sync(f), 42)
|
||||
|
||||
def test_async_exception(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield gen.Task(self.io_loop.add_callback)
|
||||
yield gen.moment
|
||||
1 / 0
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
self.io_loop.run_sync(f)
|
||||
|
|
@ -600,18 +735,24 @@ class TestIOLoopRunSync(unittest.TestCase):
|
|||
def test_timeout(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1)
|
||||
yield gen.sleep(1)
|
||||
self.assertRaises(TimeoutError, self.io_loop.run_sync, f, timeout=0.01)
|
||||
|
||||
@skipBefore35
|
||||
def test_native_coroutine(self):
|
||||
@gen.coroutine
|
||||
def f1():
|
||||
yield gen.moment
|
||||
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
async def f():
|
||||
await gen.Task(self.io_loop.add_callback)
|
||||
async def f2():
|
||||
await f1()
|
||||
""")
|
||||
self.io_loop.run_sync(namespace['f'])
|
||||
self.io_loop.run_sync(namespace['f2'])
|
||||
|
||||
|
||||
@unittest.skipIf(asyncio is not None,
|
||||
'IOLoop configuration not available')
|
||||
class TestPeriodicCallback(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.io_loop = FakeTimeIOLoop()
|
||||
|
|
@ -653,6 +794,149 @@ class TestPeriodicCallback(unittest.TestCase):
|
|||
self.io_loop.start()
|
||||
self.assertEqual(calls, expected)
|
||||
|
||||
def test_io_loop_set_at_start(self):
|
||||
# Check PeriodicCallback uses the current IOLoop at start() time,
|
||||
# not at instantiation time.
|
||||
calls = []
|
||||
io_loop = FakeTimeIOLoop()
|
||||
|
||||
def cb():
|
||||
calls.append(io_loop.time())
|
||||
pc = PeriodicCallback(cb, 10000)
|
||||
io_loop.make_current()
|
||||
pc.start()
|
||||
io_loop.call_later(50, io_loop.stop)
|
||||
io_loop.start()
|
||||
self.assertEqual(calls, [1010, 1020, 1030, 1040, 1050])
|
||||
io_loop.close()
|
||||
|
||||
|
||||
class TestPeriodicCallbackMath(unittest.TestCase):
|
||||
def simulate_calls(self, pc, durations):
|
||||
"""Simulate a series of calls to the PeriodicCallback.
|
||||
|
||||
Pass a list of call durations in seconds (negative values
|
||||
work to simulate clock adjustments during the call, or more or
|
||||
less equivalently, between calls). This method returns the
|
||||
times at which each call would be made.
|
||||
"""
|
||||
calls = []
|
||||
now = 1000
|
||||
pc._next_timeout = now
|
||||
for d in durations:
|
||||
pc._update_next(now)
|
||||
calls.append(pc._next_timeout)
|
||||
now = pc._next_timeout + d
|
||||
return calls
|
||||
|
||||
def test_basic(self):
|
||||
pc = PeriodicCallback(None, 10000)
|
||||
self.assertEqual(self.simulate_calls(pc, [0] * 5),
|
||||
[1010, 1020, 1030, 1040, 1050])
|
||||
|
||||
def test_overrun(self):
|
||||
# If a call runs for too long, we skip entire cycles to get
|
||||
# back on schedule.
|
||||
call_durations = [9, 9, 10, 11, 20, 20, 35, 35, 0, 0, 0]
|
||||
expected = [
|
||||
1010, 1020, 1030, # first 3 calls on schedule
|
||||
1050, 1070, # next 2 delayed one cycle
|
||||
1100, 1130, # next 2 delayed 2 cycles
|
||||
1170, 1210, # next 2 delayed 3 cycles
|
||||
1220, 1230, # then back on schedule.
|
||||
]
|
||||
|
||||
pc = PeriodicCallback(None, 10000)
|
||||
self.assertEqual(self.simulate_calls(pc, call_durations),
|
||||
expected)
|
||||
|
||||
def test_clock_backwards(self):
|
||||
pc = PeriodicCallback(None, 10000)
|
||||
# Backwards jumps are ignored, potentially resulting in a
|
||||
# slightly slow schedule (although we assume that when
|
||||
# time.time() and time.monotonic() are different, time.time()
|
||||
# is getting adjusted by NTP and is therefore more accurate)
|
||||
self.assertEqual(self.simulate_calls(pc, [-2, -1, -3, -2, 0]),
|
||||
[1010, 1020, 1030, 1040, 1050])
|
||||
|
||||
# For big jumps, we should perhaps alter the schedule, but we
|
||||
# don't currently. This trace shows that we run callbacks
|
||||
# every 10s of time.time(), but the first and second calls are
|
||||
# 110s of real time apart because the backwards jump is
|
||||
# ignored.
|
||||
self.assertEqual(self.simulate_calls(pc, [-100, 0, 0]),
|
||||
[1010, 1020, 1030])
|
||||
|
||||
@unittest.skipIf(mock is None, 'mock package not present')
|
||||
def test_jitter(self):
|
||||
random_times = [0.5, 1, 0, 0.75]
|
||||
expected = [1010, 1022.5, 1030, 1041.25]
|
||||
call_durations = [0] * len(random_times)
|
||||
pc = PeriodicCallback(None, 10000, jitter=0.5)
|
||||
|
||||
def mock_random():
|
||||
return random_times.pop(0)
|
||||
with mock.patch('random.random', mock_random):
|
||||
self.assertEqual(self.simulate_calls(pc, call_durations),
|
||||
expected)
|
||||
|
||||
|
||||
class TestIOLoopConfiguration(unittest.TestCase):
|
||||
def run_python(self, *statements):
|
||||
statements = [
|
||||
'from tornado.ioloop import IOLoop, PollIOLoop',
|
||||
'classname = lambda x: x.__class__.__name__',
|
||||
] + list(statements)
|
||||
args = [sys.executable, '-c', '; '.join(statements)]
|
||||
return native_str(subprocess.check_output(args)).strip()
|
||||
|
||||
def test_default(self):
|
||||
if asyncio is not None:
|
||||
# When asyncio is available, it is used by default.
|
||||
cls = self.run_python('print(classname(IOLoop.current()))')
|
||||
self.assertEqual(cls, 'AsyncIOMainLoop')
|
||||
cls = self.run_python('print(classname(IOLoop()))')
|
||||
self.assertEqual(cls, 'AsyncIOLoop')
|
||||
else:
|
||||
# Otherwise, the default is a subclass of PollIOLoop
|
||||
is_poll = self.run_python(
|
||||
'print(isinstance(IOLoop.current(), PollIOLoop))')
|
||||
self.assertEqual(is_poll, 'True')
|
||||
|
||||
@unittest.skipIf(asyncio is not None,
|
||||
"IOLoop configuration not available")
|
||||
def test_explicit_select(self):
|
||||
# SelectIOLoop can always be configured explicitly.
|
||||
default_class = self.run_python(
|
||||
'IOLoop.configure("tornado.platform.select.SelectIOLoop")',
|
||||
'print(classname(IOLoop.current()))')
|
||||
self.assertEqual(default_class, 'SelectIOLoop')
|
||||
|
||||
@unittest.skipIf(asyncio is None, "asyncio module not present")
|
||||
def test_asyncio(self):
|
||||
cls = self.run_python(
|
||||
'IOLoop.configure("tornado.platform.asyncio.AsyncIOLoop")',
|
||||
'print(classname(IOLoop.current()))')
|
||||
self.assertEqual(cls, 'AsyncIOMainLoop')
|
||||
|
||||
@unittest.skipIf(asyncio is None, "asyncio module not present")
|
||||
def test_asyncio_main(self):
|
||||
cls = self.run_python(
|
||||
'from tornado.platform.asyncio import AsyncIOMainLoop',
|
||||
'AsyncIOMainLoop().install()',
|
||||
'print(classname(IOLoop.current()))')
|
||||
self.assertEqual(cls, 'AsyncIOMainLoop')
|
||||
|
||||
@unittest.skipIf(twisted is None, "twisted module not present")
|
||||
@unittest.skipIf(asyncio is not None,
|
||||
"IOLoop configuration not available")
|
||||
def test_twisted(self):
|
||||
cls = self.run_python(
|
||||
'from tornado.platform.twisted import TwistedIOLoop',
|
||||
'TwistedIOLoop().install()',
|
||||
'print(classname(IOLoop.current()))')
|
||||
self.assertEqual(cls, 'TwistedIOLoop')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,4 +1,4 @@
|
|||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import datetime
|
||||
import os
|
||||
|
|
@ -8,7 +8,7 @@ import tempfile
|
|||
import tornado.locale
|
||||
from tornado.escape import utf8, to_unicode
|
||||
from tornado.test.util import unittest, skipOnAppEngine
|
||||
from tornado.util import u, unicode_type
|
||||
from tornado.util import unicode_type
|
||||
|
||||
|
||||
class TranslationLoaderTest(unittest.TestCase):
|
||||
|
|
@ -35,7 +35,7 @@ class TranslationLoaderTest(unittest.TestCase):
|
|||
os.path.join(os.path.dirname(__file__), 'csv_translations'))
|
||||
locale = tornado.locale.get("fr_FR")
|
||||
self.assertTrue(isinstance(locale, tornado.locale.CSVLocale))
|
||||
self.assertEqual(locale.translate("school"), u("\u00e9cole"))
|
||||
self.assertEqual(locale.translate("school"), u"\u00e9cole")
|
||||
|
||||
# tempfile.mkdtemp is not available on app engine.
|
||||
@skipOnAppEngine
|
||||
|
|
@ -55,7 +55,7 @@ class TranslationLoaderTest(unittest.TestCase):
|
|||
tornado.locale.load_translations(tmpdir)
|
||||
locale = tornado.locale.get('fr_FR')
|
||||
self.assertIsInstance(locale, tornado.locale.CSVLocale)
|
||||
self.assertEqual(locale.translate("school"), u("\u00e9cole"))
|
||||
self.assertEqual(locale.translate("school"), u"\u00e9cole")
|
||||
finally:
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
|
|
@ -65,20 +65,20 @@ class TranslationLoaderTest(unittest.TestCase):
|
|||
"tornado_test")
|
||||
locale = tornado.locale.get("fr_FR")
|
||||
self.assertTrue(isinstance(locale, tornado.locale.GettextLocale))
|
||||
self.assertEqual(locale.translate("school"), u("\u00e9cole"))
|
||||
self.assertEqual(locale.pgettext("law", "right"), u("le droit"))
|
||||
self.assertEqual(locale.pgettext("good", "right"), u("le bien"))
|
||||
self.assertEqual(locale.pgettext("organization", "club", "clubs", 1), u("le club"))
|
||||
self.assertEqual(locale.pgettext("organization", "club", "clubs", 2), u("les clubs"))
|
||||
self.assertEqual(locale.pgettext("stick", "club", "clubs", 1), u("le b\xe2ton"))
|
||||
self.assertEqual(locale.pgettext("stick", "club", "clubs", 2), u("les b\xe2tons"))
|
||||
self.assertEqual(locale.translate("school"), u"\u00e9cole")
|
||||
self.assertEqual(locale.pgettext("law", "right"), u"le droit")
|
||||
self.assertEqual(locale.pgettext("good", "right"), u"le bien")
|
||||
self.assertEqual(locale.pgettext("organization", "club", "clubs", 1), u"le club")
|
||||
self.assertEqual(locale.pgettext("organization", "club", "clubs", 2), u"les clubs")
|
||||
self.assertEqual(locale.pgettext("stick", "club", "clubs", 1), u"le b\xe2ton")
|
||||
self.assertEqual(locale.pgettext("stick", "club", "clubs", 2), u"les b\xe2tons")
|
||||
|
||||
|
||||
class LocaleDataTest(unittest.TestCase):
|
||||
def test_non_ascii_name(self):
|
||||
name = tornado.locale.LOCALE_NAMES['es_LA']['name']
|
||||
self.assertTrue(isinstance(name, unicode_type))
|
||||
self.assertEqual(name, u('Espa\u00f1ol'))
|
||||
self.assertEqual(name, u'Espa\u00f1ol')
|
||||
self.assertEqual(utf8(name), b'Espa\xc3\xb1ol')
|
||||
|
||||
|
||||
|
|
@ -89,16 +89,17 @@ class EnglishTest(unittest.TestCase):
|
|||
self.assertEqual(locale.format_date(date, full_format=True),
|
||||
'April 28, 2013 at 6:35 pm')
|
||||
|
||||
self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(seconds=2), full_format=False),
|
||||
now = datetime.datetime.utcnow()
|
||||
|
||||
self.assertEqual(locale.format_date(now - datetime.timedelta(seconds=2), full_format=False),
|
||||
'2 seconds ago')
|
||||
self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(minutes=2), full_format=False),
|
||||
self.assertEqual(locale.format_date(now - datetime.timedelta(minutes=2), full_format=False),
|
||||
'2 minutes ago')
|
||||
self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(hours=2), full_format=False),
|
||||
self.assertEqual(locale.format_date(now - datetime.timedelta(hours=2), full_format=False),
|
||||
'2 hours ago')
|
||||
|
||||
now = datetime.datetime.utcnow()
|
||||
self.assertEqual(locale.format_date(now - datetime.timedelta(days=1), full_format=False, shorter=True),
|
||||
'yesterday')
|
||||
self.assertEqual(locale.format_date(now - datetime.timedelta(days=1),
|
||||
full_format=False, shorter=True), 'yesterday')
|
||||
|
||||
date = now - datetime.timedelta(days=2)
|
||||
self.assertEqual(locale.format_date(date, full_format=False, shorter=True),
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@
|
|||
# under the License.
|
||||
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
from datetime import timedelta
|
||||
|
||||
from tornado import gen, locks
|
||||
|
|
@ -35,6 +35,16 @@ class ConditionTest(AsyncTestCase):
|
|||
self.history.append(key)
|
||||
future.add_done_callback(callback)
|
||||
|
||||
def loop_briefly(self):
|
||||
"""Run all queued callbacks on the IOLoop.
|
||||
|
||||
In these tests, this method is used after calling notify() to
|
||||
preserve the pre-5.0 behavior in which callbacks ran
|
||||
synchronously.
|
||||
"""
|
||||
self.io_loop.add_callback(self.stop)
|
||||
self.wait()
|
||||
|
||||
def test_repr(self):
|
||||
c = locks.Condition()
|
||||
self.assertIn('Condition', repr(c))
|
||||
|
|
@ -53,8 +63,10 @@ class ConditionTest(AsyncTestCase):
|
|||
self.record_done(c.wait(), 'wait1')
|
||||
self.record_done(c.wait(), 'wait2')
|
||||
c.notify(1)
|
||||
self.loop_briefly()
|
||||
self.history.append('notify1')
|
||||
c.notify(1)
|
||||
self.loop_briefly()
|
||||
self.history.append('notify2')
|
||||
self.assertEqual(['wait1', 'notify1', 'wait2', 'notify2'],
|
||||
self.history)
|
||||
|
|
@ -65,12 +77,15 @@ class ConditionTest(AsyncTestCase):
|
|||
self.record_done(c.wait(), i)
|
||||
|
||||
c.notify(3)
|
||||
self.loop_briefly()
|
||||
|
||||
# Callbacks execute in the order they were registered.
|
||||
self.assertEqual(list(range(3)), self.history)
|
||||
c.notify(1)
|
||||
self.loop_briefly()
|
||||
self.assertEqual(list(range(4)), self.history)
|
||||
c.notify(2)
|
||||
self.loop_briefly()
|
||||
self.assertEqual(list(range(6)), self.history)
|
||||
|
||||
def test_notify_all(self):
|
||||
|
|
@ -79,6 +94,7 @@ class ConditionTest(AsyncTestCase):
|
|||
self.record_done(c.wait(), i)
|
||||
|
||||
c.notify_all()
|
||||
self.loop_briefly()
|
||||
self.history.append('notify_all')
|
||||
|
||||
# Callbacks execute in the order they were registered.
|
||||
|
|
@ -125,6 +141,7 @@ class ConditionTest(AsyncTestCase):
|
|||
self.assertEqual(['timeout', 0, 2], self.history)
|
||||
self.assertEqual(['timeout', 0, 2], self.history)
|
||||
c.notify()
|
||||
yield
|
||||
self.assertEqual(['timeout', 0, 2, 3], self.history)
|
||||
|
||||
@gen_test
|
||||
|
|
@ -139,6 +156,7 @@ class ConditionTest(AsyncTestCase):
|
|||
self.assertEqual(['timeout'], self.history)
|
||||
|
||||
c.notify_all()
|
||||
yield
|
||||
self.assertEqual(['timeout', 0, 2], self.history)
|
||||
|
||||
@gen_test
|
||||
|
|
@ -154,6 +172,7 @@ class ConditionTest(AsyncTestCase):
|
|||
# resolving third future.
|
||||
futures[1].add_done_callback(lambda _: c.notify())
|
||||
c.notify(2)
|
||||
yield
|
||||
self.assertTrue(all(f.done() for f in futures))
|
||||
|
||||
@gen_test
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright 2012 Facebook
|
||||
#
|
||||
|
|
@ -13,7 +12,7 @@
|
|||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import contextlib
|
||||
import glob
|
||||
|
|
@ -29,7 +28,7 @@ from tornado.escape import utf8
|
|||
from tornado.log import LogFormatter, define_logging_options, enable_pretty_logging
|
||||
from tornado.options import OptionParser
|
||||
from tornado.test.util import unittest
|
||||
from tornado.util import u, basestring_type
|
||||
from tornado.util import basestring_type
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
|
@ -42,7 +41,8 @@ def ignore_bytes_warning():
|
|||
class LogFormatterTest(unittest.TestCase):
|
||||
# Matches the output of a single logging call (which may be multiple lines
|
||||
# if a traceback was included, so we use the DOTALL option)
|
||||
LINE_RE = re.compile(b"(?s)\x01\\[E [0-9]{6} [0-9]{2}:[0-9]{2}:[0-9]{2} log_test:[0-9]+\\]\x02 (.*)")
|
||||
LINE_RE = re.compile(
|
||||
b"(?s)\x01\\[E [0-9]{6} [0-9]{2}:[0-9]{2}:[0-9]{2} log_test:[0-9]+\\]\x02 (.*)")
|
||||
|
||||
def setUp(self):
|
||||
self.formatter = LogFormatter(color=False)
|
||||
|
|
@ -51,9 +51,9 @@ class LogFormatterTest(unittest.TestCase):
|
|||
# for testing. (testing with color off fails to expose some potential
|
||||
# encoding issues from the control characters)
|
||||
self.formatter._colors = {
|
||||
logging.ERROR: u("\u0001"),
|
||||
logging.ERROR: u"\u0001",
|
||||
}
|
||||
self.formatter._normal = u("\u0002")
|
||||
self.formatter._normal = u"\u0002"
|
||||
# construct a Logger directly to bypass getLogger's caching
|
||||
self.logger = logging.Logger('LogFormatterTest')
|
||||
self.logger.propagate = False
|
||||
|
|
@ -96,16 +96,16 @@ class LogFormatterTest(unittest.TestCase):
|
|||
|
||||
def test_utf8_logging(self):
|
||||
with ignore_bytes_warning():
|
||||
self.logger.error(u("\u00e9").encode("utf8"))
|
||||
self.logger.error(u"\u00e9".encode("utf8"))
|
||||
if issubclass(bytes, basestring_type):
|
||||
# on python 2, utf8 byte strings (and by extension ascii byte
|
||||
# strings) are passed through as-is.
|
||||
self.assertEqual(self.get_output(), utf8(u("\u00e9")))
|
||||
self.assertEqual(self.get_output(), utf8(u"\u00e9"))
|
||||
else:
|
||||
# on python 3, byte strings always get repr'd even if
|
||||
# they're ascii-only, so this degenerates into another
|
||||
# copy of test_bytes_logging.
|
||||
self.assertEqual(self.get_output(), utf8(repr(utf8(u("\u00e9")))))
|
||||
self.assertEqual(self.get_output(), utf8(repr(utf8(u"\u00e9"))))
|
||||
|
||||
def test_bytes_exception_logging(self):
|
||||
try:
|
||||
|
|
@ -128,8 +128,8 @@ class UnicodeLogFormatterTest(LogFormatterTest):
|
|||
return logging.FileHandler(filename, encoding="utf8")
|
||||
|
||||
def test_unicode_logging(self):
|
||||
self.logger.error(u("\u00e9"))
|
||||
self.assertEqual(self.get_output(), utf8(u("\u00e9")))
|
||||
self.logger.error(u"\u00e9")
|
||||
self.assertEqual(self.get_output(), utf8(u"\u00e9"))
|
||||
|
||||
|
||||
class EnablePrettyLoggingTest(unittest.TestCase):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import errno
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
|
|
@ -7,10 +8,12 @@ from subprocess import Popen
|
|||
import sys
|
||||
import time
|
||||
|
||||
from tornado.netutil import BlockingResolver, ThreadedResolver, is_valid_ip, bind_sockets
|
||||
from tornado.netutil import (
|
||||
BlockingResolver, OverrideResolver, ThreadedResolver, is_valid_ip, bind_sockets
|
||||
)
|
||||
from tornado.stack_context import ExceptionStackContext
|
||||
from tornado.testing import AsyncTestCase, gen_test, bind_unused_port
|
||||
from tornado.test.util import unittest, skipIfNoNetwork
|
||||
from tornado.test.util import unittest, skipIfNoNetwork, ignore_deprecation
|
||||
|
||||
try:
|
||||
from concurrent import futures
|
||||
|
|
@ -18,15 +21,15 @@ except ImportError:
|
|||
futures = None
|
||||
|
||||
try:
|
||||
import pycares
|
||||
import pycares # type: ignore
|
||||
except ImportError:
|
||||
pycares = None
|
||||
else:
|
||||
from tornado.platform.caresresolver import CaresResolver
|
||||
|
||||
try:
|
||||
import twisted
|
||||
import twisted.names
|
||||
import twisted # type: ignore
|
||||
import twisted.names # type: ignore
|
||||
except ImportError:
|
||||
twisted = None
|
||||
else:
|
||||
|
|
@ -35,7 +38,8 @@ else:
|
|||
|
||||
class _ResolverTestMixin(object):
|
||||
def test_localhost(self):
|
||||
self.resolver.resolve('localhost', 80, callback=self.stop)
|
||||
with ignore_deprecation():
|
||||
self.resolver.resolve('localhost', 80, callback=self.stop)
|
||||
result = self.wait()
|
||||
self.assertIn((socket.AF_INET, ('127.0.0.1', 80)), result)
|
||||
|
||||
|
|
@ -55,29 +59,30 @@ class _ResolverErrorTestMixin(object):
|
|||
self.stop(exc_val)
|
||||
return True # Halt propagation.
|
||||
|
||||
with ExceptionStackContext(handler):
|
||||
self.resolver.resolve('an invalid domain', 80, callback=self.stop)
|
||||
with ignore_deprecation():
|
||||
with ExceptionStackContext(handler):
|
||||
self.resolver.resolve('an invalid domain', 80, callback=self.stop)
|
||||
|
||||
result = self.wait()
|
||||
self.assertIsInstance(result, Exception)
|
||||
|
||||
@gen_test
|
||||
def test_future_interface_bad_host(self):
|
||||
with self.assertRaises(Exception):
|
||||
with self.assertRaises(IOError):
|
||||
yield self.resolver.resolve('an invalid domain', 80,
|
||||
socket.AF_UNSPEC)
|
||||
|
||||
|
||||
def _failing_getaddrinfo(*args):
|
||||
"""Dummy implementation of getaddrinfo for use in mocks"""
|
||||
raise socket.gaierror("mock: lookup failed")
|
||||
raise socket.gaierror(errno.EIO, "mock: lookup failed")
|
||||
|
||||
|
||||
@skipIfNoNetwork
|
||||
class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
|
||||
def setUp(self):
|
||||
super(BlockingResolverTest, self).setUp()
|
||||
self.resolver = BlockingResolver(io_loop=self.io_loop)
|
||||
self.resolver = BlockingResolver()
|
||||
|
||||
|
||||
# getaddrinfo-based tests need mocking to reliably generate errors;
|
||||
|
|
@ -86,7 +91,7 @@ class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
|
|||
class BlockingResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
|
||||
def setUp(self):
|
||||
super(BlockingResolverErrorTest, self).setUp()
|
||||
self.resolver = BlockingResolver(io_loop=self.io_loop)
|
||||
self.resolver = BlockingResolver()
|
||||
self.real_getaddrinfo = socket.getaddrinfo
|
||||
socket.getaddrinfo = _failing_getaddrinfo
|
||||
|
||||
|
|
@ -95,12 +100,31 @@ class BlockingResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
|
|||
super(BlockingResolverErrorTest, self).tearDown()
|
||||
|
||||
|
||||
class OverrideResolverTest(AsyncTestCase, _ResolverTestMixin):
|
||||
def setUp(self):
|
||||
super(OverrideResolverTest, self).setUp()
|
||||
mapping = {
|
||||
('google.com', 80): ('1.2.3.4', 80),
|
||||
('google.com', 80, socket.AF_INET): ('1.2.3.4', 80),
|
||||
('google.com', 80, socket.AF_INET6): ('2a02:6b8:7c:40c:c51e:495f:e23a:3', 80)
|
||||
}
|
||||
self.resolver = OverrideResolver(BlockingResolver(), mapping)
|
||||
|
||||
@gen_test
|
||||
def test_resolve_multiaddr(self):
|
||||
result = yield self.resolver.resolve('google.com', 80, socket.AF_INET)
|
||||
self.assertIn((socket.AF_INET, ('1.2.3.4', 80)), result)
|
||||
|
||||
result = yield self.resolver.resolve('google.com', 80, socket.AF_INET6)
|
||||
self.assertIn((socket.AF_INET6, ('2a02:6b8:7c:40c:c51e:495f:e23a:3', 80, 0, 0)), result)
|
||||
|
||||
|
||||
@skipIfNoNetwork
|
||||
@unittest.skipIf(futures is None, "futures module not present")
|
||||
class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
|
||||
def setUp(self):
|
||||
super(ThreadedResolverTest, self).setUp()
|
||||
self.resolver = ThreadedResolver(io_loop=self.io_loop)
|
||||
self.resolver = ThreadedResolver()
|
||||
|
||||
def tearDown(self):
|
||||
self.resolver.close()
|
||||
|
|
@ -110,7 +134,7 @@ class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
|
|||
class ThreadedResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
|
||||
def setUp(self):
|
||||
super(ThreadedResolverErrorTest, self).setUp()
|
||||
self.resolver = BlockingResolver(io_loop=self.io_loop)
|
||||
self.resolver = BlockingResolver()
|
||||
self.real_getaddrinfo = socket.getaddrinfo
|
||||
socket.getaddrinfo = _failing_getaddrinfo
|
||||
|
||||
|
|
@ -157,19 +181,23 @@ class ThreadedResolverImportTest(unittest.TestCase):
|
|||
class CaresResolverTest(AsyncTestCase, _ResolverTestMixin):
|
||||
def setUp(self):
|
||||
super(CaresResolverTest, self).setUp()
|
||||
self.resolver = CaresResolver(io_loop=self.io_loop)
|
||||
self.resolver = CaresResolver()
|
||||
|
||||
|
||||
# TwistedResolver produces consistent errors in our test cases so we
|
||||
# can test the regular and error cases in the same class.
|
||||
# could test the regular and error cases in the same class. However,
|
||||
# in the error cases it appears that cleanup of socket objects is
|
||||
# handled asynchronously and occasionally results in "unclosed socket"
|
||||
# warnings if not given time to shut down (and there is no way to
|
||||
# explicitly shut it down). This makes the test flaky, so we do not
|
||||
# test error cases here.
|
||||
@skipIfNoNetwork
|
||||
@unittest.skipIf(twisted is None, "twisted module not present")
|
||||
@unittest.skipIf(getattr(twisted, '__version__', '0.0') < "12.1", "old version of twisted")
|
||||
class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin,
|
||||
_ResolverErrorTestMixin):
|
||||
class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin):
|
||||
def setUp(self):
|
||||
super(TwistedResolverTest, self).setUp()
|
||||
self.resolver = TwistedResolver(io_loop=self.io_loop)
|
||||
self.resolver = TwistedResolver()
|
||||
|
||||
|
||||
class IsValidIPTest(unittest.TestCase):
|
||||
|
|
@ -203,9 +231,10 @@ class TestPortAllocation(unittest.TestCase):
|
|||
|
||||
@unittest.skipIf(not hasattr(socket, "SO_REUSEPORT"), "SO_REUSEPORT is not supported")
|
||||
def test_reuse_port(self):
|
||||
sockets = []
|
||||
socket, port = bind_unused_port(reuse_port=True)
|
||||
try:
|
||||
sockets = bind_sockets(port, 'localhost', reuse_port=True)
|
||||
sockets = bind_sockets(port, '127.0.0.1', reuse_port=True)
|
||||
self.assertTrue(all(s.getsockname()[1] == port for s in sockets))
|
||||
finally:
|
||||
socket.close()
|
||||
|
|
|
|||
|
|
@ -3,3 +3,5 @@ port=443
|
|||
username='李康'
|
||||
|
||||
foo_bar='a'
|
||||
|
||||
my_path = __file__
|
||||
|
|
|
|||
|
|
@ -1,28 +1,41 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import sys
|
||||
|
||||
from tornado.options import OptionParser, Error
|
||||
from tornado.util import basestring_type
|
||||
from tornado.test.util import unittest
|
||||
from tornado.util import basestring_type, PY3
|
||||
from tornado.test.util import unittest, subTest
|
||||
|
||||
if PY3:
|
||||
from io import StringIO
|
||||
else:
|
||||
from cStringIO import StringIO
|
||||
|
||||
try:
|
||||
from cStringIO import StringIO # python 2
|
||||
except ImportError:
|
||||
from io import StringIO # python 3
|
||||
|
||||
try:
|
||||
from unittest import mock # python 3.3
|
||||
# py33+
|
||||
from unittest import mock # type: ignore
|
||||
except ImportError:
|
||||
try:
|
||||
import mock # third-party mock package
|
||||
import mock # type: ignore
|
||||
except ImportError:
|
||||
mock = None
|
||||
|
||||
|
||||
class Email(object):
|
||||
def __init__(self, value):
|
||||
if isinstance(value, str) and '@' in value:
|
||||
self._value = value
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self._value
|
||||
|
||||
|
||||
class OptionsTest(unittest.TestCase):
|
||||
def test_parse_command_line(self):
|
||||
options = OptionParser()
|
||||
|
|
@ -34,10 +47,13 @@ class OptionsTest(unittest.TestCase):
|
|||
options = OptionParser()
|
||||
options.define("port", default=80)
|
||||
options.define("username", default='foo')
|
||||
options.parse_config_file(os.path.join(os.path.dirname(__file__),
|
||||
"options_test.cfg"))
|
||||
self.assertEquals(options.port, 443)
|
||||
options.define("my_path")
|
||||
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
"options_test.cfg")
|
||||
options.parse_config_file(config_path)
|
||||
self.assertEqual(options.port, 443)
|
||||
self.assertEqual(options.username, "李康")
|
||||
self.assertEqual(options.my_path, config_path)
|
||||
|
||||
def test_parse_callbacks(self):
|
||||
options = OptionParser()
|
||||
|
|
@ -131,6 +147,12 @@ class OptionsTest(unittest.TestCase):
|
|||
options = self._sample_options()
|
||||
self.assertEqual(1, options['a'])
|
||||
|
||||
def test_setitem(self):
|
||||
options = OptionParser()
|
||||
options.define('foo', default=1, type=int)
|
||||
options['foo'] = 2
|
||||
self.assertEqual(options['foo'], 2)
|
||||
|
||||
def test_items(self):
|
||||
options = self._sample_options()
|
||||
# OptionParsers always define 'help'.
|
||||
|
|
@ -179,7 +201,7 @@ class OptionsTest(unittest.TestCase):
|
|||
self.assertEqual(options.foo, 5)
|
||||
self.assertEqual(options.foo, 2)
|
||||
|
||||
def test_types(self):
|
||||
def _define_options(self):
|
||||
options = OptionParser()
|
||||
options.define('str', type=str)
|
||||
options.define('basestring', type=basestring_type)
|
||||
|
|
@ -187,13 +209,11 @@ class OptionsTest(unittest.TestCase):
|
|||
options.define('float', type=float)
|
||||
options.define('datetime', type=datetime.datetime)
|
||||
options.define('timedelta', type=datetime.timedelta)
|
||||
options.parse_command_line(['main.py',
|
||||
'--str=asdf',
|
||||
'--basestring=qwer',
|
||||
'--int=42',
|
||||
'--float=1.5',
|
||||
'--datetime=2013-04-28 05:16',
|
||||
'--timedelta=45s'])
|
||||
options.define('email', type=Email)
|
||||
options.define('list-of-int', type=int, multiple=True)
|
||||
return options
|
||||
|
||||
def _check_options_values(self, options):
|
||||
self.assertEqual(options.str, 'asdf')
|
||||
self.assertEqual(options.basestring, 'qwer')
|
||||
self.assertEqual(options.int, 42)
|
||||
|
|
@ -201,6 +221,30 @@ class OptionsTest(unittest.TestCase):
|
|||
self.assertEqual(options.datetime,
|
||||
datetime.datetime(2013, 4, 28, 5, 16))
|
||||
self.assertEqual(options.timedelta, datetime.timedelta(seconds=45))
|
||||
self.assertEqual(options.email.value, 'tornado@web.com')
|
||||
self.assertTrue(isinstance(options.email, Email))
|
||||
self.assertEqual(options.list_of_int, [1, 2, 3])
|
||||
|
||||
def test_types(self):
|
||||
options = self._define_options()
|
||||
options.parse_command_line(['main.py',
|
||||
'--str=asdf',
|
||||
'--basestring=qwer',
|
||||
'--int=42',
|
||||
'--float=1.5',
|
||||
'--datetime=2013-04-28 05:16',
|
||||
'--timedelta=45s',
|
||||
'--email=tornado@web.com',
|
||||
'--list-of-int=1,2,3'])
|
||||
self._check_options_values(options)
|
||||
|
||||
def test_types_with_conf_file(self):
|
||||
for config_file_name in ("options_test_types.cfg",
|
||||
"options_test_types_str.cfg"):
|
||||
options = self._define_options()
|
||||
options.parse_config_file(os.path.join(os.path.dirname(__file__),
|
||||
config_file_name))
|
||||
self._check_options_values(options)
|
||||
|
||||
def test_multiple_string(self):
|
||||
options = OptionParser()
|
||||
|
|
@ -222,6 +266,24 @@ class OptionsTest(unittest.TestCase):
|
|||
self.assertRegexpMatches(str(cm.exception),
|
||||
'Option.*foo.*already defined')
|
||||
|
||||
def test_error_redefine_underscore(self):
|
||||
# Ensure that the dash/underscore normalization doesn't
|
||||
# interfere with the redefinition error.
|
||||
tests = [
|
||||
('foo-bar', 'foo-bar'),
|
||||
('foo_bar', 'foo_bar'),
|
||||
('foo-bar', 'foo_bar'),
|
||||
('foo_bar', 'foo-bar'),
|
||||
]
|
||||
for a, b in tests:
|
||||
with subTest(self, a=a, b=b):
|
||||
options = OptionParser()
|
||||
options.define(a)
|
||||
with self.assertRaises(Error) as cm:
|
||||
options.define(b)
|
||||
self.assertRegexpMatches(str(cm.exception),
|
||||
'Option.*foo.bar.*already defined')
|
||||
|
||||
def test_dash_underscore_cli(self):
|
||||
# Dashes and underscores should be interchangeable.
|
||||
for defined_name in ['foo-bar', 'foo_bar']:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,11 @@
|
|||
from datetime import datetime, timedelta
|
||||
from tornado.test.options_test import Email
|
||||
|
||||
str = 'asdf'
|
||||
basestring = 'qwer'
|
||||
int = 42
|
||||
float = 1.5
|
||||
datetime = datetime(2013, 4, 28, 5, 16)
|
||||
timedelta = timedelta(0, 45)
|
||||
email = Email('tornado@web.com')
|
||||
list_of_int = [1, 2, 3]
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
str = 'asdf'
|
||||
basestring = 'qwer'
|
||||
int = 42
|
||||
float = 1.5
|
||||
datetime = '2013-04-28 05:16'
|
||||
timedelta = '45s'
|
||||
email = 'tornado@web.com'
|
||||
list_of_int = '1,2,3'
|
||||
|
|
@ -1,12 +1,11 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from tornado.httpclient import HTTPClient, HTTPError
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.ioloop import IOLoop
|
||||
|
|
@ -17,12 +16,15 @@ from tornado.testing import bind_unused_port, ExpectLog, AsyncTestCase, gen_test
|
|||
from tornado.test.util import unittest, skipIfNonUnix
|
||||
from tornado.web import RequestHandler, Application
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
except ImportError:
|
||||
asyncio = None
|
||||
|
||||
|
||||
def skip_if_twisted():
|
||||
if IOLoop.configured_class().__name__.endswith(('TwistedIOLoop',
|
||||
'AsyncIOMainLoop')):
|
||||
raise unittest.SkipTest("Process tests not compatible with "
|
||||
"TwistedIOLoop or AsyncIOMainLoop")
|
||||
if IOLoop.configured_class().__name__.endswith('TwistedIOLoop'):
|
||||
raise unittest.SkipTest("Process tests not compatible with TwistedIOLoop")
|
||||
|
||||
# Not using AsyncHTTPTestCase because we need control over the IOLoop.
|
||||
|
||||
|
|
@ -58,11 +60,12 @@ class ProcessTest(unittest.TestCase):
|
|||
super(ProcessTest, self).tearDown()
|
||||
|
||||
def test_multi_process(self):
|
||||
# This test can't work on twisted because we use the global reactor
|
||||
# and have no way to get it back into a sane state after the fork.
|
||||
# This test doesn't work on twisted because we use the global
|
||||
# reactor and don't restore it to a sane state after the fork
|
||||
# (asyncio has the same issue, but we have a special case in
|
||||
# place for it).
|
||||
skip_if_twisted()
|
||||
with ExpectLog(gen_log, "(Starting .* processes|child .* exited|uncaught exception)"):
|
||||
self.assertFalse(IOLoop.initialized())
|
||||
sock, port = bind_unused_port()
|
||||
|
||||
def get_url(path):
|
||||
|
|
@ -81,6 +84,10 @@ class ProcessTest(unittest.TestCase):
|
|||
sock.close()
|
||||
return
|
||||
try:
|
||||
if asyncio is not None:
|
||||
# Reset the global asyncio event loop, which was put into
|
||||
# a broken state by the fork.
|
||||
asyncio.set_event_loop(asyncio.new_event_loop())
|
||||
if id in (0, 1):
|
||||
self.assertEqual(id, task_id())
|
||||
server = HTTPServer(self.get_app())
|
||||
|
|
@ -136,6 +143,7 @@ class ProcessTest(unittest.TestCase):
|
|||
|
||||
@skipIfNonUnix
|
||||
class SubprocessTest(AsyncTestCase):
|
||||
@gen_test
|
||||
def test_subprocess(self):
|
||||
if IOLoop.configured_class().__name__.endswith('LayeredTwistedIOLoop'):
|
||||
# This test fails non-deterministically with LayeredTwistedIOLoop.
|
||||
|
|
@ -147,54 +155,52 @@ class SubprocessTest(AsyncTestCase):
|
|||
"LayeredTwistedIOLoop")
|
||||
subproc = Subprocess([sys.executable, '-u', '-i'],
|
||||
stdin=Subprocess.STREAM,
|
||||
stdout=Subprocess.STREAM, stderr=subprocess.STDOUT,
|
||||
io_loop=self.io_loop)
|
||||
self.addCleanup(lambda: os.kill(subproc.pid, signal.SIGTERM))
|
||||
subproc.stdout.read_until(b'>>> ', self.stop)
|
||||
self.wait()
|
||||
stdout=Subprocess.STREAM, stderr=subprocess.STDOUT)
|
||||
self.addCleanup(lambda: (subproc.proc.terminate(), subproc.proc.wait()))
|
||||
self.addCleanup(subproc.stdout.close)
|
||||
self.addCleanup(subproc.stdin.close)
|
||||
yield subproc.stdout.read_until(b'>>> ')
|
||||
subproc.stdin.write(b"print('hello')\n")
|
||||
subproc.stdout.read_until(b'\n', self.stop)
|
||||
data = self.wait()
|
||||
data = yield subproc.stdout.read_until(b'\n')
|
||||
self.assertEqual(data, b"hello\n")
|
||||
|
||||
subproc.stdout.read_until(b">>> ", self.stop)
|
||||
self.wait()
|
||||
yield subproc.stdout.read_until(b">>> ")
|
||||
subproc.stdin.write(b"raise SystemExit\n")
|
||||
subproc.stdout.read_until_close(self.stop)
|
||||
data = self.wait()
|
||||
data = yield subproc.stdout.read_until_close()
|
||||
self.assertEqual(data, b"")
|
||||
|
||||
@gen_test
|
||||
def test_close_stdin(self):
|
||||
# Close the parent's stdin handle and see that the child recognizes it.
|
||||
subproc = Subprocess([sys.executable, '-u', '-i'],
|
||||
stdin=Subprocess.STREAM,
|
||||
stdout=Subprocess.STREAM, stderr=subprocess.STDOUT,
|
||||
io_loop=self.io_loop)
|
||||
self.addCleanup(lambda: os.kill(subproc.pid, signal.SIGTERM))
|
||||
subproc.stdout.read_until(b'>>> ', self.stop)
|
||||
self.wait()
|
||||
stdout=Subprocess.STREAM, stderr=subprocess.STDOUT)
|
||||
self.addCleanup(lambda: (subproc.proc.terminate(), subproc.proc.wait()))
|
||||
yield subproc.stdout.read_until(b'>>> ')
|
||||
subproc.stdin.close()
|
||||
subproc.stdout.read_until_close(self.stop)
|
||||
data = self.wait()
|
||||
data = yield subproc.stdout.read_until_close()
|
||||
self.assertEqual(data, b"\n")
|
||||
|
||||
@gen_test
|
||||
def test_stderr(self):
|
||||
# This test is mysteriously flaky on twisted: it succeeds, but logs
|
||||
# an error of EBADF on closing a file descriptor.
|
||||
skip_if_twisted()
|
||||
subproc = Subprocess([sys.executable, '-u', '-c',
|
||||
r"import sys; sys.stderr.write('hello\n')"],
|
||||
stderr=Subprocess.STREAM,
|
||||
io_loop=self.io_loop)
|
||||
self.addCleanup(lambda: os.kill(subproc.pid, signal.SIGTERM))
|
||||
subproc.stderr.read_until(b'\n', self.stop)
|
||||
data = self.wait()
|
||||
stderr=Subprocess.STREAM)
|
||||
self.addCleanup(lambda: (subproc.proc.terminate(), subproc.proc.wait()))
|
||||
data = yield subproc.stderr.read_until(b'\n')
|
||||
self.assertEqual(data, b'hello\n')
|
||||
# More mysterious EBADF: This fails if done with self.addCleanup instead of here.
|
||||
subproc.stderr.close()
|
||||
|
||||
def test_sigchild(self):
|
||||
# Twisted's SIGCHLD handler and Subprocess's conflict with each other.
|
||||
skip_if_twisted()
|
||||
Subprocess.initialize(io_loop=self.io_loop)
|
||||
Subprocess.initialize()
|
||||
self.addCleanup(Subprocess.uninitialize)
|
||||
subproc = Subprocess([sys.executable, '-c', 'pass'],
|
||||
io_loop=self.io_loop)
|
||||
subproc = Subprocess([sys.executable, '-c', 'pass'])
|
||||
subproc.set_exit_callback(self.stop)
|
||||
ret = self.wait()
|
||||
self.assertEqual(ret, 0)
|
||||
|
|
@ -212,14 +218,31 @@ class SubprocessTest(AsyncTestCase):
|
|||
|
||||
def test_sigchild_signal(self):
|
||||
skip_if_twisted()
|
||||
Subprocess.initialize(io_loop=self.io_loop)
|
||||
Subprocess.initialize()
|
||||
self.addCleanup(Subprocess.uninitialize)
|
||||
subproc = Subprocess([sys.executable, '-c',
|
||||
'import time; time.sleep(30)'],
|
||||
io_loop=self.io_loop)
|
||||
stdout=Subprocess.STREAM)
|
||||
self.addCleanup(subproc.stdout.close)
|
||||
subproc.set_exit_callback(self.stop)
|
||||
os.kill(subproc.pid, signal.SIGTERM)
|
||||
ret = self.wait()
|
||||
try:
|
||||
ret = self.wait(timeout=1.0)
|
||||
except AssertionError:
|
||||
# We failed to get the termination signal. This test is
|
||||
# occasionally flaky on pypy, so try to get a little more
|
||||
# information: did the process close its stdout
|
||||
# (indicating that the problem is in the parent process's
|
||||
# signal handling) or did the child process somehow fail
|
||||
# to terminate?
|
||||
subproc.stdout.read_until_close(callback=self.stop)
|
||||
try:
|
||||
self.wait(timeout=1.0)
|
||||
except AssertionError:
|
||||
raise AssertionError("subprocess failed to terminate")
|
||||
else:
|
||||
raise AssertionError("subprocess closed stdout but failed to "
|
||||
"get termination signal")
|
||||
self.assertEqual(subproc.returncode, ret)
|
||||
self.assertEqual(ret, -signal.SIGTERM)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@
|
|||
# under the License.
|
||||
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
from datetime import timedelta
|
||||
from random import random
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.netutil import ThreadedResolver
|
||||
from tornado.util import u
|
||||
|
||||
# When this module is imported, it runs getaddrinfo on a thread. Since
|
||||
# the hostname is unicode, getaddrinfo attempts to import encodings.idna
|
||||
|
|
@ -9,4 +8,4 @@ from tornado.util import u
|
|||
# this deadlock.
|
||||
|
||||
resolver = ThreadedResolver()
|
||||
IOLoop.current().run_sync(lambda: resolver.resolve(u('localhost'), 80))
|
||||
IOLoop.current().run_sync(lambda: resolver.resolve(u'localhost', 80))
|
||||
|
|
|
|||
247
Shared/lib/python3.4/site-packages/tornado/test/routing_test.py
Normal file
247
Shared/lib/python3.4/site-packages/tornado/test/routing_test.py
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine # noqa: E501
|
||||
from tornado.routing import HostMatches, PathMatches, ReversibleRouter, Router, Rule, RuleRouter
|
||||
from tornado.testing import AsyncHTTPTestCase
|
||||
from tornado.web import Application, HTTPError, RequestHandler
|
||||
from tornado.wsgi import WSGIContainer
|
||||
|
||||
|
||||
class BasicRouter(Router):
|
||||
def find_handler(self, request, **kwargs):
|
||||
|
||||
class MessageDelegate(HTTPMessageDelegate):
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def finish(self):
|
||||
self.connection.write_headers(
|
||||
ResponseStartLine("HTTP/1.1", 200, "OK"),
|
||||
HTTPHeaders({"Content-Length": "2"}),
|
||||
b"OK"
|
||||
)
|
||||
self.connection.finish()
|
||||
|
||||
return MessageDelegate(request.connection)
|
||||
|
||||
|
||||
class BasicRouterTestCase(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
return BasicRouter()
|
||||
|
||||
def test_basic_router(self):
|
||||
response = self.fetch("/any_request")
|
||||
self.assertEqual(response.body, b"OK")
|
||||
|
||||
|
||||
resources = {}
|
||||
|
||||
|
||||
class GetResource(RequestHandler):
|
||||
def get(self, path):
|
||||
if path not in resources:
|
||||
raise HTTPError(404)
|
||||
|
||||
self.finish(resources[path])
|
||||
|
||||
|
||||
class PostResource(RequestHandler):
|
||||
def post(self, path):
|
||||
resources[path] = self.request.body
|
||||
|
||||
|
||||
class HTTPMethodRouter(Router):
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
def find_handler(self, request, **kwargs):
|
||||
handler = GetResource if request.method == "GET" else PostResource
|
||||
return self.app.get_handler_delegate(request, handler, path_args=[request.path])
|
||||
|
||||
|
||||
class HTTPMethodRouterTestCase(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
return HTTPMethodRouter(Application())
|
||||
|
||||
def test_http_method_router(self):
|
||||
response = self.fetch("/post_resource", method="POST", body="data")
|
||||
self.assertEqual(response.code, 200)
|
||||
|
||||
response = self.fetch("/get_resource")
|
||||
self.assertEqual(response.code, 404)
|
||||
|
||||
response = self.fetch("/post_resource")
|
||||
self.assertEqual(response.code, 200)
|
||||
self.assertEqual(response.body, b"data")
|
||||
|
||||
|
||||
def _get_named_handler(handler_name):
|
||||
class Handler(RequestHandler):
|
||||
def get(self, *args, **kwargs):
|
||||
if self.application.settings.get("app_name") is not None:
|
||||
self.write(self.application.settings["app_name"] + ": ")
|
||||
|
||||
self.finish(handler_name + ": " + self.reverse_url(handler_name))
|
||||
|
||||
return Handler
|
||||
|
||||
|
||||
FirstHandler = _get_named_handler("first_handler")
|
||||
SecondHandler = _get_named_handler("second_handler")
|
||||
|
||||
|
||||
class CustomRouter(ReversibleRouter):
|
||||
def __init__(self):
|
||||
super(CustomRouter, self).__init__()
|
||||
self.routes = {}
|
||||
|
||||
def add_routes(self, routes):
|
||||
self.routes.update(routes)
|
||||
|
||||
def find_handler(self, request, **kwargs):
|
||||
if request.path in self.routes:
|
||||
app, handler = self.routes[request.path]
|
||||
return app.get_handler_delegate(request, handler)
|
||||
|
||||
def reverse_url(self, name, *args):
|
||||
handler_path = '/' + name
|
||||
return handler_path if handler_path in self.routes else None
|
||||
|
||||
|
||||
class CustomRouterTestCase(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
class CustomApplication(Application):
|
||||
def reverse_url(self, name, *args):
|
||||
return router.reverse_url(name, *args)
|
||||
|
||||
router = CustomRouter()
|
||||
app1 = CustomApplication(app_name="app1")
|
||||
app2 = CustomApplication(app_name="app2")
|
||||
|
||||
router.add_routes({
|
||||
"/first_handler": (app1, FirstHandler),
|
||||
"/second_handler": (app2, SecondHandler),
|
||||
"/first_handler_second_app": (app2, FirstHandler),
|
||||
})
|
||||
|
||||
return router
|
||||
|
||||
def test_custom_router(self):
|
||||
response = self.fetch("/first_handler")
|
||||
self.assertEqual(response.body, b"app1: first_handler: /first_handler")
|
||||
response = self.fetch("/second_handler")
|
||||
self.assertEqual(response.body, b"app2: second_handler: /second_handler")
|
||||
response = self.fetch("/first_handler_second_app")
|
||||
self.assertEqual(response.body, b"app2: first_handler: /first_handler")
|
||||
|
||||
|
||||
class ConnectionDelegate(HTTPServerConnectionDelegate):
|
||||
def start_request(self, server_conn, request_conn):
|
||||
|
||||
class MessageDelegate(HTTPMessageDelegate):
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def finish(self):
|
||||
response_body = b"OK"
|
||||
self.connection.write_headers(
|
||||
ResponseStartLine("HTTP/1.1", 200, "OK"),
|
||||
HTTPHeaders({"Content-Length": str(len(response_body))}))
|
||||
self.connection.write(response_body)
|
||||
self.connection.finish()
|
||||
|
||||
return MessageDelegate(request_conn)
|
||||
|
||||
|
||||
class RuleRouterTest(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
app = Application()
|
||||
|
||||
def request_callable(request):
|
||||
request.connection.write_headers(
|
||||
ResponseStartLine("HTTP/1.1", 200, "OK"),
|
||||
HTTPHeaders({"Content-Length": "2"}))
|
||||
request.connection.write(b"OK")
|
||||
request.connection.finish()
|
||||
|
||||
router = CustomRouter()
|
||||
router.add_routes({
|
||||
"/nested_handler": (app, _get_named_handler("nested_handler"))
|
||||
})
|
||||
|
||||
app.add_handlers(".*", [
|
||||
(HostMatches("www.example.com"), [
|
||||
(PathMatches("/first_handler"),
|
||||
"tornado.test.routing_test.SecondHandler", {}, "second_handler")
|
||||
]),
|
||||
Rule(PathMatches("/.*handler"), router),
|
||||
Rule(PathMatches("/first_handler"), FirstHandler, name="first_handler"),
|
||||
Rule(PathMatches("/request_callable"), request_callable),
|
||||
("/connection_delegate", ConnectionDelegate())
|
||||
])
|
||||
|
||||
return app
|
||||
|
||||
def test_rule_based_router(self):
|
||||
response = self.fetch("/first_handler")
|
||||
self.assertEqual(response.body, b"first_handler: /first_handler")
|
||||
|
||||
response = self.fetch("/first_handler", headers={'Host': 'www.example.com'})
|
||||
self.assertEqual(response.body, b"second_handler: /first_handler")
|
||||
|
||||
response = self.fetch("/nested_handler")
|
||||
self.assertEqual(response.body, b"nested_handler: /nested_handler")
|
||||
|
||||
response = self.fetch("/nested_not_found_handler")
|
||||
self.assertEqual(response.code, 404)
|
||||
|
||||
response = self.fetch("/connection_delegate")
|
||||
self.assertEqual(response.body, b"OK")
|
||||
|
||||
response = self.fetch("/request_callable")
|
||||
self.assertEqual(response.body, b"OK")
|
||||
|
||||
response = self.fetch("/404")
|
||||
self.assertEqual(response.code, 404)
|
||||
|
||||
|
||||
class WSGIContainerTestCase(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
wsgi_app = WSGIContainer(self.wsgi_app)
|
||||
|
||||
class Handler(RequestHandler):
|
||||
def get(self, *args, **kwargs):
|
||||
self.finish(self.reverse_url("tornado"))
|
||||
|
||||
return RuleRouter([
|
||||
(PathMatches("/tornado.*"), Application([(r"/tornado/test", Handler, {}, "tornado")])),
|
||||
(PathMatches("/wsgi"), wsgi_app),
|
||||
])
|
||||
|
||||
def wsgi_app(self, environ, start_response):
|
||||
start_response("200 OK", [])
|
||||
return [b"WSGI"]
|
||||
|
||||
def test_wsgi_container(self):
|
||||
response = self.fetch("/tornado/test")
|
||||
self.assertEqual(response.body, b"/tornado/test")
|
||||
|
||||
response = self.fetch("/wsgi")
|
||||
self.assertEqual(response.body, b"WSGI")
|
||||
|
||||
def test_delegate_not_found(self):
|
||||
response = self.fetch("/404")
|
||||
self.assertEqual(response.code, 404)
|
||||
|
|
@ -1,12 +1,13 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
import gc
|
||||
import io
|
||||
import locale # system locale module, not tornado.locale
|
||||
import logging
|
||||
import operator
|
||||
import textwrap
|
||||
import sys
|
||||
|
||||
from tornado.httpclient import AsyncHTTPClient
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.ioloop import IOLoop
|
||||
|
|
@ -25,10 +26,12 @@ TEST_MODULES = [
|
|||
'tornado.util.doctests',
|
||||
'tornado.test.asyncio_test',
|
||||
'tornado.test.auth_test',
|
||||
'tornado.test.autoreload_test',
|
||||
'tornado.test.concurrent_test',
|
||||
'tornado.test.curl_httpclient_test',
|
||||
'tornado.test.escape_test',
|
||||
'tornado.test.gen_test',
|
||||
'tornado.test.http1connection_test',
|
||||
'tornado.test.httpclient_test',
|
||||
'tornado.test.httpserver_test',
|
||||
'tornado.test.httputil_test',
|
||||
|
|
@ -42,6 +45,7 @@ TEST_MODULES = [
|
|||
'tornado.test.options_test',
|
||||
'tornado.test.process_test',
|
||||
'tornado.test.queues_test',
|
||||
'tornado.test.routing_test',
|
||||
'tornado.test.simple_httpclient_test',
|
||||
'tornado.test.stack_context_test',
|
||||
'tornado.test.tcpclient_test',
|
||||
|
|
@ -52,6 +56,7 @@ TEST_MODULES = [
|
|||
'tornado.test.util_test',
|
||||
'tornado.test.web_test',
|
||||
'tornado.test.websocket_test',
|
||||
'tornado.test.windows_test',
|
||||
'tornado.test.wsgi_test',
|
||||
]
|
||||
|
||||
|
|
@ -60,16 +65,21 @@ def all():
|
|||
return unittest.defaultTestLoader.loadTestsFromNames(TEST_MODULES)
|
||||
|
||||
|
||||
class TornadoTextTestRunner(unittest.TextTestRunner):
|
||||
def run(self, test):
|
||||
result = super(TornadoTextTestRunner, self).run(test)
|
||||
if result.skipped:
|
||||
skip_reasons = set(reason for (test, reason) in result.skipped)
|
||||
self.stream.write(textwrap.fill(
|
||||
"Some tests were skipped because: %s" %
|
||||
", ".join(sorted(skip_reasons))))
|
||||
self.stream.write("\n")
|
||||
return result
|
||||
def test_runner_factory(stderr):
|
||||
class TornadoTextTestRunner(unittest.TextTestRunner):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TornadoTextTestRunner, self).__init__(*args, stream=stderr, **kwargs)
|
||||
|
||||
def run(self, test):
|
||||
result = super(TornadoTextTestRunner, self).run(test)
|
||||
if result.skipped:
|
||||
skip_reasons = set(reason for (test, reason) in result.skipped)
|
||||
self.stream.write(textwrap.fill(
|
||||
"Some tests were skipped because: %s" %
|
||||
", ".join(sorted(skip_reasons))))
|
||||
self.stream.write("\n")
|
||||
return result
|
||||
return TornadoTextTestRunner
|
||||
|
||||
|
||||
class LogCounter(logging.Filter):
|
||||
|
|
@ -77,16 +87,31 @@ class LogCounter(logging.Filter):
|
|||
def __init__(self, *args, **kwargs):
|
||||
# Can't use super() because logging.Filter is an old-style class in py26
|
||||
logging.Filter.__init__(self, *args, **kwargs)
|
||||
self.warning_count = self.error_count = 0
|
||||
self.info_count = self.warning_count = self.error_count = 0
|
||||
|
||||
def filter(self, record):
|
||||
if record.levelno >= logging.ERROR:
|
||||
self.error_count += 1
|
||||
elif record.levelno >= logging.WARNING:
|
||||
self.warning_count += 1
|
||||
elif record.levelno >= logging.INFO:
|
||||
self.info_count += 1
|
||||
return True
|
||||
|
||||
|
||||
class CountingStderr(io.IOBase):
|
||||
def __init__(self, real):
|
||||
self.real = real
|
||||
self.byte_count = 0
|
||||
|
||||
def write(self, data):
|
||||
self.byte_count += len(data)
|
||||
return self.real.write(data)
|
||||
|
||||
def flush(self):
|
||||
return self.real.flush()
|
||||
|
||||
|
||||
def main():
|
||||
# The -W command-line option does not work in a virtualenv with
|
||||
# python 3 (as of virtualenv 1.7), so configure warnings
|
||||
|
|
@ -112,13 +137,18 @@ def main():
|
|||
# 2.7 and 3.2
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning,
|
||||
message="Please use assert.* instead")
|
||||
# unittest2 0.6 on py26 reports these as PendingDeprecationWarnings
|
||||
# instead of DeprecationWarnings.
|
||||
warnings.filterwarnings("ignore", category=PendingDeprecationWarning,
|
||||
message="Please use assert.* instead")
|
||||
# Twisted 15.0.0 triggers some warnings on py3 with -bb.
|
||||
warnings.filterwarnings("ignore", category=BytesWarning,
|
||||
module=r"twisted\..*")
|
||||
if (3,) < sys.version_info < (3, 6):
|
||||
# Prior to 3.6, async ResourceWarnings were rather noisy
|
||||
# and even
|
||||
# `python3.4 -W error -c 'import asyncio; asyncio.get_event_loop()'`
|
||||
# would generate a warning.
|
||||
warnings.filterwarnings("ignore", category=ResourceWarning, # noqa: F821
|
||||
module=r"asyncio\..*")
|
||||
|
||||
logging.getLogger("tornado.access").setLevel(logging.CRITICAL)
|
||||
|
||||
|
|
@ -154,6 +184,12 @@ def main():
|
|||
add_parse_callback(
|
||||
lambda: logging.getLogger().handlers[0].addFilter(log_counter))
|
||||
|
||||
# Certain errors (especially "unclosed resource" errors raised in
|
||||
# destructors) go directly to stderr instead of logging. Count
|
||||
# anything written by anything but the test runner as an error.
|
||||
orig_stderr = sys.stderr
|
||||
sys.stderr = CountingStderr(orig_stderr)
|
||||
|
||||
import tornado.testing
|
||||
kwargs = {}
|
||||
if sys.version_info >= (3, 2):
|
||||
|
|
@ -163,17 +199,21 @@ def main():
|
|||
# suppresses this behavior, although this looks like an implementation
|
||||
# detail. http://bugs.python.org/issue15626
|
||||
kwargs['warnings'] = False
|
||||
kwargs['testRunner'] = TornadoTextTestRunner
|
||||
kwargs['testRunner'] = test_runner_factory(orig_stderr)
|
||||
try:
|
||||
tornado.testing.main(**kwargs)
|
||||
finally:
|
||||
# The tests should run clean; consider it a failure if they logged
|
||||
# any warnings or errors. We'd like to ban info logs too, but
|
||||
# we can't count them cleanly due to interactions with LogTrapTestCase.
|
||||
if log_counter.warning_count > 0 or log_counter.error_count > 0:
|
||||
logging.error("logged %d warnings and %d errors",
|
||||
log_counter.warning_count, log_counter.error_count)
|
||||
# The tests should run clean; consider it a failure if they
|
||||
# logged anything at info level or above.
|
||||
if (log_counter.info_count > 0 or
|
||||
log_counter.warning_count > 0 or
|
||||
log_counter.error_count > 0 or
|
||||
sys.stderr.byte_count > 0):
|
||||
logging.error("logged %d infos, %d warnings, %d errors, and %d bytes to stderr",
|
||||
log_counter.info_count, log_counter.warning_count,
|
||||
log_counter.error_count, sys.stderr.byte_count)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import collections
|
||||
from contextlib import closing
|
||||
|
|
@ -11,25 +11,28 @@ import socket
|
|||
import ssl
|
||||
import sys
|
||||
|
||||
from tornado.escape import to_unicode
|
||||
from tornado.escape import to_unicode, utf8
|
||||
from tornado import gen
|
||||
from tornado.httpclient import AsyncHTTPClient
|
||||
from tornado.httputil import HTTPHeaders, ResponseStartLine
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.iostream import UnsatisfiableReadError
|
||||
from tornado.locks import Event
|
||||
from tornado.log import gen_log
|
||||
from tornado.concurrent import Future
|
||||
from tornado.netutil import Resolver, bind_sockets
|
||||
from tornado.simple_httpclient import SimpleAsyncHTTPClient
|
||||
from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler, RedirectHandler
|
||||
from tornado.simple_httpclient import SimpleAsyncHTTPClient, HTTPStreamClosedError, HTTPTimeoutError
|
||||
from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler, RedirectHandler # noqa: E501
|
||||
from tornado.test import httpclient_test
|
||||
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog
|
||||
from tornado.test.util import skipOnTravis, skipIfNoIPv6, refusing_port, unittest, skipBefore35, exec_test
|
||||
from tornado.web import RequestHandler, Application, asynchronous, url, stream_request_body
|
||||
from tornado.testing import (AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase,
|
||||
ExpectLog, gen_test)
|
||||
from tornado.test.util import skipOnTravis, skipIfNoIPv6, refusing_port, skipBefore35, exec_test
|
||||
from tornado.web import RequestHandler, Application, url, stream_request_body
|
||||
|
||||
|
||||
class SimpleHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
|
||||
def get_http_client(self):
|
||||
client = SimpleAsyncHTTPClient(io_loop=self.io_loop,
|
||||
force_instance=True)
|
||||
client = SimpleAsyncHTTPClient(force_instance=True)
|
||||
self.assertTrue(isinstance(client, SimpleAsyncHTTPClient))
|
||||
return client
|
||||
|
||||
|
|
@ -39,24 +42,33 @@ class TriggerHandler(RequestHandler):
|
|||
self.queue = queue
|
||||
self.wake_callback = wake_callback
|
||||
|
||||
@asynchronous
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
logging.debug("queuing trigger")
|
||||
self.queue.append(self.finish)
|
||||
if self.get_argument("wake", "true") == "true":
|
||||
self.wake_callback()
|
||||
never_finish = Event()
|
||||
yield never_finish.wait()
|
||||
|
||||
|
||||
class HangHandler(RequestHandler):
|
||||
@asynchronous
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
pass
|
||||
never_finish = Event()
|
||||
yield never_finish.wait()
|
||||
|
||||
|
||||
class ContentLengthHandler(RequestHandler):
|
||||
def get(self):
|
||||
self.set_header("Content-Length", self.get_argument("value"))
|
||||
self.write("ok")
|
||||
self.stream = self.detach()
|
||||
IOLoop.current().spawn_callback(self.write_response)
|
||||
|
||||
@gen.coroutine
|
||||
def write_response(self):
|
||||
yield self.stream.write(utf8("HTTP/1.0 200 OK\r\nContent-Length: %s\r\n\r\nok" %
|
||||
self.get_argument("value")))
|
||||
self.stream.close()
|
||||
|
||||
|
||||
class HeadHandler(RequestHandler):
|
||||
|
|
@ -72,10 +84,8 @@ class OptionsHandler(RequestHandler):
|
|||
|
||||
class NoContentHandler(RequestHandler):
|
||||
def get(self):
|
||||
if self.get_argument("error", None):
|
||||
self.set_header("Content-Length", "5")
|
||||
self.write("hello")
|
||||
self.set_status(204)
|
||||
self.finish()
|
||||
|
||||
|
||||
class SeeOtherPostHandler(RequestHandler):
|
||||
|
|
@ -99,13 +109,12 @@ class HostEchoHandler(RequestHandler):
|
|||
|
||||
|
||||
class NoContentLengthHandler(RequestHandler):
|
||||
@asynchronous
|
||||
def get(self):
|
||||
if self.request.version.startswith('HTTP/1'):
|
||||
# Emulate the old HTTP/1.0 behavior of returning a body with no
|
||||
# content-length. Tornado handles content-length at the framework
|
||||
# level so we have to go around it.
|
||||
stream = self.request.connection.detach()
|
||||
stream = self.detach()
|
||||
stream.write(b"HTTP/1.0 200 OK\r\n\r\n"
|
||||
b"hello")
|
||||
stream.close()
|
||||
|
|
@ -151,16 +160,16 @@ class SimpleHTTPClientTestMixin(object):
|
|||
|
||||
def test_singleton(self):
|
||||
# Class "constructor" reuses objects on the same IOLoop
|
||||
self.assertTrue(SimpleAsyncHTTPClient(self.io_loop) is
|
||||
SimpleAsyncHTTPClient(self.io_loop))
|
||||
self.assertTrue(SimpleAsyncHTTPClient() is
|
||||
SimpleAsyncHTTPClient())
|
||||
# unless force_instance is used
|
||||
self.assertTrue(SimpleAsyncHTTPClient(self.io_loop) is not
|
||||
SimpleAsyncHTTPClient(self.io_loop,
|
||||
force_instance=True))
|
||||
self.assertTrue(SimpleAsyncHTTPClient() is not
|
||||
SimpleAsyncHTTPClient(force_instance=True))
|
||||
# different IOLoops use different objects
|
||||
with closing(IOLoop()) as io_loop2:
|
||||
self.assertTrue(SimpleAsyncHTTPClient(self.io_loop) is not
|
||||
SimpleAsyncHTTPClient(io_loop2))
|
||||
client1 = self.io_loop.run_sync(gen.coroutine(SimpleAsyncHTTPClient))
|
||||
client2 = io_loop2.run_sync(gen.coroutine(SimpleAsyncHTTPClient))
|
||||
self.assertTrue(client1 is not client2)
|
||||
|
||||
def test_connection_limit(self):
|
||||
with closing(self.create_client(max_clients=2)) as client:
|
||||
|
|
@ -169,8 +178,8 @@ class SimpleHTTPClientTestMixin(object):
|
|||
# Send 4 requests. Two can be sent immediately, while the others
|
||||
# will be queued
|
||||
for i in range(4):
|
||||
client.fetch(self.get_url("/trigger"),
|
||||
lambda response, i=i: (seen.append(i), self.stop()))
|
||||
client.fetch(self.get_url("/trigger")).add_done_callback(
|
||||
lambda fut, i=i: (seen.append(i), self.stop()))
|
||||
self.wait(condition=lambda: len(self.triggers) == 2)
|
||||
self.assertEqual(len(client.queue), 2)
|
||||
|
||||
|
|
@ -189,12 +198,12 @@ class SimpleHTTPClientTestMixin(object):
|
|||
self.assertEqual(set(seen), set([0, 1, 2, 3]))
|
||||
self.assertEqual(len(self.triggers), 0)
|
||||
|
||||
@gen_test
|
||||
def test_redirect_connection_limit(self):
|
||||
# following redirects should not consume additional connections
|
||||
with closing(self.create_client(max_clients=1)) as client:
|
||||
client.fetch(self.get_url('/countdown/3'), self.stop,
|
||||
max_redirects=3)
|
||||
response = self.wait()
|
||||
response = yield client.fetch(self.get_url('/countdown/3'),
|
||||
max_redirects=3)
|
||||
response.rethrow()
|
||||
|
||||
def test_gzip(self):
|
||||
|
|
@ -237,55 +246,58 @@ class SimpleHTTPClientTestMixin(object):
|
|||
# request is the original request, is a POST still
|
||||
self.assertEqual("POST", response.request.method)
|
||||
|
||||
@skipOnTravis
|
||||
@gen_test
|
||||
def test_connect_timeout(self):
|
||||
timeout = 0.1
|
||||
|
||||
class TimeoutResolver(Resolver):
|
||||
def resolve(self, *args, **kwargs):
|
||||
return Future() # never completes
|
||||
|
||||
with closing(self.create_client(resolver=TimeoutResolver())) as client:
|
||||
with self.assertRaises(HTTPTimeoutError):
|
||||
yield client.fetch(self.get_url('/hello'),
|
||||
connect_timeout=timeout,
|
||||
request_timeout=3600,
|
||||
raise_error=True)
|
||||
|
||||
@skipOnTravis
|
||||
def test_request_timeout(self):
|
||||
timeout = 0.1
|
||||
timeout_min, timeout_max = 0.099, 0.15
|
||||
if os.name == 'nt':
|
||||
timeout = 0.5
|
||||
timeout_min, timeout_max = 0.4, 0.6
|
||||
|
||||
response = self.fetch('/trigger?wake=false', request_timeout=timeout)
|
||||
self.assertEqual(response.code, 599)
|
||||
self.assertTrue(timeout_min < response.request_time < timeout_max,
|
||||
response.request_time)
|
||||
self.assertEqual(str(response.error), "HTTP 599: Timeout")
|
||||
with self.assertRaises(HTTPTimeoutError):
|
||||
self.fetch('/trigger?wake=false', request_timeout=timeout, raise_error=True)
|
||||
# trigger the hanging request to let it clean up after itself
|
||||
self.triggers.popleft()()
|
||||
|
||||
@skipIfNoIPv6
|
||||
def test_ipv6(self):
|
||||
try:
|
||||
[sock] = bind_sockets(None, '::1', family=socket.AF_INET6)
|
||||
port = sock.getsockname()[1]
|
||||
self.http_server.add_socket(sock)
|
||||
except socket.gaierror as e:
|
||||
if e.args[0] == socket.EAI_ADDRFAMILY:
|
||||
# python supports ipv6, but it's not configured on the network
|
||||
# interface, so skip this test.
|
||||
return
|
||||
raise
|
||||
[sock] = bind_sockets(None, '::1', family=socket.AF_INET6)
|
||||
port = sock.getsockname()[1]
|
||||
self.http_server.add_socket(sock)
|
||||
url = '%s://[::1]:%d/hello' % (self.get_protocol(), port)
|
||||
|
||||
# ipv6 is currently enabled by default but can be disabled
|
||||
self.http_client.fetch(url, self.stop, allow_ipv6=False)
|
||||
response = self.wait()
|
||||
self.assertEqual(response.code, 599)
|
||||
with self.assertRaises(Exception):
|
||||
self.fetch(url, allow_ipv6=False, raise_error=True)
|
||||
|
||||
self.http_client.fetch(url, self.stop)
|
||||
response = self.wait()
|
||||
response = self.fetch(url)
|
||||
self.assertEqual(response.body, b"Hello world!")
|
||||
|
||||
def xtest_multiple_content_length_accepted(self):
|
||||
def test_multiple_content_length_accepted(self):
|
||||
response = self.fetch("/content_length?value=2,2")
|
||||
self.assertEqual(response.body, b"ok")
|
||||
response = self.fetch("/content_length?value=2,%202,2")
|
||||
self.assertEqual(response.body, b"ok")
|
||||
|
||||
response = self.fetch("/content_length?value=2,4")
|
||||
self.assertEqual(response.code, 599)
|
||||
response = self.fetch("/content_length?value=2,%202,3")
|
||||
self.assertEqual(response.code, 599)
|
||||
with ExpectLog(gen_log, ".*Multiple unequal Content-Lengths"):
|
||||
with self.assertRaises(HTTPStreamClosedError):
|
||||
self.fetch("/content_length?value=2,4", raise_error=True)
|
||||
with self.assertRaises(HTTPStreamClosedError):
|
||||
self.fetch("/content_length?value=2,%202,3", raise_error=True)
|
||||
|
||||
def test_head_request(self):
|
||||
response = self.fetch("/head", method="HEAD")
|
||||
|
|
@ -303,63 +315,52 @@ class SimpleHTTPClientTestMixin(object):
|
|||
def test_no_content(self):
|
||||
response = self.fetch("/no_content")
|
||||
self.assertEqual(response.code, 204)
|
||||
# 204 status doesn't need a content-length, but tornado will
|
||||
# add a zero content-length anyway.
|
||||
# 204 status shouldn't have a content-length
|
||||
#
|
||||
# A test without a content-length header is included below
|
||||
# Tests with a content-length header are included below
|
||||
# in HTTP204NoContentTestCase.
|
||||
self.assertEqual(response.headers["Content-length"], "0")
|
||||
|
||||
# 204 status with non-zero content length is malformed
|
||||
with ExpectLog(gen_log, "Malformed HTTP message"):
|
||||
response = self.fetch("/no_content?error=1")
|
||||
self.assertEqual(response.code, 599)
|
||||
self.assertNotIn("Content-Length", response.headers)
|
||||
|
||||
def test_host_header(self):
|
||||
host_re = re.compile(b"^localhost:[0-9]+$")
|
||||
host_re = re.compile(b"^127.0.0.1:[0-9]+$")
|
||||
response = self.fetch("/host_echo")
|
||||
self.assertTrue(host_re.match(response.body))
|
||||
|
||||
url = self.get_url("/host_echo").replace("http://", "http://me:secret@")
|
||||
self.http_client.fetch(url, self.stop)
|
||||
response = self.wait()
|
||||
response = self.fetch(url)
|
||||
self.assertTrue(host_re.match(response.body), response.body)
|
||||
|
||||
def test_connection_refused(self):
|
||||
cleanup_func, port = refusing_port()
|
||||
self.addCleanup(cleanup_func)
|
||||
with ExpectLog(gen_log, ".*", required=False):
|
||||
self.http_client.fetch("http://127.0.0.1:%d/" % port, self.stop)
|
||||
response = self.wait()
|
||||
self.assertEqual(599, response.code)
|
||||
with self.assertRaises(socket.error) as cm:
|
||||
self.fetch("http://127.0.0.1:%d/" % port, raise_error=True)
|
||||
|
||||
if sys.platform != 'cygwin':
|
||||
# cygwin returns EPERM instead of ECONNREFUSED here
|
||||
contains_errno = str(errno.ECONNREFUSED) in str(response.error)
|
||||
contains_errno = str(errno.ECONNREFUSED) in str(cm.exception)
|
||||
if not contains_errno and hasattr(errno, "WSAECONNREFUSED"):
|
||||
contains_errno = str(errno.WSAECONNREFUSED) in str(response.error)
|
||||
self.assertTrue(contains_errno, response.error)
|
||||
contains_errno = str(errno.WSAECONNREFUSED) in str(cm.exception)
|
||||
self.assertTrue(contains_errno, cm.exception)
|
||||
# This is usually "Connection refused".
|
||||
# On windows, strerror is broken and returns "Unknown error".
|
||||
expected_message = os.strerror(errno.ECONNREFUSED)
|
||||
self.assertTrue(expected_message in str(response.error),
|
||||
response.error)
|
||||
self.assertTrue(expected_message in str(cm.exception),
|
||||
cm.exception)
|
||||
|
||||
def test_queue_timeout(self):
|
||||
with closing(self.create_client(max_clients=1)) as client:
|
||||
client.fetch(self.get_url('/trigger'), self.stop,
|
||||
request_timeout=10)
|
||||
# Wait for the trigger request to block, not complete.
|
||||
fut1 = client.fetch(self.get_url('/trigger'), request_timeout=10)
|
||||
self.wait()
|
||||
client.fetch(self.get_url('/hello'), self.stop,
|
||||
connect_timeout=0.1)
|
||||
response = self.wait()
|
||||
with self.assertRaises(HTTPTimeoutError) as cm:
|
||||
self.io_loop.run_sync(lambda: client.fetch(
|
||||
self.get_url('/hello'), connect_timeout=0.1, raise_error=True))
|
||||
|
||||
self.assertEqual(response.code, 599)
|
||||
self.assertTrue(response.request_time < 1, response.request_time)
|
||||
self.assertEqual(str(response.error), "HTTP 599: Timeout")
|
||||
self.assertEqual(str(cm.exception), "Timeout in request queue")
|
||||
self.triggers.popleft()()
|
||||
self.wait()
|
||||
self.io_loop.run_sync(lambda: fut1)
|
||||
|
||||
def test_no_content_length(self):
|
||||
response = self.fetch("/no_content_length")
|
||||
|
|
@ -375,7 +376,7 @@ class SimpleHTTPClientTestMixin(object):
|
|||
@gen.coroutine
|
||||
def async_body_producer(self, write):
|
||||
yield write(b'1234')
|
||||
yield gen.Task(IOLoop.current().add_callback)
|
||||
yield gen.moment
|
||||
yield write(b'5678')
|
||||
|
||||
def test_sync_body_producer_chunked(self):
|
||||
|
|
@ -409,7 +410,8 @@ class SimpleHTTPClientTestMixin(object):
|
|||
namespace = exec_test(globals(), locals(), """
|
||||
async def body_producer(write):
|
||||
await write(b'1234')
|
||||
await gen.Task(IOLoop.current().add_callback)
|
||||
import asyncio
|
||||
await asyncio.sleep(0)
|
||||
await write(b'5678')
|
||||
""")
|
||||
response = self.fetch("/echo_post", method="POST",
|
||||
|
|
@ -422,7 +424,8 @@ class SimpleHTTPClientTestMixin(object):
|
|||
namespace = exec_test(globals(), locals(), """
|
||||
async def body_producer(write):
|
||||
await write(b'1234')
|
||||
await gen.Task(IOLoop.current().add_callback)
|
||||
import asyncio
|
||||
await asyncio.sleep(0)
|
||||
await write(b'5678')
|
||||
""")
|
||||
response = self.fetch("/echo_post", method="POST",
|
||||
|
|
@ -470,8 +473,7 @@ class SimpleHTTPClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPTestCase):
|
|||
self.http_client = self.create_client()
|
||||
|
||||
def create_client(self, **kwargs):
|
||||
return SimpleAsyncHTTPClient(self.io_loop, force_instance=True,
|
||||
**kwargs)
|
||||
return SimpleAsyncHTTPClient(force_instance=True, **kwargs)
|
||||
|
||||
|
||||
class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase):
|
||||
|
|
@ -480,7 +482,7 @@ class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase):
|
|||
self.http_client = self.create_client()
|
||||
|
||||
def create_client(self, **kwargs):
|
||||
return SimpleAsyncHTTPClient(self.io_loop, force_instance=True,
|
||||
return SimpleAsyncHTTPClient(force_instance=True,
|
||||
defaults=dict(validate_cert=False),
|
||||
**kwargs)
|
||||
|
||||
|
|
@ -488,8 +490,6 @@ class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase):
|
|||
resp = self.fetch("/hello", ssl_options={})
|
||||
self.assertEqual(resp.body, b"Hello world!")
|
||||
|
||||
@unittest.skipIf(not hasattr(ssl, 'SSLContext'),
|
||||
'ssl.SSLContext not present')
|
||||
def test_ssl_context(self):
|
||||
resp = self.fetch("/hello",
|
||||
ssl_options=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
|
||||
|
|
@ -498,27 +498,25 @@ class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase):
|
|||
def test_ssl_options_handshake_fail(self):
|
||||
with ExpectLog(gen_log, "SSL Error|Uncaught exception",
|
||||
required=False):
|
||||
resp = self.fetch(
|
||||
"/hello", ssl_options=dict(cert_reqs=ssl.CERT_REQUIRED))
|
||||
self.assertRaises(ssl.SSLError, resp.rethrow)
|
||||
with self.assertRaises(ssl.SSLError):
|
||||
self.fetch(
|
||||
"/hello", ssl_options=dict(cert_reqs=ssl.CERT_REQUIRED),
|
||||
raise_error=True)
|
||||
|
||||
@unittest.skipIf(not hasattr(ssl, 'SSLContext'),
|
||||
'ssl.SSLContext not present')
|
||||
def test_ssl_context_handshake_fail(self):
|
||||
with ExpectLog(gen_log, "SSL Error|Uncaught exception"):
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
||||
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||
resp = self.fetch("/hello", ssl_options=ctx)
|
||||
self.assertRaises(ssl.SSLError, resp.rethrow)
|
||||
with self.assertRaises(ssl.SSLError):
|
||||
self.fetch("/hello", ssl_options=ctx, raise_error=True)
|
||||
|
||||
def test_error_logging(self):
|
||||
# No stack traces are logged for SSL errors (in this case,
|
||||
# failure to validate the testing self-signed cert).
|
||||
# The SSLError is exposed through ssl.SSLError.
|
||||
with ExpectLog(gen_log, '.*') as expect_log:
|
||||
response = self.fetch("/", validate_cert=True)
|
||||
self.assertEqual(response.code, 599)
|
||||
self.assertIsInstance(response.error, ssl.SSLError)
|
||||
with self.assertRaises(ssl.SSLError):
|
||||
self.fetch("/", validate_cert=True, raise_error=True)
|
||||
self.assertFalse(expect_log.logged_stack)
|
||||
|
||||
|
||||
|
|
@ -533,24 +531,22 @@ class CreateAsyncHTTPClientTestCase(AsyncTestCase):
|
|||
|
||||
def test_max_clients(self):
|
||||
AsyncHTTPClient.configure(SimpleAsyncHTTPClient)
|
||||
with closing(AsyncHTTPClient(
|
||||
self.io_loop, force_instance=True)) as client:
|
||||
with closing(AsyncHTTPClient(force_instance=True)) as client:
|
||||
self.assertEqual(client.max_clients, 10)
|
||||
with closing(AsyncHTTPClient(
|
||||
self.io_loop, max_clients=11, force_instance=True)) as client:
|
||||
max_clients=11, force_instance=True)) as client:
|
||||
self.assertEqual(client.max_clients, 11)
|
||||
|
||||
# Now configure max_clients statically and try overriding it
|
||||
# with each way max_clients can be passed
|
||||
AsyncHTTPClient.configure(SimpleAsyncHTTPClient, max_clients=12)
|
||||
with closing(AsyncHTTPClient(
|
||||
self.io_loop, force_instance=True)) as client:
|
||||
with closing(AsyncHTTPClient(force_instance=True)) as client:
|
||||
self.assertEqual(client.max_clients, 12)
|
||||
with closing(AsyncHTTPClient(
|
||||
self.io_loop, max_clients=13, force_instance=True)) as client:
|
||||
max_clients=13, force_instance=True)) as client:
|
||||
self.assertEqual(client.max_clients, 13)
|
||||
with closing(AsyncHTTPClient(
|
||||
self.io_loop, max_clients=14, force_instance=True)) as client:
|
||||
max_clients=14, force_instance=True)) as client:
|
||||
self.assertEqual(client.max_clients, 14)
|
||||
|
||||
|
||||
|
|
@ -563,14 +559,15 @@ class HTTP100ContinueTestCase(AsyncHTTPTestCase):
|
|||
request.connection.finish()
|
||||
return
|
||||
self.request = request
|
||||
self.request.connection.stream.write(
|
||||
b"HTTP/1.1 100 CONTINUE\r\n\r\n",
|
||||
self.respond_200)
|
||||
fut = self.request.connection.stream.write(
|
||||
b"HTTP/1.1 100 CONTINUE\r\n\r\n")
|
||||
fut.add_done_callback(self.respond_200)
|
||||
|
||||
def respond_200(self):
|
||||
self.request.connection.stream.write(
|
||||
b"HTTP/1.1 200 OK\r\nContent-Length: 1\r\n\r\nA",
|
||||
self.request.connection.stream.close)
|
||||
def respond_200(self, fut):
|
||||
fut.result()
|
||||
fut = self.request.connection.stream.write(
|
||||
b"HTTP/1.1 200 OK\r\nContent-Length: 1\r\n\r\nA")
|
||||
fut.add_done_callback(lambda f: self.request.connection.stream.close())
|
||||
|
||||
def get_app(self):
|
||||
# Not a full Application, but works as an HTTPServer callback
|
||||
|
|
@ -592,16 +589,20 @@ class HTTP204NoContentTestCase(AsyncHTTPTestCase):
|
|||
HTTPHeaders())
|
||||
request.connection.finish()
|
||||
return
|
||||
|
||||
# A 204 response never has a body, even if doesn't have a content-length
|
||||
# (which would otherwise mean read-until-close). Tornado always
|
||||
# sends a content-length, so we simulate here a server that sends
|
||||
# no content length and does not close the connection.
|
||||
# (which would otherwise mean read-until-close). We simulate here a
|
||||
# server that sends no content length and does not close the connection.
|
||||
#
|
||||
# Tests of a 204 response with a Content-Length header are included
|
||||
# Tests of a 204 response with no Content-Length header are included
|
||||
# in SimpleHTTPClientTestMixin.
|
||||
stream = request.connection.detach()
|
||||
stream.write(
|
||||
b"HTTP/1.1 204 No content\r\n\r\n")
|
||||
stream.write(b"HTTP/1.1 204 No content\r\n")
|
||||
if request.arguments.get("error", [False])[-1]:
|
||||
stream.write(b"Content-Length: 5\r\n")
|
||||
else:
|
||||
stream.write(b"Content-Length: 0\r\n")
|
||||
stream.write(b"\r\n")
|
||||
stream.close()
|
||||
|
||||
def get_app(self):
|
||||
|
|
@ -614,12 +615,21 @@ class HTTP204NoContentTestCase(AsyncHTTPTestCase):
|
|||
self.assertEqual(resp.code, 204)
|
||||
self.assertEqual(resp.body, b'')
|
||||
|
||||
def test_204_invalid_content_length(self):
|
||||
# 204 status with non-zero content length is malformed
|
||||
with ExpectLog(gen_log, ".*Response with code 204 should not have body"):
|
||||
with self.assertRaises(HTTPStreamClosedError):
|
||||
self.fetch("/?error=1", raise_error=True)
|
||||
if not self.http1:
|
||||
self.skipTest("requires HTTP/1.x")
|
||||
if self.http_client.configured_class != SimpleAsyncHTTPClient:
|
||||
self.skipTest("curl client accepts invalid headers")
|
||||
|
||||
|
||||
class HostnameMappingTestCase(AsyncHTTPTestCase):
|
||||
def setUp(self):
|
||||
super(HostnameMappingTestCase, self).setUp()
|
||||
self.http_client = SimpleAsyncHTTPClient(
|
||||
self.io_loop,
|
||||
hostname_mapping={
|
||||
'www.example.com': '127.0.0.1',
|
||||
('foo.example.com', 8000): ('127.0.0.1', self.get_http_port()),
|
||||
|
|
@ -629,37 +639,35 @@ class HostnameMappingTestCase(AsyncHTTPTestCase):
|
|||
return Application([url("/hello", HelloWorldHandler), ])
|
||||
|
||||
def test_hostname_mapping(self):
|
||||
self.http_client.fetch(
|
||||
'http://www.example.com:%d/hello' % self.get_http_port(), self.stop)
|
||||
response = self.wait()
|
||||
response = self.fetch(
|
||||
'http://www.example.com:%d/hello' % self.get_http_port())
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b'Hello world!')
|
||||
|
||||
def test_port_mapping(self):
|
||||
self.http_client.fetch('http://foo.example.com:8000/hello', self.stop)
|
||||
response = self.wait()
|
||||
response = self.fetch('http://foo.example.com:8000/hello')
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b'Hello world!')
|
||||
|
||||
|
||||
class ResolveTimeoutTestCase(AsyncHTTPTestCase):
|
||||
def setUp(self):
|
||||
# Dummy Resolver subclass that never invokes its callback.
|
||||
# Dummy Resolver subclass that never finishes.
|
||||
class BadResolver(Resolver):
|
||||
@gen.coroutine
|
||||
def resolve(self, *args, **kwargs):
|
||||
pass
|
||||
yield Event().wait()
|
||||
|
||||
super(ResolveTimeoutTestCase, self).setUp()
|
||||
self.http_client = SimpleAsyncHTTPClient(
|
||||
self.io_loop,
|
||||
resolver=BadResolver())
|
||||
|
||||
def get_app(self):
|
||||
return Application([url("/hello", HelloWorldHandler), ])
|
||||
|
||||
def test_resolve_timeout(self):
|
||||
response = self.fetch('/hello', connect_timeout=0.1)
|
||||
self.assertEqual(response.code, 599)
|
||||
with self.assertRaises(HTTPTimeoutError):
|
||||
self.fetch('/hello', connect_timeout=0.1, raise_error=True)
|
||||
|
||||
|
||||
class MaxHeaderSizeTest(AsyncHTTPTestCase):
|
||||
|
|
@ -678,7 +686,7 @@ class MaxHeaderSizeTest(AsyncHTTPTestCase):
|
|||
('/large', LargeHeaders)])
|
||||
|
||||
def get_http_client(self):
|
||||
return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_header_size=1024)
|
||||
return SimpleAsyncHTTPClient(max_header_size=1024)
|
||||
|
||||
def test_small_headers(self):
|
||||
response = self.fetch('/small')
|
||||
|
|
@ -687,8 +695,8 @@ class MaxHeaderSizeTest(AsyncHTTPTestCase):
|
|||
|
||||
def test_large_headers(self):
|
||||
with ExpectLog(gen_log, "Unsatisfiable read"):
|
||||
response = self.fetch('/large')
|
||||
self.assertEqual(response.code, 599)
|
||||
with self.assertRaises(UnsatisfiableReadError):
|
||||
self.fetch('/large', raise_error=True)
|
||||
|
||||
|
||||
class MaxBodySizeTest(AsyncHTTPTestCase):
|
||||
|
|
@ -705,7 +713,7 @@ class MaxBodySizeTest(AsyncHTTPTestCase):
|
|||
('/large', LargeBody)])
|
||||
|
||||
def get_http_client(self):
|
||||
return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_body_size=1024 * 64)
|
||||
return SimpleAsyncHTTPClient(max_body_size=1024 * 64)
|
||||
|
||||
def test_small_body(self):
|
||||
response = self.fetch('/small')
|
||||
|
|
@ -714,8 +722,8 @@ class MaxBodySizeTest(AsyncHTTPTestCase):
|
|||
|
||||
def test_large_body(self):
|
||||
with ExpectLog(gen_log, "Malformed HTTP message from None: Content-Length too long"):
|
||||
response = self.fetch('/large')
|
||||
self.assertEqual(response.code, 599)
|
||||
with self.assertRaises(HTTPStreamClosedError):
|
||||
self.fetch('/large', raise_error=True)
|
||||
|
||||
|
||||
class MaxBufferSizeTest(AsyncHTTPTestCase):
|
||||
|
|
@ -729,7 +737,7 @@ class MaxBufferSizeTest(AsyncHTTPTestCase):
|
|||
|
||||
def get_http_client(self):
|
||||
# 100KB body with 64KB buffer
|
||||
return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_body_size=1024 * 100, max_buffer_size=1024 * 64)
|
||||
return SimpleAsyncHTTPClient(max_body_size=1024 * 100, max_buffer_size=1024 * 64)
|
||||
|
||||
def test_large_body(self):
|
||||
response = self.fetch('/large')
|
||||
|
|
@ -754,6 +762,6 @@ class ChunkedWithContentLengthTest(AsyncHTTPTestCase):
|
|||
def test_chunked_with_content_length(self):
|
||||
# Make sure the invalid headers are detected
|
||||
with ExpectLog(gen_log, ("Malformed HTTP message from None: Response "
|
||||
"with both Transfer-Encoding and Content-Length")):
|
||||
response = self.fetch('/chunkwithcl')
|
||||
self.assertEqual(response.code, 599)
|
||||
"with both Transfer-Encoding and Content-Length")):
|
||||
with self.assertRaises(HTTPStreamClosedError):
|
||||
self.fetch('/chunkwithcl', raise_error=True)
|
||||
|
|
|
|||
|
|
@ -1,35 +1,36 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from tornado import gen
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.log import app_log
|
||||
from tornado.stack_context import (StackContext, wrap, NullContext, StackContextInconsistentError,
|
||||
ExceptionStackContext, run_with_stack_context, _state)
|
||||
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
|
||||
from tornado.test.util import unittest
|
||||
from tornado.test.util import unittest, ignore_deprecation
|
||||
from tornado.web import asynchronous, Application, RequestHandler
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
|
||||
class TestRequestHandler(RequestHandler):
|
||||
def __init__(self, app, request, io_loop):
|
||||
def __init__(self, app, request):
|
||||
super(TestRequestHandler, self).__init__(app, request)
|
||||
self.io_loop = io_loop
|
||||
|
||||
@asynchronous
|
||||
def get(self):
|
||||
logging.debug('in get()')
|
||||
# call self.part2 without a self.async_callback wrapper. Its
|
||||
# exception should still get thrown
|
||||
self.io_loop.add_callback(self.part2)
|
||||
with ignore_deprecation():
|
||||
@asynchronous
|
||||
def get(self):
|
||||
logging.debug('in get()')
|
||||
# call self.part2 without a self.async_callback wrapper. Its
|
||||
# exception should still get thrown
|
||||
IOLoop.current().add_callback(self.part2)
|
||||
|
||||
def part2(self):
|
||||
logging.debug('in part2()')
|
||||
# Go through a third layer to make sure that contexts once restored
|
||||
# are again passed on to future callbacks
|
||||
self.io_loop.add_callback(self.part3)
|
||||
IOLoop.current().add_callback(self.part3)
|
||||
|
||||
def part3(self):
|
||||
logging.debug('in part3()')
|
||||
|
|
@ -44,13 +45,13 @@ class TestRequestHandler(RequestHandler):
|
|||
|
||||
class HTTPStackContextTest(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
return Application([('/', TestRequestHandler,
|
||||
dict(io_loop=self.io_loop))])
|
||||
return Application([('/', TestRequestHandler)])
|
||||
|
||||
def test_stack_context(self):
|
||||
with ExpectLog(app_log, "Uncaught exception GET /"):
|
||||
self.http_client.fetch(self.get_url('/'), self.handle_response)
|
||||
self.wait()
|
||||
with ignore_deprecation():
|
||||
self.http_client.fetch(self.get_url('/'), self.handle_response)
|
||||
self.wait()
|
||||
self.assertEqual(self.response.code, 500)
|
||||
self.assertTrue(b'got expected exception' in self.response.body)
|
||||
|
||||
|
|
@ -63,6 +64,13 @@ class StackContextTest(AsyncTestCase):
|
|||
def setUp(self):
|
||||
super(StackContextTest, self).setUp()
|
||||
self.active_contexts = []
|
||||
self.warning_catcher = warnings.catch_warnings()
|
||||
self.warning_catcher.__enter__()
|
||||
warnings.simplefilter('ignore', DeprecationWarning)
|
||||
|
||||
def tearDown(self):
|
||||
self.warning_catcher.__exit__(None, None, None)
|
||||
super(StackContextTest, self).tearDown()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def context(self, name):
|
||||
|
|
@ -284,5 +292,6 @@ class StackContextTest(AsyncTestCase):
|
|||
f1)
|
||||
self.assertEqual(self.active_contexts, [])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright 2014 Facebook
|
||||
#
|
||||
|
|
@ -14,7 +13,7 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from contextlib import closing
|
||||
import os
|
||||
|
|
@ -22,10 +21,12 @@ import socket
|
|||
|
||||
from tornado.concurrent import Future
|
||||
from tornado.netutil import bind_sockets, Resolver
|
||||
from tornado.queues import Queue
|
||||
from tornado.tcpclient import TCPClient, _Connector
|
||||
from tornado.tcpserver import TCPServer
|
||||
from tornado.testing import AsyncTestCase, gen_test
|
||||
from tornado.test.util import skipIfNoIPv6, unittest, refusing_port
|
||||
from tornado.test.util import skipIfNoIPv6, unittest, refusing_port, skipIfNonUnix
|
||||
from tornado.gen import TimeoutError
|
||||
|
||||
# Fake address families for testing. Used in place of AF_INET
|
||||
# and AF_INET6 because some installations do not have AF_INET6.
|
||||
|
|
@ -36,12 +37,14 @@ class TestTCPServer(TCPServer):
|
|||
def __init__(self, family):
|
||||
super(TestTCPServer, self).__init__()
|
||||
self.streams = []
|
||||
self.queue = Queue()
|
||||
sockets = bind_sockets(None, 'localhost', family)
|
||||
self.add_sockets(sockets)
|
||||
self.port = sockets[0].getsockname()[1]
|
||||
|
||||
def handle_stream(self, stream, address):
|
||||
self.streams.append(stream)
|
||||
self.queue.put(stream)
|
||||
|
||||
def stop(self):
|
||||
super(TestTCPServer, self).stop()
|
||||
|
|
@ -74,19 +77,21 @@ class TCPClientTest(AsyncTestCase):
|
|||
def skipIfLocalhostV4(self):
|
||||
# The port used here doesn't matter, but some systems require it
|
||||
# to be non-zero if we do not also pass AI_PASSIVE.
|
||||
Resolver().resolve('localhost', 80, callback=self.stop)
|
||||
addrinfo = self.wait()
|
||||
addrinfo = self.io_loop.run_sync(lambda: Resolver().resolve('localhost', 80))
|
||||
families = set(addr[0] for addr in addrinfo)
|
||||
if socket.AF_INET6 not in families:
|
||||
self.skipTest("localhost does not resolve to ipv6")
|
||||
|
||||
@gen_test
|
||||
def do_test_connect(self, family, host):
|
||||
def do_test_connect(self, family, host, source_ip=None, source_port=None):
|
||||
port = self.start_server(family)
|
||||
stream = yield self.client.connect(host, port)
|
||||
stream = yield self.client.connect(host, port,
|
||||
source_ip=source_ip,
|
||||
source_port=source_port)
|
||||
server_stream = yield self.server.queue.get()
|
||||
with closing(stream):
|
||||
stream.write(b"hello")
|
||||
data = yield self.server.streams[0].read_bytes(5)
|
||||
data = yield server_stream.read_bytes(5)
|
||||
self.assertEqual(data, b"hello")
|
||||
|
||||
def test_connect_ipv4_ipv4(self):
|
||||
|
|
@ -125,6 +130,44 @@ class TCPClientTest(AsyncTestCase):
|
|||
with self.assertRaises(IOError):
|
||||
yield self.client.connect('127.0.0.1', port)
|
||||
|
||||
def test_source_ip_fail(self):
|
||||
'''
|
||||
Fail when trying to use the source IP Address '8.8.8.8'.
|
||||
'''
|
||||
self.assertRaises(socket.error,
|
||||
self.do_test_connect,
|
||||
socket.AF_INET,
|
||||
'127.0.0.1',
|
||||
source_ip='8.8.8.8')
|
||||
|
||||
def test_source_ip_success(self):
|
||||
'''
|
||||
Success when trying to use the source IP Address '127.0.0.1'
|
||||
'''
|
||||
self.do_test_connect(socket.AF_INET, '127.0.0.1', source_ip='127.0.0.1')
|
||||
|
||||
@skipIfNonUnix
|
||||
def test_source_port_fail(self):
|
||||
'''
|
||||
Fail when trying to use source port 1.
|
||||
'''
|
||||
self.assertRaises(socket.error,
|
||||
self.do_test_connect,
|
||||
socket.AF_INET,
|
||||
'127.0.0.1',
|
||||
source_port=1)
|
||||
|
||||
@gen_test
|
||||
def test_connect_timeout(self):
|
||||
timeout = 0.05
|
||||
|
||||
class TimeoutResolver(Resolver):
|
||||
def resolve(self, *args, **kwargs):
|
||||
return Future() # never completes
|
||||
with self.assertRaises(TimeoutError):
|
||||
yield TCPClient(resolver=TimeoutResolver()).connect(
|
||||
'1.2.3.4', 12345, timeout=timeout)
|
||||
|
||||
|
||||
class TestConnectorSplit(unittest.TestCase):
|
||||
def test_one_family(self):
|
||||
|
|
@ -169,9 +212,11 @@ class ConnectorTest(AsyncTestCase):
|
|||
super(ConnectorTest, self).tearDown()
|
||||
|
||||
def create_stream(self, af, addr):
|
||||
stream = ConnectorTest.FakeStream()
|
||||
self.streams[addr] = stream
|
||||
future = Future()
|
||||
self.connect_futures[(af, addr)] = future
|
||||
return future
|
||||
return stream, future
|
||||
|
||||
def assert_pending(self, *keys):
|
||||
self.assertEqual(sorted(self.connect_futures.keys()), sorted(keys))
|
||||
|
|
@ -179,15 +224,22 @@ class ConnectorTest(AsyncTestCase):
|
|||
def resolve_connect(self, af, addr, success):
|
||||
future = self.connect_futures.pop((af, addr))
|
||||
if success:
|
||||
self.streams[addr] = ConnectorTest.FakeStream()
|
||||
future.set_result(self.streams[addr])
|
||||
else:
|
||||
self.streams.pop(addr)
|
||||
future.set_exception(IOError())
|
||||
# Run the loop to allow callbacks to be run.
|
||||
self.io_loop.add_callback(self.stop)
|
||||
self.wait()
|
||||
|
||||
def assert_connector_streams_closed(self, conn):
|
||||
for stream in conn.streams:
|
||||
self.assertTrue(stream.closed)
|
||||
|
||||
def start_connect(self, addrinfo):
|
||||
conn = _Connector(addrinfo, self.io_loop, self.create_stream)
|
||||
conn = _Connector(addrinfo, self.create_stream)
|
||||
# Give it a huge timeout; we'll trigger timeouts manually.
|
||||
future = conn.start(3600)
|
||||
future = conn.start(3600, connect_timeout=self.io_loop.time() + 3600)
|
||||
return conn, future
|
||||
|
||||
def test_immediate_success(self):
|
||||
|
|
@ -278,3 +330,101 @@ class ConnectorTest(AsyncTestCase):
|
|||
self.assertFalse(future.done())
|
||||
self.resolve_connect(AF1, 'b', False)
|
||||
self.assertRaises(IOError, future.result)
|
||||
|
||||
def test_one_family_timeout_after_connect_timeout(self):
|
||||
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
|
||||
self.assert_pending((AF1, 'a'))
|
||||
conn.on_connect_timeout()
|
||||
# the connector will close all streams on connect timeout, we
|
||||
# should explicitly pop the connect_future.
|
||||
self.connect_futures.pop((AF1, 'a'))
|
||||
self.assertTrue(self.streams.pop('a').closed)
|
||||
conn.on_timeout()
|
||||
# if the future is set with TimeoutError, we will not iterate next
|
||||
# possible address.
|
||||
self.assert_pending()
|
||||
self.assertEqual(len(conn.streams), 1)
|
||||
self.assert_connector_streams_closed(conn)
|
||||
self.assertRaises(TimeoutError, future.result)
|
||||
|
||||
def test_one_family_success_before_connect_timeout(self):
|
||||
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
|
||||
self.assert_pending((AF1, 'a'))
|
||||
self.resolve_connect(AF1, 'a', True)
|
||||
conn.on_connect_timeout()
|
||||
self.assert_pending()
|
||||
self.assertEqual(self.streams['a'].closed, False)
|
||||
# success stream will be pop
|
||||
self.assertEqual(len(conn.streams), 0)
|
||||
# streams in connector should be closed after connect timeout
|
||||
self.assert_connector_streams_closed(conn)
|
||||
self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
|
||||
|
||||
def test_one_family_second_try_after_connect_timeout(self):
|
||||
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
|
||||
self.assert_pending((AF1, 'a'))
|
||||
self.resolve_connect(AF1, 'a', False)
|
||||
self.assert_pending((AF1, 'b'))
|
||||
conn.on_connect_timeout()
|
||||
self.connect_futures.pop((AF1, 'b'))
|
||||
self.assertTrue(self.streams.pop('b').closed)
|
||||
self.assert_pending()
|
||||
self.assertEqual(len(conn.streams), 2)
|
||||
self.assert_connector_streams_closed(conn)
|
||||
self.assertRaises(TimeoutError, future.result)
|
||||
|
||||
def test_one_family_second_try_failure_before_connect_timeout(self):
|
||||
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
|
||||
self.assert_pending((AF1, 'a'))
|
||||
self.resolve_connect(AF1, 'a', False)
|
||||
self.assert_pending((AF1, 'b'))
|
||||
self.resolve_connect(AF1, 'b', False)
|
||||
conn.on_connect_timeout()
|
||||
self.assert_pending()
|
||||
self.assertEqual(len(conn.streams), 2)
|
||||
self.assert_connector_streams_closed(conn)
|
||||
self.assertRaises(IOError, future.result)
|
||||
|
||||
def test_two_family_timeout_before_connect_timeout(self):
|
||||
conn, future = self.start_connect(self.addrinfo)
|
||||
self.assert_pending((AF1, 'a'))
|
||||
conn.on_timeout()
|
||||
self.assert_pending((AF1, 'a'), (AF2, 'c'))
|
||||
conn.on_connect_timeout()
|
||||
self.connect_futures.pop((AF1, 'a'))
|
||||
self.assertTrue(self.streams.pop('a').closed)
|
||||
self.connect_futures.pop((AF2, 'c'))
|
||||
self.assertTrue(self.streams.pop('c').closed)
|
||||
self.assert_pending()
|
||||
self.assertEqual(len(conn.streams), 2)
|
||||
self.assert_connector_streams_closed(conn)
|
||||
self.assertRaises(TimeoutError, future.result)
|
||||
|
||||
def test_two_family_success_after_timeout(self):
|
||||
conn, future = self.start_connect(self.addrinfo)
|
||||
self.assert_pending((AF1, 'a'))
|
||||
conn.on_timeout()
|
||||
self.assert_pending((AF1, 'a'), (AF2, 'c'))
|
||||
self.resolve_connect(AF1, 'a', True)
|
||||
# if one of streams succeed, connector will close all other streams
|
||||
self.connect_futures.pop((AF2, 'c'))
|
||||
self.assertTrue(self.streams.pop('c').closed)
|
||||
self.assert_pending()
|
||||
self.assertEqual(len(conn.streams), 1)
|
||||
self.assert_connector_streams_closed(conn)
|
||||
self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
|
||||
|
||||
def test_two_family_timeout_after_connect_timeout(self):
|
||||
conn, future = self.start_connect(self.addrinfo)
|
||||
self.assert_pending((AF1, 'a'))
|
||||
conn.on_connect_timeout()
|
||||
self.connect_futures.pop((AF1, 'a'))
|
||||
self.assertTrue(self.streams.pop('a').closed)
|
||||
self.assert_pending()
|
||||
conn.on_timeout()
|
||||
# if the future is set with TimeoutError, connector will not
|
||||
# trigger secondary address.
|
||||
self.assert_pending()
|
||||
self.assertEqual(len(conn.streams), 1)
|
||||
self.assert_connector_streams_closed(conn)
|
||||
self.assertRaises(TimeoutError, future.result)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,17 @@
|
|||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
import socket
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
|
||||
from tornado.escape import utf8, to_unicode
|
||||
from tornado import gen
|
||||
from tornado.iostream import IOStream
|
||||
from tornado.log import app_log
|
||||
from tornado.stack_context import NullContext
|
||||
from tornado.tcpserver import TCPServer
|
||||
from tornado.test.util import skipBefore35, skipIfNonUnix, exec_test, unittest
|
||||
from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test
|
||||
|
||||
|
||||
|
|
@ -17,7 +23,7 @@ class TCPServerTest(AsyncTestCase):
|
|||
class TestServer(TCPServer):
|
||||
@gen.coroutine
|
||||
def handle_stream(self, stream, address):
|
||||
yield gen.moment
|
||||
yield stream.read_bytes(len(b'hello'))
|
||||
stream.close()
|
||||
1 / 0
|
||||
|
||||
|
|
@ -30,6 +36,7 @@ class TCPServerTest(AsyncTestCase):
|
|||
client = IOStream(socket.socket())
|
||||
with ExpectLog(app_log, "Exception in callback"):
|
||||
yield client.connect(('localhost', port))
|
||||
yield client.write(b'hello')
|
||||
yield client.read_until_close()
|
||||
yield gen.moment
|
||||
finally:
|
||||
|
|
@ -37,3 +44,150 @@ class TCPServerTest(AsyncTestCase):
|
|||
server.stop()
|
||||
if client is not None:
|
||||
client.close()
|
||||
|
||||
@skipBefore35
|
||||
@gen_test
|
||||
def test_handle_stream_native_coroutine(self):
|
||||
# handle_stream may be a native coroutine.
|
||||
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
class TestServer(TCPServer):
|
||||
async def handle_stream(self, stream, address):
|
||||
stream.write(b'data')
|
||||
stream.close()
|
||||
""")
|
||||
|
||||
sock, port = bind_unused_port()
|
||||
server = namespace['TestServer']()
|
||||
server.add_socket(sock)
|
||||
client = IOStream(socket.socket())
|
||||
yield client.connect(('localhost', port))
|
||||
result = yield client.read_until_close()
|
||||
self.assertEqual(result, b'data')
|
||||
server.stop()
|
||||
client.close()
|
||||
|
||||
def test_stop_twice(self):
|
||||
sock, port = bind_unused_port()
|
||||
server = TCPServer()
|
||||
server.add_socket(sock)
|
||||
server.stop()
|
||||
server.stop()
|
||||
|
||||
@gen_test
|
||||
def test_stop_in_callback(self):
|
||||
# Issue #2069: calling server.stop() in a loop callback should not
|
||||
# raise EBADF when the loop handles other server connection
|
||||
# requests in the same loop iteration
|
||||
|
||||
class TestServer(TCPServer):
|
||||
@gen.coroutine
|
||||
def handle_stream(self, stream, address):
|
||||
server.stop()
|
||||
yield stream.read_until_close()
|
||||
|
||||
sock, port = bind_unused_port()
|
||||
server = TestServer()
|
||||
server.add_socket(sock)
|
||||
server_addr = ('localhost', port)
|
||||
N = 40
|
||||
clients = [IOStream(socket.socket()) for i in range(N)]
|
||||
connected_clients = []
|
||||
|
||||
@gen.coroutine
|
||||
def connect(c):
|
||||
try:
|
||||
yield c.connect(server_addr)
|
||||
except EnvironmentError:
|
||||
pass
|
||||
else:
|
||||
connected_clients.append(c)
|
||||
|
||||
yield [connect(c) for c in clients]
|
||||
|
||||
self.assertGreater(len(connected_clients), 0,
|
||||
"all clients failed connecting")
|
||||
try:
|
||||
if len(connected_clients) == N:
|
||||
# Ideally we'd make the test deterministic, but we're testing
|
||||
# for a race condition in combination with the system's TCP stack...
|
||||
self.skipTest("at least one client should fail connecting "
|
||||
"for the test to be meaningful")
|
||||
finally:
|
||||
for c in connected_clients:
|
||||
c.close()
|
||||
|
||||
# Here tearDown() would re-raise the EBADF encountered in the IO loop
|
||||
|
||||
|
||||
@skipIfNonUnix
|
||||
class TestMultiprocess(unittest.TestCase):
|
||||
# These tests verify that the two multiprocess examples from the
|
||||
# TCPServer docs work. Both tests start a server with three worker
|
||||
# processes, each of which prints its task id to stdout (a single
|
||||
# byte, so we don't have to worry about atomicity of the shared
|
||||
# stdout stream) and then exits.
|
||||
def run_subproc(self, code):
|
||||
proc = subprocess.Popen(sys.executable,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE)
|
||||
proc.stdin.write(utf8(code))
|
||||
proc.stdin.close()
|
||||
proc.wait()
|
||||
stdout = proc.stdout.read()
|
||||
proc.stdout.close()
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError("Process returned %d. stdout=%r" % (
|
||||
proc.returncode, stdout))
|
||||
return to_unicode(stdout)
|
||||
|
||||
def test_single(self):
|
||||
# As a sanity check, run the single-process version through this test
|
||||
# harness too.
|
||||
code = textwrap.dedent("""
|
||||
from __future__ import print_function
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.tcpserver import TCPServer
|
||||
|
||||
server = TCPServer()
|
||||
server.listen(0, address='127.0.0.1')
|
||||
IOLoop.current().run_sync(lambda: None)
|
||||
print('012', end='')
|
||||
""")
|
||||
out = self.run_subproc(code)
|
||||
self.assertEqual(''.join(sorted(out)), "012")
|
||||
|
||||
def test_simple(self):
|
||||
code = textwrap.dedent("""
|
||||
from __future__ import print_function
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.process import task_id
|
||||
from tornado.tcpserver import TCPServer
|
||||
|
||||
server = TCPServer()
|
||||
server.bind(0, address='127.0.0.1')
|
||||
server.start(3)
|
||||
IOLoop.current().run_sync(lambda: None)
|
||||
print(task_id(), end='')
|
||||
""")
|
||||
out = self.run_subproc(code)
|
||||
self.assertEqual(''.join(sorted(out)), "012")
|
||||
|
||||
def test_advanced(self):
|
||||
code = textwrap.dedent("""
|
||||
from __future__ import print_function
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.netutil import bind_sockets
|
||||
from tornado.process import fork_processes, task_id
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.tcpserver import TCPServer
|
||||
|
||||
sockets = bind_sockets(0, address='127.0.0.1')
|
||||
fork_processes(3)
|
||||
server = TCPServer()
|
||||
server.add_sockets(sockets)
|
||||
IOLoop.current().run_sync(lambda: None)
|
||||
print(task_id(), end='')
|
||||
""")
|
||||
out = self.run_subproc(code)
|
||||
self.assertEqual(''.join(sorted(out)), "012")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
|
@ -7,7 +7,7 @@ import traceback
|
|||
from tornado.escape import utf8, native_str, to_unicode
|
||||
from tornado.template import Template, DictLoader, ParseError, Loader
|
||||
from tornado.test.util import unittest
|
||||
from tornado.util import u, ObjectDict, unicode_type
|
||||
from tornado.util import ObjectDict, unicode_type
|
||||
|
||||
|
||||
class TemplateTest(unittest.TestCase):
|
||||
|
|
@ -67,12 +67,13 @@ class TemplateTest(unittest.TestCase):
|
|||
self.assertRaises(ParseError, lambda: Template("{%"))
|
||||
self.assertEqual(Template("{{!").generate(), b"{{")
|
||||
self.assertEqual(Template("{%!").generate(), b"{%")
|
||||
self.assertEqual(Template("{#!").generate(), b"{#")
|
||||
self.assertEqual(Template("{{ 'expr' }} {{!jquery expr}}").generate(),
|
||||
b"expr {{jquery expr}}")
|
||||
|
||||
def test_unicode_template(self):
|
||||
template = Template(utf8(u("\u00e9")))
|
||||
self.assertEqual(template.generate(), utf8(u("\u00e9")))
|
||||
template = Template(utf8(u"\u00e9"))
|
||||
self.assertEqual(template.generate(), utf8(u"\u00e9"))
|
||||
|
||||
def test_unicode_literal_expression(self):
|
||||
# Unicode literals should be usable in templates. Note that this
|
||||
|
|
@ -82,10 +83,10 @@ class TemplateTest(unittest.TestCase):
|
|||
if str is unicode_type:
|
||||
# python 3 needs a different version of this test since
|
||||
# 2to3 doesn't run on template internals
|
||||
template = Template(utf8(u('{{ "\u00e9" }}')))
|
||||
template = Template(utf8(u'{{ "\u00e9" }}'))
|
||||
else:
|
||||
template = Template(utf8(u('{{ u"\u00e9" }}')))
|
||||
self.assertEqual(template.generate(), utf8(u("\u00e9")))
|
||||
template = Template(utf8(u'{{ u"\u00e9" }}'))
|
||||
self.assertEqual(template.generate(), utf8(u"\u00e9"))
|
||||
|
||||
def test_custom_namespace(self):
|
||||
loader = DictLoader({"test.html": "{{ inc(5) }}"}, namespace={"inc": lambda x: x + 1})
|
||||
|
|
@ -100,14 +101,14 @@ class TemplateTest(unittest.TestCase):
|
|||
def test_unicode_apply(self):
|
||||
def upper(s):
|
||||
return to_unicode(s).upper()
|
||||
template = Template(utf8(u("{% apply upper %}foo \u00e9{% end %}")))
|
||||
self.assertEqual(template.generate(upper=upper), utf8(u("FOO \u00c9")))
|
||||
template = Template(utf8(u"{% apply upper %}foo \u00e9{% end %}"))
|
||||
self.assertEqual(template.generate(upper=upper), utf8(u"FOO \u00c9"))
|
||||
|
||||
def test_bytes_apply(self):
|
||||
def upper(s):
|
||||
return utf8(to_unicode(s).upper())
|
||||
template = Template(utf8(u("{% apply upper %}foo \u00e9{% end %}")))
|
||||
self.assertEqual(template.generate(upper=upper), utf8(u("FOO \u00c9")))
|
||||
template = Template(utf8(u"{% apply upper %}foo \u00e9{% end %}"))
|
||||
self.assertEqual(template.generate(upper=upper), utf8(u"FOO \u00c9"))
|
||||
|
||||
def test_if(self):
|
||||
template = Template(utf8("{% if x > 4 %}yes{% else %}no{% end %}"))
|
||||
|
|
@ -174,8 +175,8 @@ try{% set y = 1/x %}
|
|||
self.assertEqual(template.generate(), '0')
|
||||
|
||||
def test_non_ascii_name(self):
|
||||
loader = DictLoader({u("t\u00e9st.html"): "hello"})
|
||||
self.assertEqual(loader.load(u("t\u00e9st.html")).generate(), b"hello")
|
||||
loader = DictLoader({u"t\u00e9st.html": "hello"})
|
||||
self.assertEqual(loader.load(u"t\u00e9st.html").generate(), b"hello")
|
||||
|
||||
|
||||
class StackTraceTest(unittest.TestCase):
|
||||
|
|
@ -202,10 +203,15 @@ three{%end%}
|
|||
self.assertTrue("# test.html:2" in traceback.format_exc())
|
||||
|
||||
def test_error_line_number_module(self):
|
||||
loader = None
|
||||
|
||||
def load_generate(path, **kwargs):
|
||||
return loader.load(path).generate(**kwargs)
|
||||
|
||||
loader = DictLoader({
|
||||
"base.html": "{% module Template('sub.html') %}",
|
||||
"sub.html": "{{1/0}}",
|
||||
}, namespace={"_tt_modules": ObjectDict({"Template": lambda path, **kwargs: loader.load(path).generate(**kwargs)})})
|
||||
}, namespace={"_tt_modules": ObjectDict(Template=load_generate)})
|
||||
try:
|
||||
loader.load("base.html").generate()
|
||||
self.fail("did not get expected exception")
|
||||
|
|
@ -280,6 +286,11 @@ class ParseErrorDetailTest(unittest.TestCase):
|
|||
self.assertEqual("foo.html", cm.exception.filename)
|
||||
self.assertEqual(3, cm.exception.lineno)
|
||||
|
||||
def test_custom_parse_error(self):
|
||||
# Make sure that ParseErrors remain compatible with their
|
||||
# pre-4.3 signature.
|
||||
self.assertEqual("asdf at None:0", str(ParseError("asdf")))
|
||||
|
||||
|
||||
class AutoEscapeTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
|
@ -482,4 +493,4 @@ class TemplateLoaderTest(unittest.TestCase):
|
|||
def test_utf8_in_file(self):
|
||||
tmpl = self.loader.load("utf8.html")
|
||||
result = tmpl.generate()
|
||||
self.assertEqual(to_unicode(result).strip(), u("H\u00e9llo"))
|
||||
self.assertEqual(to_unicode(result).strip(), u"H\u00e9llo")
|
||||
|
|
|
|||
|
|
@ -1,16 +1,22 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from tornado import gen, ioloop
|
||||
from tornado.log import app_log
|
||||
from tornado.testing import AsyncTestCase, gen_test, ExpectLog
|
||||
from tornado.test.util import unittest, skipBefore35, exec_test
|
||||
from tornado.simple_httpclient import SimpleAsyncHTTPClient, HTTPTimeoutError
|
||||
from tornado.test.util import unittest, skipBefore35, exec_test, ignore_deprecation
|
||||
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, bind_unused_port, gen_test, ExpectLog
|
||||
from tornado.web import Application
|
||||
import contextlib
|
||||
import os
|
||||
import platform
|
||||
import traceback
|
||||
import warnings
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
except ImportError:
|
||||
asyncio = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_environ(name, value):
|
||||
|
|
@ -28,12 +34,13 @@ def set_environ(name, value):
|
|||
|
||||
class AsyncTestCaseTest(AsyncTestCase):
|
||||
def test_exception_in_callback(self):
|
||||
self.io_loop.add_callback(lambda: 1 / 0)
|
||||
try:
|
||||
self.wait()
|
||||
self.fail("did not get expected exception")
|
||||
except ZeroDivisionError:
|
||||
pass
|
||||
with ignore_deprecation():
|
||||
self.io_loop.add_callback(lambda: 1 / 0)
|
||||
try:
|
||||
self.wait()
|
||||
self.fail("did not get expected exception")
|
||||
except ZeroDivisionError:
|
||||
pass
|
||||
|
||||
def test_wait_timeout(self):
|
||||
time = self.io_loop.time
|
||||
|
|
@ -64,15 +71,56 @@ class AsyncTestCaseTest(AsyncTestCase):
|
|||
self.wait(timeout=0.15)
|
||||
|
||||
def test_multiple_errors(self):
|
||||
def fail(message):
|
||||
raise Exception(message)
|
||||
self.io_loop.add_callback(lambda: fail("error one"))
|
||||
self.io_loop.add_callback(lambda: fail("error two"))
|
||||
# The first error gets raised; the second gets logged.
|
||||
with ExpectLog(app_log, "multiple unhandled exceptions"):
|
||||
with self.assertRaises(Exception) as cm:
|
||||
self.wait()
|
||||
self.assertEqual(str(cm.exception), "error one")
|
||||
with ignore_deprecation():
|
||||
def fail(message):
|
||||
raise Exception(message)
|
||||
self.io_loop.add_callback(lambda: fail("error one"))
|
||||
self.io_loop.add_callback(lambda: fail("error two"))
|
||||
# The first error gets raised; the second gets logged.
|
||||
with ExpectLog(app_log, "multiple unhandled exceptions"):
|
||||
with self.assertRaises(Exception) as cm:
|
||||
self.wait()
|
||||
self.assertEqual(str(cm.exception), "error one")
|
||||
|
||||
|
||||
class AsyncHTTPTestCaseTest(AsyncHTTPTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(AsyncHTTPTestCaseTest, cls).setUpClass()
|
||||
# An unused port is bound so we can make requests upon it without
|
||||
# impacting a real local web server.
|
||||
cls.external_sock, cls.external_port = bind_unused_port()
|
||||
|
||||
def get_app(self):
|
||||
return Application()
|
||||
|
||||
def test_fetch_segment(self):
|
||||
path = '/path'
|
||||
response = self.fetch(path)
|
||||
self.assertEqual(response.request.url, self.get_url(path))
|
||||
|
||||
@gen_test
|
||||
def test_fetch_full_http_url(self):
|
||||
path = 'http://localhost:%d/path' % self.external_port
|
||||
|
||||
with contextlib.closing(SimpleAsyncHTTPClient(force_instance=True)) as client:
|
||||
with self.assertRaises(HTTPTimeoutError) as cm:
|
||||
yield client.fetch(path, request_timeout=0.1, raise_error=True)
|
||||
self.assertEqual(cm.exception.response.request.url, path)
|
||||
|
||||
@gen_test
|
||||
def test_fetch_full_https_url(self):
|
||||
path = 'https://localhost:%d/path' % self.external_port
|
||||
|
||||
with contextlib.closing(SimpleAsyncHTTPClient(force_instance=True)) as client:
|
||||
with self.assertRaises(HTTPTimeoutError) as cm:
|
||||
yield client.fetch(path, request_timeout=0.1, raise_error=True)
|
||||
self.assertEqual(cm.exception.response.request.url, path)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.external_sock.close()
|
||||
super(AsyncHTTPTestCaseTest, cls).tearDownClass()
|
||||
|
||||
|
||||
class AsyncTestCaseWrapperTest(unittest.TestCase):
|
||||
|
|
@ -87,6 +135,8 @@ class AsyncTestCaseWrapperTest(unittest.TestCase):
|
|||
self.assertIn("should be decorated", result.errors[0][1])
|
||||
|
||||
@skipBefore35
|
||||
@unittest.skipIf(platform.python_implementation() == 'PyPy',
|
||||
'pypy destructor warnings cannot be silenced')
|
||||
def test_undecorated_coroutine(self):
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
class Test(AsyncTestCase):
|
||||
|
|
@ -172,14 +222,14 @@ class GenTest(AsyncTestCase):
|
|||
|
||||
@gen_test
|
||||
def test_async(self):
|
||||
yield gen.Task(self.io_loop.add_callback)
|
||||
yield gen.moment
|
||||
self.finished = True
|
||||
|
||||
def test_timeout(self):
|
||||
# Set a short timeout and exceed it.
|
||||
@gen_test(timeout=0.1)
|
||||
def test(self):
|
||||
yield gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1)
|
||||
yield gen.sleep(1)
|
||||
|
||||
# This can't use assertRaises because we need to inspect the
|
||||
# exc_info triple (and not just the exception object)
|
||||
|
|
@ -190,7 +240,7 @@ class GenTest(AsyncTestCase):
|
|||
# The stack trace should blame the add_timeout line, not just
|
||||
# unrelated IOLoop/testing internals.
|
||||
self.assertIn(
|
||||
"gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1)",
|
||||
"gen.sleep(1)",
|
||||
traceback.format_exc())
|
||||
|
||||
self.finished = True
|
||||
|
|
@ -199,8 +249,7 @@ class GenTest(AsyncTestCase):
|
|||
# A test that does not exceed its timeout should succeed.
|
||||
@gen_test(timeout=1)
|
||||
def test(self):
|
||||
time = self.io_loop.time
|
||||
yield gen.Task(self.io_loop.add_timeout, time() + 0.1)
|
||||
yield gen.sleep(0.1)
|
||||
|
||||
test(self)
|
||||
self.finished = True
|
||||
|
|
@ -208,8 +257,7 @@ class GenTest(AsyncTestCase):
|
|||
def test_timeout_environment_variable(self):
|
||||
@gen_test(timeout=0.5)
|
||||
def test_long_timeout(self):
|
||||
time = self.io_loop.time
|
||||
yield gen.Task(self.io_loop.add_timeout, time() + 0.25)
|
||||
yield gen.sleep(0.25)
|
||||
|
||||
# Uses provided timeout of 0.5 seconds, doesn't time out.
|
||||
with set_environ('ASYNC_TEST_TIMEOUT', '0.1'):
|
||||
|
|
@ -220,8 +268,7 @@ class GenTest(AsyncTestCase):
|
|||
def test_no_timeout_environment_variable(self):
|
||||
@gen_test(timeout=0.01)
|
||||
def test_short_timeout(self):
|
||||
time = self.io_loop.time
|
||||
yield gen.Task(self.io_loop.add_timeout, time() + 1)
|
||||
yield gen.sleep(1)
|
||||
|
||||
# Uses environment-variable timeout of 0.1, times out.
|
||||
with set_environ('ASYNC_TEST_TIMEOUT', '0.1'):
|
||||
|
|
@ -234,7 +281,7 @@ class GenTest(AsyncTestCase):
|
|||
@gen_test
|
||||
def test_with_args(self, *args):
|
||||
self.assertEqual(args, ('test',))
|
||||
yield gen.Task(self.io_loop.add_callback)
|
||||
yield gen.moment
|
||||
|
||||
test_with_args(self, 'test')
|
||||
self.finished = True
|
||||
|
|
@ -243,7 +290,7 @@ class GenTest(AsyncTestCase):
|
|||
@gen_test
|
||||
def test_with_kwargs(self, **kwargs):
|
||||
self.assertDictEqual(kwargs, {'test': 'test'})
|
||||
yield gen.Task(self.io_loop.add_callback)
|
||||
yield gen.moment
|
||||
|
||||
test_with_kwargs(self, test='test')
|
||||
self.finished = True
|
||||
|
|
@ -274,5 +321,30 @@ class GenTest(AsyncTestCase):
|
|||
self.finished = True
|
||||
|
||||
|
||||
@unittest.skipIf(asyncio is None, "asyncio module not present")
|
||||
class GetNewIOLoopTest(AsyncTestCase):
|
||||
def get_new_ioloop(self):
|
||||
# Use the current loop instead of creating a new one here.
|
||||
return ioloop.IOLoop.current()
|
||||
|
||||
def setUp(self):
|
||||
# This simulates the effect of an asyncio test harness like
|
||||
# pytest-asyncio.
|
||||
self.orig_loop = asyncio.get_event_loop()
|
||||
self.new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.new_loop)
|
||||
super(GetNewIOLoopTest, self).setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super(GetNewIOLoopTest, self).tearDown()
|
||||
# AsyncTestCase must not affect the existing asyncio loop.
|
||||
self.assertFalse(asyncio.get_event_loop().is_closed())
|
||||
asyncio.set_event_loop(self.orig_loop)
|
||||
self.new_loop.close()
|
||||
|
||||
def test_loop(self):
|
||||
self.assertIs(self.io_loop.asyncio_loop, self.new_loop)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@
|
|||
Unittest for the twisted-style reactor.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
|
@ -28,14 +28,25 @@ import tempfile
|
|||
import threading
|
||||
import warnings
|
||||
|
||||
from tornado.escape import utf8
|
||||
from tornado import gen
|
||||
from tornado.httpclient import AsyncHTTPClient
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.ioloop import IOLoop, PollIOLoop
|
||||
from tornado.platform.auto import set_close_exec
|
||||
from tornado.testing import bind_unused_port
|
||||
from tornado.test.util import unittest
|
||||
from tornado.util import import_object, PY3
|
||||
from tornado.web import RequestHandler, Application
|
||||
|
||||
try:
|
||||
import fcntl
|
||||
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue
|
||||
from twisted.internet.interfaces import IReadDescriptor, IWriteDescriptor
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.python import log
|
||||
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue # type: ignore
|
||||
from twisted.internet.interfaces import IReadDescriptor, IWriteDescriptor # type: ignore
|
||||
from twisted.internet.protocol import Protocol # type: ignore
|
||||
from twisted.python import log # type: ignore
|
||||
from tornado.platform.twisted import TornadoReactor, TwistedIOLoop
|
||||
from zope.interface import implementer
|
||||
from zope.interface import implementer # type: ignore
|
||||
have_twisted = True
|
||||
except ImportError:
|
||||
have_twisted = False
|
||||
|
|
@ -43,38 +54,29 @@ except ImportError:
|
|||
# The core of Twisted 12.3.0 is available on python 3, but twisted.web is not
|
||||
# so test for it separately.
|
||||
try:
|
||||
from twisted.web.client import Agent, readBody
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import Site
|
||||
from twisted.web.client import Agent, readBody # type: ignore
|
||||
from twisted.web.resource import Resource # type: ignore
|
||||
from twisted.web.server import Site # type: ignore
|
||||
# As of Twisted 15.0.0, twisted.web is present but fails our
|
||||
# tests due to internal str/bytes errors.
|
||||
have_twisted_web = sys.version_info < (3,)
|
||||
except ImportError:
|
||||
have_twisted_web = False
|
||||
|
||||
try:
|
||||
import thread # py2
|
||||
except ImportError:
|
||||
import _thread as thread # py3
|
||||
if PY3:
|
||||
import _thread as thread
|
||||
else:
|
||||
import thread
|
||||
ResourceWarning = None
|
||||
|
||||
from tornado.escape import utf8
|
||||
from tornado import gen
|
||||
from tornado.httpclient import AsyncHTTPClient
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.platform.auto import set_close_exec
|
||||
from tornado.platform.select import SelectIOLoop
|
||||
from tornado.testing import bind_unused_port
|
||||
from tornado.test.util import unittest
|
||||
from tornado.util import import_object
|
||||
from tornado.web import RequestHandler, Application
|
||||
try:
|
||||
import asyncio
|
||||
except ImportError:
|
||||
asyncio = None
|
||||
|
||||
skipIfNoTwisted = unittest.skipUnless(have_twisted,
|
||||
"twisted module not present")
|
||||
|
||||
skipIfPy26 = unittest.skipIf(sys.version_info < (2, 7),
|
||||
"twisted incompatible with singledispatch in py26")
|
||||
|
||||
|
||||
def save_signal_handlers():
|
||||
saved = {}
|
||||
|
|
@ -97,8 +99,10 @@ def restore_signal_handlers(saved):
|
|||
class ReactorTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self._saved_signals = save_signal_handlers()
|
||||
self._io_loop = IOLoop()
|
||||
self._reactor = TornadoReactor(self._io_loop)
|
||||
IOLoop.clear_current()
|
||||
self._io_loop = IOLoop(make_current=True)
|
||||
self._reactor = TornadoReactor()
|
||||
IOLoop.clear_current()
|
||||
|
||||
def tearDown(self):
|
||||
self._io_loop.close(all_fds=True)
|
||||
|
|
@ -219,53 +223,51 @@ class ReactorCallInThread(ReactorTestCase):
|
|||
self._reactor.run()
|
||||
|
||||
|
||||
class Reader(object):
|
||||
def __init__(self, fd, callback):
|
||||
self._fd = fd
|
||||
self._callback = callback
|
||||
|
||||
def logPrefix(self):
|
||||
return "Reader"
|
||||
|
||||
def close(self):
|
||||
self._fd.close()
|
||||
|
||||
def fileno(self):
|
||||
return self._fd.fileno()
|
||||
|
||||
def readConnectionLost(self, reason):
|
||||
self.close()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.close()
|
||||
|
||||
def doRead(self):
|
||||
self._callback(self._fd)
|
||||
if have_twisted:
|
||||
Reader = implementer(IReadDescriptor)(Reader)
|
||||
@implementer(IReadDescriptor)
|
||||
class Reader(object):
|
||||
def __init__(self, fd, callback):
|
||||
self._fd = fd
|
||||
self._callback = callback
|
||||
|
||||
def logPrefix(self):
|
||||
return "Reader"
|
||||
|
||||
class Writer(object):
|
||||
def __init__(self, fd, callback):
|
||||
self._fd = fd
|
||||
self._callback = callback
|
||||
def close(self):
|
||||
self._fd.close()
|
||||
|
||||
def logPrefix(self):
|
||||
return "Writer"
|
||||
def fileno(self):
|
||||
return self._fd.fileno()
|
||||
|
||||
def close(self):
|
||||
self._fd.close()
|
||||
def readConnectionLost(self, reason):
|
||||
self.close()
|
||||
|
||||
def fileno(self):
|
||||
return self._fd.fileno()
|
||||
def connectionLost(self, reason):
|
||||
self.close()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.close()
|
||||
def doRead(self):
|
||||
self._callback(self._fd)
|
||||
|
||||
def doWrite(self):
|
||||
self._callback(self._fd)
|
||||
if have_twisted:
|
||||
Writer = implementer(IWriteDescriptor)(Writer)
|
||||
@implementer(IWriteDescriptor)
|
||||
class Writer(object):
|
||||
def __init__(self, fd, callback):
|
||||
self._fd = fd
|
||||
self._callback = callback
|
||||
|
||||
def logPrefix(self):
|
||||
return "Writer"
|
||||
|
||||
def close(self):
|
||||
self._fd.close()
|
||||
|
||||
def fileno(self):
|
||||
return self._fd.fileno()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.close()
|
||||
|
||||
def doWrite(self):
|
||||
self._callback(self._fd)
|
||||
|
||||
|
||||
@skipIfNoTwisted
|
||||
|
|
@ -362,7 +364,7 @@ class CompatibilityTests(unittest.TestCase):
|
|||
self.saved_signals = save_signal_handlers()
|
||||
self.io_loop = IOLoop()
|
||||
self.io_loop.make_current()
|
||||
self.reactor = TornadoReactor(self.io_loop)
|
||||
self.reactor = TornadoReactor()
|
||||
|
||||
def tearDown(self):
|
||||
self.reactor.disconnectAll()
|
||||
|
|
@ -386,7 +388,7 @@ class CompatibilityTests(unittest.TestCase):
|
|||
self.write("Hello from tornado!")
|
||||
app = Application([('/', HelloHandler)],
|
||||
log_function=lambda x: None)
|
||||
server = HTTPServer(app, io_loop=self.io_loop)
|
||||
server = HTTPServer(app)
|
||||
sock, self.tornado_port = bind_unused_port()
|
||||
server.add_sockets([sock])
|
||||
|
||||
|
|
@ -402,7 +404,7 @@ class CompatibilityTests(unittest.TestCase):
|
|||
|
||||
def tornado_fetch(self, url, runner):
|
||||
responses = []
|
||||
client = AsyncHTTPClient(self.io_loop)
|
||||
client = AsyncHTTPClient()
|
||||
|
||||
def callback(response):
|
||||
responses.append(response)
|
||||
|
|
@ -495,7 +497,6 @@ class CompatibilityTests(unittest.TestCase):
|
|||
'http://127.0.0.1:%d' % self.tornado_port, self.run_reactor)
|
||||
self.assertEqual(response, 'Hello from tornado!')
|
||||
|
||||
@skipIfPy26
|
||||
def testTornadoServerTwistedCoroutineClientIOLoop(self):
|
||||
self.start_tornado_server()
|
||||
response = self.twisted_coroutine_fetch(
|
||||
|
|
@ -504,7 +505,6 @@ class CompatibilityTests(unittest.TestCase):
|
|||
|
||||
|
||||
@skipIfNoTwisted
|
||||
@skipIfPy26
|
||||
class ConvertDeferredTest(unittest.TestCase):
|
||||
def test_success(self):
|
||||
@inlineCallbacks
|
||||
|
|
@ -610,14 +610,14 @@ if have_twisted:
|
|||
test_class = import_object(test_name)
|
||||
except (ImportError, AttributeError):
|
||||
continue
|
||||
for test_func in blacklist:
|
||||
for test_func in blacklist: # type: ignore
|
||||
if hasattr(test_class, test_func):
|
||||
# The test_func may be defined in a mixin, so clobber
|
||||
# it instead of delattr()
|
||||
setattr(test_class, test_func, lambda self: None)
|
||||
|
||||
def make_test_subclass(test_class):
|
||||
class TornadoTest(test_class):
|
||||
class TornadoTest(test_class): # type: ignore
|
||||
_reactors = ["tornado.platform.twisted._TestReactor"]
|
||||
|
||||
def setUp(self):
|
||||
|
|
@ -627,10 +627,10 @@ if have_twisted:
|
|||
self.__curdir = os.getcwd()
|
||||
self.__tempdir = tempfile.mkdtemp()
|
||||
os.chdir(self.__tempdir)
|
||||
super(TornadoTest, self).setUp()
|
||||
super(TornadoTest, self).setUp() # type: ignore
|
||||
|
||||
def tearDown(self):
|
||||
super(TornadoTest, self).tearDown()
|
||||
super(TornadoTest, self).tearDown() # type: ignore
|
||||
os.chdir(self.__curdir)
|
||||
shutil.rmtree(self.__tempdir)
|
||||
|
||||
|
|
@ -645,7 +645,7 @@ if have_twisted:
|
|||
# enabled) but without our filter rules to ignore those
|
||||
# warnings from Twisted code.
|
||||
filtered = []
|
||||
for w in super(TornadoTest, self).flushWarnings(
|
||||
for w in super(TornadoTest, self).flushWarnings( # type: ignore
|
||||
*args, **kwargs):
|
||||
if w['category'] in (BytesWarning, ResourceWarning):
|
||||
continue
|
||||
|
|
@ -682,15 +682,15 @@ if have_twisted:
|
|||
|
||||
# Twisted recently introduced a new logger; disable that one too.
|
||||
try:
|
||||
from twisted.logger import globalLogBeginner
|
||||
from twisted.logger import globalLogBeginner # type: ignore
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
globalLogBeginner.beginLoggingTo([])
|
||||
globalLogBeginner.beginLoggingTo([], redirectStandardIO=False)
|
||||
|
||||
if have_twisted:
|
||||
class LayeredTwistedIOLoop(TwistedIOLoop):
|
||||
"""Layers a TwistedIOLoop on top of a TornadoReactor on a SelectIOLoop.
|
||||
"""Layers a TwistedIOLoop on top of a TornadoReactor on a PollIOLoop.
|
||||
|
||||
This is of course silly, but is useful for testing purposes to make
|
||||
sure we're implementing both sides of the various interfaces
|
||||
|
|
@ -698,11 +698,8 @@ if have_twisted:
|
|||
of the whole stack.
|
||||
"""
|
||||
def initialize(self, **kwargs):
|
||||
# When configured to use LayeredTwistedIOLoop we can't easily
|
||||
# get the next-best IOLoop implementation, so use the lowest common
|
||||
# denominator.
|
||||
self.real_io_loop = SelectIOLoop(make_current=False)
|
||||
reactor = TornadoReactor(io_loop=self.real_io_loop)
|
||||
self.real_io_loop = PollIOLoop(make_current=False) # type: ignore
|
||||
reactor = self.real_io_loop.run_sync(gen.coroutine(TornadoReactor))
|
||||
super(LayeredTwistedIOLoop, self).initialize(reactor=reactor, **kwargs)
|
||||
self.add_callback(self.make_current)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,22 +1,17 @@
|
|||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import platform
|
||||
import socket
|
||||
import sys
|
||||
import textwrap
|
||||
import warnings
|
||||
|
||||
from tornado.testing import bind_unused_port
|
||||
|
||||
# Encapsulate the choice of unittest or unittest2 here.
|
||||
# To be used as 'from tornado.test.util import unittest'.
|
||||
if sys.version_info < (2, 7):
|
||||
# In py26, we must always use unittest2.
|
||||
import unittest2 as unittest
|
||||
else:
|
||||
# Otherwise, use whichever version of unittest was imported in
|
||||
# tornado.testing.
|
||||
from tornado.testing import unittest
|
||||
# Delegate the choice of unittest or unittest2 to tornado.testing.
|
||||
from tornado.testing import unittest
|
||||
|
||||
skipIfNonUnix = unittest.skipIf(os.name != 'posix' or sys.platform == 'cygwin',
|
||||
"non-unix platform")
|
||||
|
|
@ -34,14 +29,39 @@ skipOnAppEngine = unittest.skipIf('APPENGINE_RUNTIME' in os.environ,
|
|||
skipIfNoNetwork = unittest.skipIf('NO_NETWORK' in os.environ,
|
||||
'network access disabled')
|
||||
|
||||
skipIfNoIPv6 = unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present')
|
||||
|
||||
|
||||
skipBefore33 = unittest.skipIf(sys.version_info < (3, 3), 'PEP 380 (yield from) not available')
|
||||
skipBefore35 = unittest.skipIf(sys.version_info < (3, 5), 'PEP 492 (async/await) not available')
|
||||
skipNotCPython = unittest.skipIf(platform.python_implementation() != 'CPython',
|
||||
'Not CPython implementation')
|
||||
|
||||
# Used for tests affected by
|
||||
# https://bitbucket.org/pypy/pypy/issues/2616/incomplete-error-handling-in
|
||||
# TODO: remove this after pypy3 5.8 is obsolete.
|
||||
skipPypy3V58 = unittest.skipIf(platform.python_implementation() == 'PyPy' and
|
||||
sys.version_info > (3,) and
|
||||
sys.pypy_version_info < (5, 9),
|
||||
'pypy3 5.8 has buggy ssl module')
|
||||
|
||||
|
||||
def _detect_ipv6():
|
||||
if not socket.has_ipv6:
|
||||
# socket.has_ipv6 check reports whether ipv6 was present at compile
|
||||
# time. It's usually true even when ipv6 doesn't work for other reasons.
|
||||
return False
|
||||
sock = None
|
||||
try:
|
||||
sock = socket.socket(socket.AF_INET6)
|
||||
sock.bind(('::1', 0))
|
||||
except socket.error:
|
||||
return False
|
||||
finally:
|
||||
if sock is not None:
|
||||
sock.close()
|
||||
return True
|
||||
|
||||
|
||||
skipIfNoIPv6 = unittest.skipIf(not _detect_ipv6(), 'ipv6 support not present')
|
||||
|
||||
|
||||
def refusing_port():
|
||||
"""Returns a local port number that will refuse all connections.
|
||||
|
|
@ -72,7 +92,27 @@ def exec_test(caller_globals, caller_locals, s):
|
|||
# Flatten the real global and local namespace into our fake
|
||||
# globals: it's all global from the perspective of code defined
|
||||
# in s.
|
||||
global_namespace = dict(caller_globals, **caller_locals)
|
||||
global_namespace = dict(caller_globals, **caller_locals) # type: ignore
|
||||
local_namespace = {}
|
||||
exec(textwrap.dedent(s), global_namespace, local_namespace)
|
||||
return local_namespace
|
||||
|
||||
|
||||
def subTest(test, *args, **kwargs):
|
||||
"""Compatibility shim for unittest.TestCase.subTest.
|
||||
|
||||
Usage: ``with tornado.test.util.subTest(self, x=x):``
|
||||
"""
|
||||
try:
|
||||
subTest = test.subTest # py34+
|
||||
except AttributeError:
|
||||
subTest = contextlib.contextmanager(lambda *a, **kw: (yield))
|
||||
return subTest(*args, **kwargs)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ignore_deprecation():
|
||||
"""Context manager to ignore deprecation warnings."""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', DeprecationWarning)
|
||||
yield
|
||||
|
|
|
|||
|
|
@ -1,17 +1,21 @@
|
|||
# coding: utf-8
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import re
|
||||
import sys
|
||||
import datetime
|
||||
|
||||
import tornado.escape
|
||||
from tornado.escape import utf8
|
||||
from tornado.util import raise_exc_info, Configurable, u, exec_in, ArgReplacer, timedelta_to_seconds, import_object
|
||||
from tornado.test.util import unittest
|
||||
from tornado.util import (
|
||||
raise_exc_info, Configurable, exec_in, ArgReplacer,
|
||||
timedelta_to_seconds, import_object, re_unescape, is_finalizing, PY3,
|
||||
)
|
||||
|
||||
try:
|
||||
from cStringIO import StringIO # py2
|
||||
except ImportError:
|
||||
from io import StringIO # py3
|
||||
if PY3:
|
||||
from io import StringIO
|
||||
else:
|
||||
from cStringIO import StringIO
|
||||
|
||||
|
||||
class RaiseExcInfoTest(unittest.TestCase):
|
||||
|
|
@ -57,12 +61,35 @@ class TestConfig2(TestConfigurable):
|
|||
self.pos_arg = pos_arg
|
||||
|
||||
|
||||
class TestConfig3(TestConfigurable):
|
||||
# TestConfig3 is a configuration option that is itself configurable.
|
||||
@classmethod
|
||||
def configurable_base(cls):
|
||||
return TestConfig3
|
||||
|
||||
@classmethod
|
||||
def configurable_default(cls):
|
||||
return TestConfig3A
|
||||
|
||||
|
||||
class TestConfig3A(TestConfig3):
|
||||
def initialize(self, a=None):
|
||||
self.a = a
|
||||
|
||||
|
||||
class TestConfig3B(TestConfig3):
|
||||
def initialize(self, b=None):
|
||||
self.b = b
|
||||
|
||||
|
||||
class ConfigurableTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.saved = TestConfigurable._save_configuration()
|
||||
self.saved3 = TestConfig3._save_configuration()
|
||||
|
||||
def tearDown(self):
|
||||
TestConfigurable._restore_configuration(self.saved)
|
||||
TestConfig3._restore_configuration(self.saved3)
|
||||
|
||||
def checkSubclasses(self):
|
||||
# no matter how the class is configured, it should always be
|
||||
|
|
@ -130,10 +157,43 @@ class ConfigurableTest(unittest.TestCase):
|
|||
obj = TestConfig2()
|
||||
self.assertIs(obj.b, None)
|
||||
|
||||
def test_config_multi_level(self):
|
||||
TestConfigurable.configure(TestConfig3, a=1)
|
||||
obj = TestConfigurable()
|
||||
self.assertIsInstance(obj, TestConfig3A)
|
||||
self.assertEqual(obj.a, 1)
|
||||
|
||||
TestConfigurable.configure(TestConfig3)
|
||||
TestConfig3.configure(TestConfig3B, b=2)
|
||||
obj = TestConfigurable()
|
||||
self.assertIsInstance(obj, TestConfig3B)
|
||||
self.assertEqual(obj.b, 2)
|
||||
|
||||
def test_config_inner_level(self):
|
||||
# The inner level can be used even when the outer level
|
||||
# doesn't point to it.
|
||||
obj = TestConfig3()
|
||||
self.assertIsInstance(obj, TestConfig3A)
|
||||
|
||||
TestConfig3.configure(TestConfig3B)
|
||||
obj = TestConfig3()
|
||||
self.assertIsInstance(obj, TestConfig3B)
|
||||
|
||||
# Configuring the base doesn't configure the inner.
|
||||
obj = TestConfigurable()
|
||||
self.assertIsInstance(obj, TestConfig1)
|
||||
TestConfigurable.configure(TestConfig2)
|
||||
|
||||
obj = TestConfigurable()
|
||||
self.assertIsInstance(obj, TestConfig2)
|
||||
|
||||
obj = TestConfig3()
|
||||
self.assertIsInstance(obj, TestConfig3B)
|
||||
|
||||
|
||||
class UnicodeLiteralTest(unittest.TestCase):
|
||||
def test_unicode_escapes(self):
|
||||
self.assertEqual(utf8(u('\u00e9')), b'\xc3\xa9')
|
||||
self.assertEqual(utf8(u'\u00e9'), b'\xc3\xa9')
|
||||
|
||||
|
||||
class ExecInTest(unittest.TestCase):
|
||||
|
|
@ -189,7 +249,7 @@ class ImportObjectTest(unittest.TestCase):
|
|||
self.assertIs(import_object('tornado.escape.utf8'), utf8)
|
||||
|
||||
def test_import_member_unicode(self):
|
||||
self.assertIs(import_object(u('tornado.escape.utf8')), utf8)
|
||||
self.assertIs(import_object(u'tornado.escape.utf8'), utf8)
|
||||
|
||||
def test_import_module(self):
|
||||
self.assertIs(import_object('tornado.escape'), tornado.escape)
|
||||
|
|
@ -198,4 +258,29 @@ class ImportObjectTest(unittest.TestCase):
|
|||
# The internal implementation of __import__ differs depending on
|
||||
# whether the thing being imported is a module or not.
|
||||
# This variant requires a byte string in python 2.
|
||||
self.assertIs(import_object(u('tornado.escape')), tornado.escape)
|
||||
self.assertIs(import_object(u'tornado.escape'), tornado.escape)
|
||||
|
||||
|
||||
class ReUnescapeTest(unittest.TestCase):
|
||||
def test_re_unescape(self):
|
||||
test_strings = (
|
||||
'/favicon.ico',
|
||||
'index.html',
|
||||
'Hello, World!',
|
||||
'!$@#%;',
|
||||
)
|
||||
for string in test_strings:
|
||||
self.assertEqual(string, re_unescape(re.escape(string)))
|
||||
|
||||
def test_re_unescape_raises_error_on_invalid_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
re_unescape('\\d')
|
||||
with self.assertRaises(ValueError):
|
||||
re_unescape('\\b')
|
||||
with self.assertRaises(ValueError):
|
||||
re_unescape('\\Z')
|
||||
|
||||
|
||||
class IsFinalizingTest(unittest.TestCase):
|
||||
def test_basic(self):
|
||||
self.assertFalse(is_finalizing())
|
||||
|
|
|
|||
|
|
@ -1,18 +1,26 @@
|
|||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from tornado.concurrent import Future
|
||||
from tornado import gen
|
||||
from tornado.escape import json_decode, utf8, to_unicode, recursive_unicode, native_str, to_basestring
|
||||
from tornado.escape import json_decode, utf8, to_unicode, recursive_unicode, native_str, to_basestring # noqa: E501
|
||||
from tornado.httpclient import HTTPClientError
|
||||
from tornado.httputil import format_timestamp
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.iostream import IOStream
|
||||
from tornado import locale
|
||||
from tornado.locks import Event
|
||||
from tornado.log import app_log, gen_log
|
||||
from tornado.simple_httpclient import SimpleAsyncHTTPClient
|
||||
from tornado.template import DictLoader
|
||||
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
|
||||
from tornado.test.util import unittest, skipBefore35, exec_test
|
||||
from tornado.util import u, ObjectDict, unicode_type, timedelta_to_seconds
|
||||
from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish, removeslash, addslash, RedirectHandler as WebRedirectHandler, get_signature_key_version, GZipContentEncoding
|
||||
from tornado.test.util import unittest, skipBefore35, exec_test, ignore_deprecation
|
||||
from tornado.util import ObjectDict, unicode_type, timedelta_to_seconds, PY3
|
||||
from tornado.web import (
|
||||
Application, RequestHandler, StaticFileHandler, RedirectHandler as WebRedirectHandler,
|
||||
HTTPError, MissingArgumentError, ErrorHandler, authenticated, asynchronous, url,
|
||||
_create_signature_v1, create_signed_value, decode_signed_value, get_signature_key_version,
|
||||
UIModule, Finish, stream_request_body, removeslash, addslash, GZipContentEncoding,
|
||||
)
|
||||
|
||||
import binascii
|
||||
import contextlib
|
||||
|
|
@ -27,14 +35,16 @@ import os
|
|||
import re
|
||||
import socket
|
||||
|
||||
try:
|
||||
if PY3:
|
||||
import urllib.parse as urllib_parse # py3
|
||||
except ImportError:
|
||||
else:
|
||||
import urllib as urllib_parse # py2
|
||||
|
||||
wsgi_safe_tests = []
|
||||
|
||||
relpath = lambda *a: os.path.join(os.path.dirname(__file__), *a)
|
||||
|
||||
def relpath(*a):
|
||||
return os.path.join(os.path.dirname(__file__), *a)
|
||||
|
||||
|
||||
def wsgi_safe(cls):
|
||||
|
|
@ -181,6 +191,42 @@ class SecureCookieV2Test(unittest.TestCase):
|
|||
self.assertEqual(new_handler.get_secure_cookie('foo'), None)
|
||||
|
||||
|
||||
class FinalReturnTest(WebTestCase):
|
||||
def get_handlers(self):
|
||||
test = self
|
||||
|
||||
class FinishHandler(RequestHandler):
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
test.final_return = self.finish()
|
||||
yield test.final_return
|
||||
|
||||
class RenderHandler(RequestHandler):
|
||||
def create_template_loader(self, path):
|
||||
return DictLoader({'foo.html': 'hi'})
|
||||
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
test.final_return = self.render('foo.html')
|
||||
|
||||
return [("/finish", FinishHandler),
|
||||
("/render", RenderHandler)]
|
||||
|
||||
def get_app_kwargs(self):
|
||||
return dict(template_path='FinalReturnTest')
|
||||
|
||||
def test_finish_method_return_future(self):
|
||||
response = self.fetch(self.get_url('/finish'))
|
||||
self.assertEqual(response.code, 200)
|
||||
self.assertIsInstance(self.final_return, Future)
|
||||
self.assertTrue(self.final_return.done())
|
||||
|
||||
def test_render_method_return_future(self):
|
||||
response = self.fetch(self.get_url('/render'))
|
||||
self.assertEqual(response.code, 200)
|
||||
self.assertIsInstance(self.final_return, Future)
|
||||
|
||||
|
||||
class CookieTest(WebTestCase):
|
||||
def get_handlers(self):
|
||||
class SetCookieHandler(RequestHandler):
|
||||
|
|
@ -188,7 +234,7 @@ class CookieTest(WebTestCase):
|
|||
# Try setting cookies with different argument types
|
||||
# to ensure that everything gets encoded correctly
|
||||
self.set_cookie("str", "asdf")
|
||||
self.set_cookie("unicode", u("qwer"))
|
||||
self.set_cookie("unicode", u"qwer")
|
||||
self.set_cookie("bytes", b"zxcv")
|
||||
|
||||
class GetCookieHandler(RequestHandler):
|
||||
|
|
@ -199,8 +245,8 @@ class CookieTest(WebTestCase):
|
|||
def get(self):
|
||||
# unicode domain and path arguments shouldn't break things
|
||||
# either (see bug #285)
|
||||
self.set_cookie("unicode_args", "blah", domain=u("foo.com"),
|
||||
path=u("/foo"))
|
||||
self.set_cookie("unicode_args", "blah", domain=u"foo.com",
|
||||
path=u"/foo")
|
||||
|
||||
class SetCookieSpecialCharHandler(RequestHandler):
|
||||
def get(self):
|
||||
|
|
@ -277,8 +323,8 @@ class CookieTest(WebTestCase):
|
|||
|
||||
data = [('foo=a=b', 'a=b'),
|
||||
('foo="a=b"', 'a=b'),
|
||||
('foo="a;b"', 'a;b'),
|
||||
# ('foo=a\\073b', 'a;b'), # even encoded, ";" is a delimiter
|
||||
('foo="a;b"', '"a'), # even quoted, ";" is a delimiter
|
||||
('foo=a\\073b', 'a\\073b'), # escapes only decoded in quotes
|
||||
('foo="a\\073b"', 'a;b'),
|
||||
('foo="a\\"b"', 'a"b'),
|
||||
]
|
||||
|
|
@ -342,19 +388,17 @@ class AuthRedirectTest(WebTestCase):
|
|||
dict(login_url='http://example.com/login'))]
|
||||
|
||||
def test_relative_auth_redirect(self):
|
||||
self.http_client.fetch(self.get_url('/relative'), self.stop,
|
||||
follow_redirects=False)
|
||||
response = self.wait()
|
||||
response = self.fetch(self.get_url('/relative'),
|
||||
follow_redirects=False)
|
||||
self.assertEqual(response.code, 302)
|
||||
self.assertEqual(response.headers['Location'], '/login?next=%2Frelative')
|
||||
|
||||
def test_absolute_auth_redirect(self):
|
||||
self.http_client.fetch(self.get_url('/absolute'), self.stop,
|
||||
follow_redirects=False)
|
||||
response = self.wait()
|
||||
response = self.fetch(self.get_url('/absolute'),
|
||||
follow_redirects=False)
|
||||
self.assertEqual(response.code, 302)
|
||||
self.assertTrue(re.match(
|
||||
'http://example.com/login\?next=http%3A%2F%2Flocalhost%3A[0-9]+%2Fabsolute',
|
||||
'http://example.com/login\?next=http%3A%2F%2F127.0.0.1%3A[0-9]+%2Fabsolute',
|
||||
response.headers['Location']), response.headers['Location'])
|
||||
|
||||
|
||||
|
|
@ -362,9 +406,11 @@ class ConnectionCloseHandler(RequestHandler):
|
|||
def initialize(self, test):
|
||||
self.test = test
|
||||
|
||||
@asynchronous
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
self.test.on_handler_waiting()
|
||||
never_finish = Event()
|
||||
yield never_finish.wait()
|
||||
|
||||
def on_connection_close(self):
|
||||
self.test.on_connection_close()
|
||||
|
|
@ -377,7 +423,7 @@ class ConnectionCloseTest(WebTestCase):
|
|||
def test_connection_close(self):
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
|
||||
s.connect(("127.0.0.1", self.get_http_port()))
|
||||
self.stream = IOStream(s, io_loop=self.io_loop)
|
||||
self.stream = IOStream(s)
|
||||
self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
|
||||
self.wait()
|
||||
|
||||
|
|
@ -434,9 +480,9 @@ class RequestEncodingTest(WebTestCase):
|
|||
def test_group_encoding(self):
|
||||
# Path components and query arguments should be decoded the same way
|
||||
self.assertEqual(self.fetch_json('/group/%C3%A9?arg=%C3%A9'),
|
||||
{u("path"): u("/group/%C3%A9"),
|
||||
u("path_args"): [u("\u00e9")],
|
||||
u("args"): {u("arg"): [u("\u00e9")]}})
|
||||
{u"path": u"/group/%C3%A9",
|
||||
u"path_args": [u"\u00e9"],
|
||||
u"args": {u"arg": [u"\u00e9"]}})
|
||||
|
||||
def test_slashes(self):
|
||||
# Slashes may be escaped to appear as a single "directory" in the path,
|
||||
|
|
@ -541,14 +587,17 @@ class OptionalPathHandler(RequestHandler):
|
|||
class FlowControlHandler(RequestHandler):
|
||||
# These writes are too small to demonstrate real flow control,
|
||||
# but at least it shows that the callbacks get run.
|
||||
@asynchronous
|
||||
def get(self):
|
||||
self.write("1")
|
||||
self.flush(callback=self.step2)
|
||||
with ignore_deprecation():
|
||||
@asynchronous
|
||||
def get(self):
|
||||
self.write("1")
|
||||
with ignore_deprecation():
|
||||
self.flush(callback=self.step2)
|
||||
|
||||
def step2(self):
|
||||
self.write("2")
|
||||
self.flush(callback=self.step3)
|
||||
with ignore_deprecation():
|
||||
self.flush(callback=self.step3)
|
||||
|
||||
def step3(self):
|
||||
self.write("3")
|
||||
|
|
@ -574,14 +623,13 @@ class RedirectHandler(RequestHandler):
|
|||
|
||||
|
||||
class EmptyFlushCallbackHandler(RequestHandler):
|
||||
@asynchronous
|
||||
@gen.engine
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
# Ensure that the flush callback is run whether or not there
|
||||
# was any output. The gen.Task and direct yield forms are
|
||||
# equivalent.
|
||||
yield gen.Task(self.flush) # "empty" flush, but writes headers
|
||||
yield gen.Task(self.flush) # empty flush
|
||||
yield self.flush() # "empty" flush, but writes headers
|
||||
yield self.flush() # empty flush
|
||||
self.write("o")
|
||||
yield self.flush() # flushes the "o"
|
||||
yield self.flush() # empty flush
|
||||
|
|
@ -633,7 +681,12 @@ class WSGISafeWebTest(WebTestCase):
|
|||
{% end %}
|
||||
</body></html>""",
|
||||
"entry.html": """\
|
||||
{{ set_resources(embedded_css=".entry { margin-bottom: 1em; }", embedded_javascript="js_embed()", css_files=["/base.css", "/foo.css"], javascript_files="/common.js", html_head="<meta>", html_body='<script src="/analytics.js"/>') }}
|
||||
{{ set_resources(embedded_css=".entry { margin-bottom: 1em; }",
|
||||
embedded_javascript="js_embed()",
|
||||
css_files=["/base.css", "/foo.css"],
|
||||
javascript_files="/common.js",
|
||||
html_head="<meta>",
|
||||
html_body='<script src="/analytics.js"/>') }}
|
||||
<div class="entry">...</div>""",
|
||||
})
|
||||
return dict(template_loader=loader,
|
||||
|
|
@ -655,8 +708,10 @@ class WSGISafeWebTest(WebTestCase):
|
|||
url("/multi_header", MultiHeaderHandler),
|
||||
url("/redirect", RedirectHandler),
|
||||
url("/web_redirect_permanent", WebRedirectHandler, {"url": "/web_redirect_newpath"}),
|
||||
url("/web_redirect", WebRedirectHandler, {"url": "/web_redirect_newpath", "permanent": False}),
|
||||
url("//web_redirect_double_slash", WebRedirectHandler, {"url": '/web_redirect_newpath'}),
|
||||
url("/web_redirect", WebRedirectHandler,
|
||||
{"url": "/web_redirect_newpath", "permanent": False}),
|
||||
url("//web_redirect_double_slash", WebRedirectHandler,
|
||||
{"url": '/web_redirect_newpath'}),
|
||||
url("/header_injection", HeaderInjectionHandler),
|
||||
url("/get_argument", GetArgumentHandler),
|
||||
url("/get_arguments", GetArgumentsHandler),
|
||||
|
|
@ -690,15 +745,15 @@ class WSGISafeWebTest(WebTestCase):
|
|||
response = self.fetch(req_url)
|
||||
response.rethrow()
|
||||
data = json_decode(response.body)
|
||||
self.assertEqual(data, {u('path'): [u('unicode'), u('\u00e9')],
|
||||
u('query'): [u('unicode'), u('\u00e9')],
|
||||
self.assertEqual(data, {u'path': [u'unicode', u'\u00e9'],
|
||||
u'query': [u'unicode', u'\u00e9'],
|
||||
})
|
||||
|
||||
response = self.fetch("/decode_arg/%C3%A9?foo=%C3%A9")
|
||||
response.rethrow()
|
||||
data = json_decode(response.body)
|
||||
self.assertEqual(data, {u('path'): [u('bytes'), u('c3a9')],
|
||||
u('query'): [u('bytes'), u('c3a9')],
|
||||
self.assertEqual(data, {u'path': [u'bytes', u'c3a9'],
|
||||
u'query': [u'bytes', u'c3a9'],
|
||||
})
|
||||
|
||||
def test_decode_argument_invalid_unicode(self):
|
||||
|
|
@ -717,8 +772,8 @@ class WSGISafeWebTest(WebTestCase):
|
|||
response = self.fetch(req_url)
|
||||
response.rethrow()
|
||||
data = json_decode(response.body)
|
||||
self.assertEqual(data, {u('path'): [u('unicode'), u('1 + 1')],
|
||||
u('query'): [u('unicode'), u('1 + 1')],
|
||||
self.assertEqual(data, {u'path': [u'unicode', u'1 + 1'],
|
||||
u'query': [u'unicode', u'1 + 1'],
|
||||
})
|
||||
|
||||
def test_reverse_url(self):
|
||||
|
|
@ -728,7 +783,7 @@ class WSGISafeWebTest(WebTestCase):
|
|||
'/decode_arg/42')
|
||||
self.assertEqual(self.app.reverse_url('decode_arg', b'\xe9'),
|
||||
'/decode_arg/%E9')
|
||||
self.assertEqual(self.app.reverse_url('decode_arg', u('\u00e9')),
|
||||
self.assertEqual(self.app.reverse_url('decode_arg', u'\u00e9'),
|
||||
'/decode_arg/%C3%A9')
|
||||
self.assertEqual(self.app.reverse_url('decode_arg', '1 + 1'),
|
||||
'/decode_arg/1%20%2B%201')
|
||||
|
|
@ -761,13 +816,13 @@ js_embed()
|
|||
//]]>
|
||||
</script>
|
||||
<script src="/analytics.js"/>
|
||||
</body></html>""")
|
||||
</body></html>""") # noqa: E501
|
||||
|
||||
def test_optional_path(self):
|
||||
self.assertEqual(self.fetch_json("/optional_path/foo"),
|
||||
{u("path"): u("foo")})
|
||||
{u"path": u"foo"})
|
||||
self.assertEqual(self.fetch_json("/optional_path/"),
|
||||
{u("path"): None})
|
||||
{u"path": None})
|
||||
|
||||
def test_multi_header(self):
|
||||
response = self.fetch("/multi_header")
|
||||
|
|
@ -915,6 +970,10 @@ class ErrorResponseTest(WebTestCase):
|
|||
self.assertEqual(response.code, 503)
|
||||
self.assertTrue(b"503: Service Unavailable" in response.body)
|
||||
|
||||
response = self.fetch("/default?status=435")
|
||||
self.assertEqual(response.code, 435)
|
||||
self.assertTrue(b"435: Unknown" in response.body)
|
||||
|
||||
def test_write_error(self):
|
||||
with ExpectLog(app_log, "Uncaught exception"):
|
||||
response = self.fetch("/write_error")
|
||||
|
|
@ -1062,6 +1121,13 @@ class StaticFileTest(WebTestCase):
|
|||
'If-None-Match': response1.headers['Etag']})
|
||||
self.assertEqual(response2.code, 304)
|
||||
|
||||
def test_static_304_etag_modified_bug(self):
|
||||
response1 = self.get_and_head("/static/robots.txt")
|
||||
response2 = self.get_and_head("/static/robots.txt", headers={
|
||||
'If-None-Match': '"MISMATCH"',
|
||||
'If-Modified-Since': response1.headers['Last-Modified']})
|
||||
self.assertEqual(response2.code, 200)
|
||||
|
||||
def test_static_if_modified_since_pre_epoch(self):
|
||||
# On windows, the functions that work with time_t do not accept
|
||||
# negative values, and at least one client (processing.js) seems
|
||||
|
|
@ -1346,6 +1412,8 @@ class HostMatchingTest(WebTestCase):
|
|||
[("/bar", HostMatchingTest.Handler, {"reply": "[1]"})])
|
||||
self.app.add_handlers("www.example.com",
|
||||
[("/baz", HostMatchingTest.Handler, {"reply": "[2]"})])
|
||||
self.app.add_handlers("www.e.*e.com",
|
||||
[("/baz", HostMatchingTest.Handler, {"reply": "[3]"})])
|
||||
|
||||
response = self.fetch("/foo")
|
||||
self.assertEqual(response.body, b"wildcard")
|
||||
|
|
@ -1360,6 +1428,40 @@ class HostMatchingTest(WebTestCase):
|
|||
self.assertEqual(response.body, b"[1]")
|
||||
response = self.fetch("/baz", headers={'Host': 'www.example.com'})
|
||||
self.assertEqual(response.body, b"[2]")
|
||||
response = self.fetch("/baz", headers={'Host': 'www.exe.com'})
|
||||
self.assertEqual(response.body, b"[3]")
|
||||
|
||||
|
||||
@wsgi_safe
|
||||
class DefaultHostMatchingTest(WebTestCase):
|
||||
def get_handlers(self):
|
||||
return []
|
||||
|
||||
def get_app_kwargs(self):
|
||||
return {'default_host': "www.example.com"}
|
||||
|
||||
def test_default_host_matching(self):
|
||||
self.app.add_handlers("www.example.com",
|
||||
[("/foo", HostMatchingTest.Handler, {"reply": "[0]"})])
|
||||
self.app.add_handlers(r"www\.example\.com",
|
||||
[("/bar", HostMatchingTest.Handler, {"reply": "[1]"})])
|
||||
self.app.add_handlers("www.test.com",
|
||||
[("/baz", HostMatchingTest.Handler, {"reply": "[2]"})])
|
||||
|
||||
response = self.fetch("/foo")
|
||||
self.assertEqual(response.body, b"[0]")
|
||||
response = self.fetch("/bar")
|
||||
self.assertEqual(response.body, b"[1]")
|
||||
response = self.fetch("/baz")
|
||||
self.assertEqual(response.code, 404)
|
||||
|
||||
response = self.fetch("/foo", headers={"X-Real-Ip": "127.0.0.1"})
|
||||
self.assertEqual(response.code, 404)
|
||||
|
||||
self.app.default_host = "www.test.com"
|
||||
|
||||
response = self.fetch("/baz")
|
||||
self.assertEqual(response.body, b"[2]")
|
||||
|
||||
|
||||
@wsgi_safe
|
||||
|
|
@ -1370,7 +1472,7 @@ class NamedURLSpecGroupsTest(WebTestCase):
|
|||
self.write(path)
|
||||
|
||||
return [("/str/(?P<path>.*)", EchoHandler),
|
||||
(u("/unicode/(?P<path>.*)"), EchoHandler)]
|
||||
(u"/unicode/(?P<path>.*)", EchoHandler)]
|
||||
|
||||
def test_named_urlspec_groups(self):
|
||||
response = self.fetch("/str/foo")
|
||||
|
|
@ -1395,6 +1497,19 @@ class ClearHeaderTest(SimpleHandlerTestCase):
|
|||
self.assertEqual(response.headers["h2"], "bar")
|
||||
|
||||
|
||||
class Header204Test(SimpleHandlerTestCase):
|
||||
class Handler(RequestHandler):
|
||||
def get(self):
|
||||
self.set_status(204)
|
||||
self.finish()
|
||||
|
||||
def test_204_headers(self):
|
||||
response = self.fetch('/')
|
||||
self.assertEqual(response.code, 204)
|
||||
self.assertNotIn("Content-Length", response.headers)
|
||||
self.assertNotIn("Transfer-Encoding", response.headers)
|
||||
|
||||
|
||||
@wsgi_safe
|
||||
class Header304Test(SimpleHandlerTestCase):
|
||||
class Handler(RequestHandler):
|
||||
|
|
@ -1426,7 +1541,7 @@ class StatusReasonTest(SimpleHandlerTestCase):
|
|||
|
||||
def get_http_client(self):
|
||||
# simple_httpclient only: curl doesn't expose the reason string
|
||||
return SimpleAsyncHTTPClient(io_loop=self.io_loop)
|
||||
return SimpleAsyncHTTPClient()
|
||||
|
||||
def test_status(self):
|
||||
response = self.fetch("/?code=304")
|
||||
|
|
@ -1438,9 +1553,9 @@ class StatusReasonTest(SimpleHandlerTestCase):
|
|||
response = self.fetch("/?code=682&reason=Bar")
|
||||
self.assertEqual(response.code, 682)
|
||||
self.assertEqual(response.reason, "Bar")
|
||||
with ExpectLog(app_log, 'Uncaught exception'):
|
||||
response = self.fetch("/?code=682")
|
||||
self.assertEqual(response.code, 500)
|
||||
response = self.fetch("/?code=682")
|
||||
self.assertEqual(response.code, 682)
|
||||
self.assertEqual(response.reason, "Unknown")
|
||||
|
||||
|
||||
@wsgi_safe
|
||||
|
|
@ -1465,7 +1580,7 @@ class RaiseWithReasonTest(SimpleHandlerTestCase):
|
|||
|
||||
def get_http_client(self):
|
||||
# simple_httpclient only: curl doesn't expose the reason string
|
||||
return SimpleAsyncHTTPClient(io_loop=self.io_loop)
|
||||
return SimpleAsyncHTTPClient()
|
||||
|
||||
def test_raise_with_reason(self):
|
||||
response = self.fetch("/")
|
||||
|
|
@ -1476,6 +1591,9 @@ class RaiseWithReasonTest(SimpleHandlerTestCase):
|
|||
def test_httperror_str(self):
|
||||
self.assertEqual(str(HTTPError(682, reason="Foo")), "HTTP 682: Foo")
|
||||
|
||||
def test_httperror_str_from_httputil(self):
|
||||
self.assertEqual(str(HTTPError(682)), "HTTP 682: Unknown")
|
||||
|
||||
|
||||
@wsgi_safe
|
||||
class ErrorHandlerXSRFTest(WebTestCase):
|
||||
|
|
@ -1501,8 +1619,8 @@ class ErrorHandlerXSRFTest(WebTestCase):
|
|||
class GzipTestCase(SimpleHandlerTestCase):
|
||||
class Handler(RequestHandler):
|
||||
def get(self):
|
||||
if self.get_argument('vary', None):
|
||||
self.set_header('Vary', self.get_argument('vary'))
|
||||
for v in self.get_arguments('vary'):
|
||||
self.add_header('Vary', v)
|
||||
# Must write at least MIN_LENGTH bytes to activate compression.
|
||||
self.write('hello world' + ('!' * GZipContentEncoding.MIN_LENGTH))
|
||||
|
||||
|
|
@ -1511,8 +1629,7 @@ class GzipTestCase(SimpleHandlerTestCase):
|
|||
gzip=True,
|
||||
static_path=os.path.join(os.path.dirname(__file__), 'static'))
|
||||
|
||||
def test_gzip(self):
|
||||
response = self.fetch('/')
|
||||
def assert_compressed(self, response):
|
||||
# simple_httpclient renames the content-encoding header;
|
||||
# curl_httpclient doesn't.
|
||||
self.assertEqual(
|
||||
|
|
@ -1520,17 +1637,17 @@ class GzipTestCase(SimpleHandlerTestCase):
|
|||
'Content-Encoding',
|
||||
response.headers.get('X-Consumed-Content-Encoding')),
|
||||
'gzip')
|
||||
|
||||
def test_gzip(self):
|
||||
response = self.fetch('/')
|
||||
self.assert_compressed(response)
|
||||
self.assertEqual(response.headers['Vary'], 'Accept-Encoding')
|
||||
|
||||
def test_gzip_static(self):
|
||||
# The streaming responses in StaticFileHandler have subtle
|
||||
# interactions with the gzip output so test this case separately.
|
||||
response = self.fetch('/robots.txt')
|
||||
self.assertEqual(
|
||||
response.headers.get(
|
||||
'Content-Encoding',
|
||||
response.headers.get('X-Consumed-Content-Encoding')),
|
||||
'gzip')
|
||||
self.assert_compressed(response)
|
||||
self.assertEqual(response.headers['Vary'], 'Accept-Encoding')
|
||||
|
||||
def test_gzip_not_requested(self):
|
||||
|
|
@ -1540,8 +1657,16 @@ class GzipTestCase(SimpleHandlerTestCase):
|
|||
|
||||
def test_vary_already_present(self):
|
||||
response = self.fetch('/?vary=Accept-Language')
|
||||
self.assertEqual(response.headers['Vary'],
|
||||
'Accept-Language, Accept-Encoding')
|
||||
self.assert_compressed(response)
|
||||
self.assertEqual([s.strip() for s in response.headers['Vary'].split(',')],
|
||||
['Accept-Language', 'Accept-Encoding'])
|
||||
|
||||
def test_vary_already_present_multiple(self):
|
||||
# Regression test for https://github.com/tornadoweb/tornado/issues/1670
|
||||
response = self.fetch('/?vary=Accept-Language&vary=Cookie')
|
||||
self.assert_compressed(response)
|
||||
self.assertEqual([s.strip() for s in response.headers['Vary'].split(',')],
|
||||
['Accept-Language', 'Cookie', 'Accept-Encoding'])
|
||||
|
||||
|
||||
@wsgi_safe
|
||||
|
|
@ -1721,28 +1846,29 @@ class MultipleExceptionTest(SimpleHandlerTestCase):
|
|||
class Handler(RequestHandler):
|
||||
exc_count = 0
|
||||
|
||||
@asynchronous
|
||||
def get(self):
|
||||
from tornado.ioloop import IOLoop
|
||||
IOLoop.current().add_callback(lambda: 1 / 0)
|
||||
IOLoop.current().add_callback(lambda: 1 / 0)
|
||||
with ignore_deprecation():
|
||||
@asynchronous
|
||||
def get(self):
|
||||
IOLoop.current().add_callback(lambda: 1 / 0)
|
||||
IOLoop.current().add_callback(lambda: 1 / 0)
|
||||
|
||||
def log_exception(self, typ, value, tb):
|
||||
MultipleExceptionTest.Handler.exc_count += 1
|
||||
|
||||
def test_multi_exception(self):
|
||||
# This test verifies that multiple exceptions raised into the same
|
||||
# ExceptionStackContext do not generate extraneous log entries
|
||||
# due to "Cannot send error response after headers written".
|
||||
# log_exception is called, but it does not proceed to send_error.
|
||||
response = self.fetch('/')
|
||||
self.assertEqual(response.code, 500)
|
||||
response = self.fetch('/')
|
||||
self.assertEqual(response.code, 500)
|
||||
# Each of our two requests generated two exceptions, we should have
|
||||
# seen at least three of them by now (the fourth may still be
|
||||
# in the queue).
|
||||
self.assertGreater(MultipleExceptionTest.Handler.exc_count, 2)
|
||||
with ignore_deprecation():
|
||||
# This test verifies that multiple exceptions raised into the same
|
||||
# ExceptionStackContext do not generate extraneous log entries
|
||||
# due to "Cannot send error response after headers written".
|
||||
# log_exception is called, but it does not proceed to send_error.
|
||||
response = self.fetch('/')
|
||||
self.assertEqual(response.code, 500)
|
||||
response = self.fetch('/')
|
||||
self.assertEqual(response.code, 500)
|
||||
# Each of our two requests generated two exceptions, we should have
|
||||
# seen at least three of them by now (the fourth may still be
|
||||
# in the queue).
|
||||
self.assertGreater(MultipleExceptionTest.Handler.exc_count, 2)
|
||||
|
||||
|
||||
@wsgi_safe
|
||||
|
|
@ -2050,7 +2176,7 @@ class StreamingRequestBodyTest(WebTestCase):
|
|||
# Use a raw connection so we can control the sending of data.
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
|
||||
s.connect(("127.0.0.1", self.get_http_port()))
|
||||
stream = IOStream(s, io_loop=self.io_loop)
|
||||
stream = IOStream(s)
|
||||
stream.write(b"GET " + url + b" HTTP/1.1\r\n")
|
||||
if connection_close:
|
||||
stream.write(b"Connection: close\r\n")
|
||||
|
|
@ -2073,9 +2199,9 @@ class StreamingRequestBodyTest(WebTestCase):
|
|||
stream.write(b"4\r\nqwer\r\n")
|
||||
data = yield self.data
|
||||
self.assertEquals(data, b"qwer")
|
||||
stream.write(b"0\r\n")
|
||||
stream.write(b"0\r\n\r\n")
|
||||
yield self.finished
|
||||
data = yield gen.Task(stream.read_until_close)
|
||||
data = yield stream.read_until_close()
|
||||
# This would ideally use an HTTP1Connection to read the response.
|
||||
self.assertTrue(data.endswith(b"{}"))
|
||||
stream.close()
|
||||
|
|
@ -2083,14 +2209,14 @@ class StreamingRequestBodyTest(WebTestCase):
|
|||
@gen_test
|
||||
def test_early_return(self):
|
||||
stream = self.connect(b"/early_return", connection_close=False)
|
||||
data = yield gen.Task(stream.read_until_close)
|
||||
data = yield stream.read_until_close()
|
||||
self.assertTrue(data.startswith(b"HTTP/1.1 401"))
|
||||
|
||||
@gen_test
|
||||
def test_early_return_with_data(self):
|
||||
stream = self.connect(b"/early_return", connection_close=False)
|
||||
stream.write(b"4\r\nasdf\r\n")
|
||||
data = yield gen.Task(stream.read_until_close)
|
||||
data = yield stream.read_until_close()
|
||||
self.assertTrue(data.startswith(b"HTTP/1.1 401"))
|
||||
|
||||
@gen_test
|
||||
|
|
@ -2129,12 +2255,12 @@ class BaseFlowControlHandler(RequestHandler):
|
|||
# Note that asynchronous prepare() does not block data_received,
|
||||
# so we don't use in_method here.
|
||||
self.methods.append('prepare')
|
||||
yield gen.Task(IOLoop.current().add_callback)
|
||||
yield gen.moment
|
||||
|
||||
@gen.coroutine
|
||||
def post(self):
|
||||
with self.in_method('post'):
|
||||
yield gen.Task(IOLoop.current().add_callback)
|
||||
yield gen.moment
|
||||
self.write(dict(methods=self.methods))
|
||||
|
||||
|
||||
|
|
@ -2146,7 +2272,7 @@ class BaseStreamingRequestFlowControlTest(object):
|
|||
|
||||
def get_http_client(self):
|
||||
# simple_httpclient only: curl doesn't support body_producer.
|
||||
return SimpleAsyncHTTPClient(io_loop=self.io_loop)
|
||||
return SimpleAsyncHTTPClient()
|
||||
|
||||
# Test all the slightly different code paths for fixed, chunked, etc bodies.
|
||||
def test_flow_control_fixed_body(self):
|
||||
|
|
@ -2160,6 +2286,7 @@ class BaseStreamingRequestFlowControlTest(object):
|
|||
|
||||
def test_flow_control_chunked_body(self):
|
||||
chunks = [b'abcd', b'efgh', b'ijkl']
|
||||
|
||||
@gen.coroutine
|
||||
def body_producer(write):
|
||||
for i in chunks:
|
||||
|
|
@ -2185,6 +2312,7 @@ class BaseStreamingRequestFlowControlTest(object):
|
|||
'data_received', 'data_received',
|
||||
'post']))
|
||||
|
||||
|
||||
class DecoratedStreamingRequestFlowControlTest(
|
||||
BaseStreamingRequestFlowControlTest,
|
||||
WebTestCase):
|
||||
|
|
@ -2193,7 +2321,7 @@ class DecoratedStreamingRequestFlowControlTest(
|
|||
@gen.coroutine
|
||||
def data_received(self, data):
|
||||
with self.in_method('data_received'):
|
||||
yield gen.Task(IOLoop.current().add_callback)
|
||||
yield gen.moment
|
||||
return [('/', DecoratedFlowControlHandler, dict(test=self))]
|
||||
|
||||
|
||||
|
|
@ -2206,7 +2334,8 @@ class NativeStreamingRequestFlowControlTest(
|
|||
data_received = exec_test(globals(), locals(), """
|
||||
async def data_received(self, data):
|
||||
with self.in_method('data_received'):
|
||||
await gen.Task(IOLoop.current().add_callback)
|
||||
import asyncio
|
||||
await asyncio.sleep(0)
|
||||
""")["data_received"]
|
||||
return [('/', NativeFlowControlHandler, dict(test=self))]
|
||||
|
||||
|
|
@ -2247,8 +2376,8 @@ class IncorrectContentLengthTest(SimpleHandlerTestCase):
|
|||
with ExpectLog(gen_log,
|
||||
"(Cannot send error response after headers written"
|
||||
"|Failed to flush partial response)"):
|
||||
response = self.fetch("/high")
|
||||
self.assertEqual(response.code, 599)
|
||||
with self.assertRaises(HTTPClientError):
|
||||
self.fetch("/high", raise_error=True)
|
||||
self.assertEqual(str(self.server_error),
|
||||
"Tried to write 40 bytes less than Content-Length")
|
||||
|
||||
|
|
@ -2260,8 +2389,8 @@ class IncorrectContentLengthTest(SimpleHandlerTestCase):
|
|||
with ExpectLog(gen_log,
|
||||
"(Cannot send error response after headers written"
|
||||
"|Failed to flush partial response)"):
|
||||
response = self.fetch("/low")
|
||||
self.assertEqual(response.code, 599)
|
||||
with self.assertRaises(HTTPClientError):
|
||||
self.fetch("/low", raise_error=True)
|
||||
self.assertEqual(str(self.server_error),
|
||||
"Tried to write more data than Content-Length")
|
||||
|
||||
|
|
@ -2282,10 +2411,11 @@ class ClientCloseTest(SimpleHandlerTestCase):
|
|||
self.write('requires HTTP/1.x')
|
||||
|
||||
def test_client_close(self):
|
||||
response = self.fetch('/')
|
||||
if response.body == b'requires HTTP/1.x':
|
||||
self.skipTest('requires HTTP/1.x')
|
||||
self.assertEqual(response.code, 599)
|
||||
with self.assertRaises((HTTPClientError, unittest.SkipTest)):
|
||||
response = self.fetch('/', raise_error=True)
|
||||
if response.body == b'requires HTTP/1.x':
|
||||
self.skipTest('requires HTTP/1.x')
|
||||
self.assertEqual(response.code, 599)
|
||||
|
||||
|
||||
class SignedValueTest(unittest.TestCase):
|
||||
|
|
@ -2483,6 +2613,22 @@ class XSRFTest(SimpleHandlerTestCase):
|
|||
body=urllib_parse.urlencode(dict(_xsrf=self.xsrf_token)))
|
||||
self.assertEqual(response.code, 403)
|
||||
|
||||
def test_xsrf_fail_argument_invalid_format(self):
|
||||
with ExpectLog(gen_log, ".*'_xsrf' argument has invalid format"):
|
||||
response = self.fetch(
|
||||
"/", method="POST",
|
||||
headers=self.cookie_headers(),
|
||||
body=urllib_parse.urlencode(dict(_xsrf='3|')))
|
||||
self.assertEqual(response.code, 403)
|
||||
|
||||
def test_xsrf_fail_cookie_invalid_format(self):
|
||||
with ExpectLog(gen_log, ".*XSRF cookie does not match POST"):
|
||||
response = self.fetch(
|
||||
"/", method="POST",
|
||||
headers=self.cookie_headers(token='3|'),
|
||||
body=urllib_parse.urlencode(dict(_xsrf=self.xsrf_token)))
|
||||
self.assertEqual(response.code, 403)
|
||||
|
||||
def test_xsrf_fail_cookie_no_body(self):
|
||||
with ExpectLog(gen_log, ".*'_xsrf' argument missing"):
|
||||
response = self.fetch(
|
||||
|
|
@ -2520,7 +2666,7 @@ class XSRFTest(SimpleHandlerTestCase):
|
|||
|
||||
def test_xsrf_success_header(self):
|
||||
response = self.fetch("/", method="POST", body=b"",
|
||||
headers=dict({"X-Xsrftoken": self.xsrf_token},
|
||||
headers=dict({"X-Xsrftoken": self.xsrf_token}, # type: ignore
|
||||
**self.cookie_headers()))
|
||||
self.assertEqual(response.code, 200)
|
||||
|
||||
|
|
@ -2623,8 +2769,8 @@ class FinishExceptionTest(SimpleHandlerTestCase):
|
|||
raise Finish()
|
||||
|
||||
def test_finish_exception(self):
|
||||
for url in ['/', '/?finish_value=1']:
|
||||
response = self.fetch(url)
|
||||
for u in ['/', '/?finish_value=1']:
|
||||
response = self.fetch(u)
|
||||
self.assertEqual(response.code, 401)
|
||||
self.assertEqual('Basic realm="something"',
|
||||
response.headers.get('WWW-Authenticate'))
|
||||
|
|
@ -2763,3 +2909,59 @@ class ApplicationTest(AsyncTestCase):
|
|||
app = Application([])
|
||||
server = app.listen(0, address='127.0.0.1')
|
||||
server.stop()
|
||||
|
||||
|
||||
class URLSpecReverseTest(unittest.TestCase):
|
||||
def test_reverse(self):
|
||||
self.assertEqual('/favicon.ico', url(r'/favicon\.ico', None).reverse())
|
||||
self.assertEqual('/favicon.ico', url(r'^/favicon\.ico$', None).reverse())
|
||||
|
||||
def test_non_reversible(self):
|
||||
# URLSpecs are non-reversible if they include non-constant
|
||||
# regex features outside capturing groups. Currently, this is
|
||||
# only strictly enforced for backslash-escaped character
|
||||
# classes.
|
||||
paths = [
|
||||
r'^/api/v\d+/foo/(\w+)$',
|
||||
]
|
||||
for path in paths:
|
||||
# A URLSpec can still be created even if it cannot be reversed.
|
||||
url_spec = url(path, None)
|
||||
try:
|
||||
result = url_spec.reverse()
|
||||
self.fail("did not get expected exception when reversing %s. "
|
||||
"result: %s" % (path, result))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def test_reverse_arguments(self):
|
||||
self.assertEqual('/api/v1/foo/bar',
|
||||
url(r'^/api/v1/foo/(\w+)$', None).reverse('bar'))
|
||||
|
||||
|
||||
class RedirectHandlerTest(WebTestCase):
|
||||
def get_handlers(self):
|
||||
return [
|
||||
('/src', WebRedirectHandler, {'url': '/dst'}),
|
||||
('/src2', WebRedirectHandler, {'url': '/dst2?foo=bar'}),
|
||||
(r'/(.*?)/(.*?)/(.*)', WebRedirectHandler, {'url': '/{1}/{0}/{2}'})]
|
||||
|
||||
def test_basic_redirect(self):
|
||||
response = self.fetch('/src', follow_redirects=False)
|
||||
self.assertEqual(response.code, 301)
|
||||
self.assertEqual(response.headers['Location'], '/dst')
|
||||
|
||||
def test_redirect_with_argument(self):
|
||||
response = self.fetch('/src?foo=bar', follow_redirects=False)
|
||||
self.assertEqual(response.code, 301)
|
||||
self.assertEqual(response.headers['Location'], '/dst?foo=bar')
|
||||
|
||||
def test_redirect_with_appending_argument(self):
|
||||
response = self.fetch('/src2?foo2=bar2', follow_redirects=False)
|
||||
self.assertEqual(response.code, 301)
|
||||
self.assertEqual(response.headers['Location'], '/dst2?foo=bar&foo2=bar2')
|
||||
|
||||
def test_redirect_pattern(self):
|
||||
response = self.fetch('/a/b/c', follow_redirects=False)
|
||||
self.assertEqual(response.code, 301)
|
||||
self.assertEqual(response.headers['Location'], '/b/a/c')
|
||||
|
|
|
|||
|
|
@ -1,15 +1,19 @@
|
|||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import functools
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from tornado.concurrent import Future
|
||||
from tornado import gen
|
||||
from tornado.httpclient import HTTPError, HTTPRequest
|
||||
from tornado.locks import Event
|
||||
from tornado.log import gen_log, app_log
|
||||
from tornado.simple_httpclient import SimpleAsyncHTTPClient
|
||||
from tornado.template import DictLoader
|
||||
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
|
||||
from tornado.test.util import unittest
|
||||
from tornado.test.util import unittest, skipBefore35, exec_test
|
||||
from tornado.web import Application, RequestHandler
|
||||
from tornado.util import u
|
||||
|
||||
try:
|
||||
import tornado.websocket # noqa
|
||||
|
|
@ -22,7 +26,9 @@ except ImportError:
|
|||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
from tornado.websocket import WebSocketHandler, websocket_connect, WebSocketError
|
||||
from tornado.websocket import (
|
||||
WebSocketHandler, websocket_connect, WebSocketError, WebSocketClosedError,
|
||||
)
|
||||
|
||||
try:
|
||||
from tornado import speedups
|
||||
|
|
@ -47,8 +53,12 @@ class TestWebSocketHandler(WebSocketHandler):
|
|||
|
||||
|
||||
class EchoHandler(TestWebSocketHandler):
|
||||
@gen.coroutine
|
||||
def on_message(self, message):
|
||||
self.write_message(message, isinstance(message, bytes))
|
||||
try:
|
||||
yield self.write_message(message, isinstance(message, bytes))
|
||||
except WebSocketClosedError:
|
||||
pass
|
||||
|
||||
|
||||
class ErrorInOnMessageHandler(TestWebSocketHandler):
|
||||
|
|
@ -58,16 +68,36 @@ class ErrorInOnMessageHandler(TestWebSocketHandler):
|
|||
|
||||
class HeaderHandler(TestWebSocketHandler):
|
||||
def open(self):
|
||||
try:
|
||||
# In a websocket context, many RequestHandler methods
|
||||
# raise RuntimeErrors.
|
||||
self.set_status(503)
|
||||
raise Exception("did not get expected exception")
|
||||
except RuntimeError:
|
||||
pass
|
||||
methods_to_test = [
|
||||
functools.partial(self.write, 'This should not work'),
|
||||
functools.partial(self.redirect, 'http://localhost/elsewhere'),
|
||||
functools.partial(self.set_header, 'X-Test', ''),
|
||||
functools.partial(self.set_cookie, 'Chocolate', 'Chip'),
|
||||
functools.partial(self.set_status, 503),
|
||||
self.flush,
|
||||
self.finish,
|
||||
]
|
||||
for method in methods_to_test:
|
||||
try:
|
||||
# In a websocket context, many RequestHandler methods
|
||||
# raise RuntimeErrors.
|
||||
method()
|
||||
raise Exception("did not get expected exception")
|
||||
except RuntimeError:
|
||||
pass
|
||||
self.write_message(self.request.headers.get('X-Test', ''))
|
||||
|
||||
|
||||
class HeaderEchoHandler(TestWebSocketHandler):
|
||||
def set_default_headers(self):
|
||||
self.set_header("X-Extra-Response-Header", "Extra-Response-Value")
|
||||
|
||||
def prepare(self):
|
||||
for k, v in self.request.headers.get_all():
|
||||
if k.lower().startswith('x-test'):
|
||||
self.set_header(k, v)
|
||||
|
||||
|
||||
class NonWebSocketHandler(RequestHandler):
|
||||
def get(self):
|
||||
self.write('ok')
|
||||
|
|
@ -88,12 +118,75 @@ class AsyncPrepareHandler(TestWebSocketHandler):
|
|||
self.write_message(message)
|
||||
|
||||
|
||||
class PathArgsHandler(TestWebSocketHandler):
|
||||
def open(self, arg):
|
||||
self.write_message(arg)
|
||||
|
||||
|
||||
class CoroutineOnMessageHandler(TestWebSocketHandler):
|
||||
def initialize(self, close_future, compression_options=None):
|
||||
super(CoroutineOnMessageHandler, self).initialize(close_future,
|
||||
compression_options)
|
||||
self.sleeping = 0
|
||||
|
||||
@gen.coroutine
|
||||
def on_message(self, message):
|
||||
if self.sleeping > 0:
|
||||
self.write_message('another coroutine is already sleeping')
|
||||
self.sleeping += 1
|
||||
yield gen.sleep(0.01)
|
||||
self.sleeping -= 1
|
||||
self.write_message(message)
|
||||
|
||||
|
||||
class RenderMessageHandler(TestWebSocketHandler):
|
||||
def on_message(self, message):
|
||||
self.write_message(self.render_string('message.html', message=message))
|
||||
|
||||
|
||||
class SubprotocolHandler(TestWebSocketHandler):
|
||||
def initialize(self, **kwargs):
|
||||
super(SubprotocolHandler, self).initialize(**kwargs)
|
||||
self.select_subprotocol_called = False
|
||||
|
||||
def select_subprotocol(self, subprotocols):
|
||||
if self.select_subprotocol_called:
|
||||
raise Exception("select_subprotocol called twice")
|
||||
self.select_subprotocol_called = True
|
||||
if 'goodproto' in subprotocols:
|
||||
return 'goodproto'
|
||||
return None
|
||||
|
||||
def open(self):
|
||||
if not self.select_subprotocol_called:
|
||||
raise Exception("select_subprotocol not called")
|
||||
self.write_message("subprotocol=%s" % self.selected_subprotocol)
|
||||
|
||||
|
||||
class OpenCoroutineHandler(TestWebSocketHandler):
|
||||
def initialize(self, test, **kwargs):
|
||||
super(OpenCoroutineHandler, self).initialize(**kwargs)
|
||||
self.test = test
|
||||
self.open_finished = False
|
||||
|
||||
@gen.coroutine
|
||||
def open(self):
|
||||
yield self.test.message_sent.wait()
|
||||
yield gen.sleep(0.010)
|
||||
self.open_finished = True
|
||||
|
||||
def on_message(self, message):
|
||||
if not self.open_finished:
|
||||
raise Exception('on_message called before open finished')
|
||||
self.write_message('ok')
|
||||
|
||||
|
||||
class WebSocketBaseTestCase(AsyncHTTPTestCase):
|
||||
@gen.coroutine
|
||||
def ws_connect(self, path, compression_options=None):
|
||||
def ws_connect(self, path, **kwargs):
|
||||
ws = yield websocket_connect(
|
||||
'ws://127.0.0.1:%d%s' % (self.get_http_port(), path),
|
||||
compression_options=compression_options)
|
||||
**kwargs)
|
||||
raise gen.Return(ws)
|
||||
|
||||
@gen.coroutine
|
||||
|
|
@ -114,19 +207,55 @@ class WebSocketTest(WebSocketBaseTestCase):
|
|||
('/echo', EchoHandler, dict(close_future=self.close_future)),
|
||||
('/non_ws', NonWebSocketHandler),
|
||||
('/header', HeaderHandler, dict(close_future=self.close_future)),
|
||||
('/header_echo', HeaderEchoHandler,
|
||||
dict(close_future=self.close_future)),
|
||||
('/close_reason', CloseReasonHandler,
|
||||
dict(close_future=self.close_future)),
|
||||
('/error_in_on_message', ErrorInOnMessageHandler,
|
||||
dict(close_future=self.close_future)),
|
||||
('/async_prepare', AsyncPrepareHandler,
|
||||
dict(close_future=self.close_future)),
|
||||
])
|
||||
('/path_args/(.*)', PathArgsHandler,
|
||||
dict(close_future=self.close_future)),
|
||||
('/coroutine', CoroutineOnMessageHandler,
|
||||
dict(close_future=self.close_future)),
|
||||
('/render', RenderMessageHandler,
|
||||
dict(close_future=self.close_future)),
|
||||
('/subprotocol', SubprotocolHandler,
|
||||
dict(close_future=self.close_future)),
|
||||
('/open_coroutine', OpenCoroutineHandler,
|
||||
dict(close_future=self.close_future, test=self)),
|
||||
], template_loader=DictLoader({
|
||||
'message.html': '<b>{{ message }}</b>',
|
||||
}))
|
||||
|
||||
def get_http_client(self):
|
||||
# These tests require HTTP/1; force the use of SimpleAsyncHTTPClient.
|
||||
return SimpleAsyncHTTPClient()
|
||||
|
||||
def tearDown(self):
|
||||
super(WebSocketTest, self).tearDown()
|
||||
RequestHandler._template_loaders.clear()
|
||||
|
||||
def test_http_request(self):
|
||||
# WS server, HTTP client.
|
||||
response = self.fetch('/echo')
|
||||
self.assertEqual(response.code, 400)
|
||||
|
||||
def test_missing_websocket_key(self):
|
||||
response = self.fetch('/echo',
|
||||
headers={'Connection': 'Upgrade',
|
||||
'Upgrade': 'WebSocket',
|
||||
'Sec-WebSocket-Version': '13'})
|
||||
self.assertEqual(response.code, 400)
|
||||
|
||||
def test_bad_websocket_version(self):
|
||||
response = self.fetch('/echo',
|
||||
headers={'Connection': 'Upgrade',
|
||||
'Upgrade': 'WebSocket',
|
||||
'Sec-WebSocket-Version': '12'})
|
||||
self.assertEqual(response.code, 426)
|
||||
|
||||
@gen_test
|
||||
def test_websocket_gen(self):
|
||||
ws = yield self.ws_connect('/echo')
|
||||
|
|
@ -138,7 +267,7 @@ class WebSocketTest(WebSocketBaseTestCase):
|
|||
def test_websocket_callbacks(self):
|
||||
websocket_connect(
|
||||
'ws://127.0.0.1:%d/echo' % self.get_http_port(),
|
||||
io_loop=self.io_loop, callback=self.stop)
|
||||
callback=self.stop)
|
||||
ws = self.wait().result()
|
||||
ws.write_message('hello')
|
||||
ws.read_message(self.stop)
|
||||
|
|
@ -159,9 +288,17 @@ class WebSocketTest(WebSocketBaseTestCase):
|
|||
@gen_test
|
||||
def test_unicode_message(self):
|
||||
ws = yield self.ws_connect('/echo')
|
||||
ws.write_message(u('hello \u00e9'))
|
||||
ws.write_message(u'hello \u00e9')
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, u('hello \u00e9'))
|
||||
self.assertEqual(response, u'hello \u00e9')
|
||||
yield self.close(ws)
|
||||
|
||||
@gen_test
|
||||
def test_render_message(self):
|
||||
ws = yield self.ws_connect('/render')
|
||||
ws.write_message('hello')
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, '<b>hello</b>')
|
||||
yield self.close(ws)
|
||||
|
||||
@gen_test
|
||||
|
|
@ -192,7 +329,6 @@ class WebSocketTest(WebSocketBaseTestCase):
|
|||
with ExpectLog(gen_log, ".*"):
|
||||
yield websocket_connect(
|
||||
'ws://127.0.0.1:%d/' % port,
|
||||
io_loop=self.io_loop,
|
||||
connect_timeout=3600)
|
||||
|
||||
@gen_test
|
||||
|
|
@ -215,6 +351,18 @@ class WebSocketTest(WebSocketBaseTestCase):
|
|||
self.assertEqual(response, 'hello')
|
||||
yield self.close(ws)
|
||||
|
||||
@gen_test
|
||||
def test_websocket_header_echo(self):
|
||||
# Ensure that headers can be returned in the response.
|
||||
# Specifically, that arbitrary headers passed through websocket_connect
|
||||
# can be returned.
|
||||
ws = yield websocket_connect(
|
||||
HTTPRequest('ws://127.0.0.1:%d/header_echo' % self.get_http_port(),
|
||||
headers={'X-Test-Hello': 'hello'}))
|
||||
self.assertEqual(ws.headers.get('X-Test-Hello'), 'hello')
|
||||
self.assertEqual(ws.headers.get('X-Extra-Response-Header'), 'Extra-Response-Value')
|
||||
yield self.close(ws)
|
||||
|
||||
@gen_test
|
||||
def test_server_close_reason(self):
|
||||
ws = yield self.ws_connect('/close_reason')
|
||||
|
|
@ -238,6 +386,14 @@ class WebSocketTest(WebSocketBaseTestCase):
|
|||
self.assertEqual(code, 1001)
|
||||
self.assertEqual(reason, 'goodbye')
|
||||
|
||||
@gen_test
|
||||
def test_write_after_close(self):
|
||||
ws = yield self.ws_connect('/close_reason')
|
||||
msg = yield ws.read_message()
|
||||
self.assertIs(msg, None)
|
||||
with self.assertRaises(WebSocketClosedError):
|
||||
ws.write_message('hello')
|
||||
|
||||
@gen_test
|
||||
def test_async_prepare(self):
|
||||
# Previously, an async prepare method triggered a bug that would
|
||||
|
|
@ -247,6 +403,23 @@ class WebSocketTest(WebSocketBaseTestCase):
|
|||
res = yield ws.read_message()
|
||||
self.assertEqual(res, 'hello')
|
||||
|
||||
@gen_test
|
||||
def test_path_args(self):
|
||||
ws = yield self.ws_connect('/path_args/hello')
|
||||
res = yield ws.read_message()
|
||||
self.assertEqual(res, 'hello')
|
||||
|
||||
@gen_test
|
||||
def test_coroutine(self):
|
||||
ws = yield self.ws_connect('/coroutine')
|
||||
# Send both messages immediately, coroutine must process one at a time.
|
||||
yield ws.write_message('hello1')
|
||||
yield ws.write_message('hello2')
|
||||
res = yield ws.read_message()
|
||||
self.assertEqual(res, 'hello1')
|
||||
res = yield ws.read_message()
|
||||
self.assertEqual(res, 'hello2')
|
||||
|
||||
@gen_test
|
||||
def test_check_origin_valid_no_path(self):
|
||||
port = self.get_http_port()
|
||||
|
|
@ -254,8 +427,7 @@ class WebSocketTest(WebSocketBaseTestCase):
|
|||
url = 'ws://127.0.0.1:%d/echo' % port
|
||||
headers = {'Origin': 'http://127.0.0.1:%d' % port}
|
||||
|
||||
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
|
||||
io_loop=self.io_loop)
|
||||
ws = yield websocket_connect(HTTPRequest(url, headers=headers))
|
||||
ws.write_message('hello')
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, 'hello')
|
||||
|
|
@ -268,8 +440,7 @@ class WebSocketTest(WebSocketBaseTestCase):
|
|||
url = 'ws://127.0.0.1:%d/echo' % port
|
||||
headers = {'Origin': 'http://127.0.0.1:%d/something' % port}
|
||||
|
||||
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
|
||||
io_loop=self.io_loop)
|
||||
ws = yield websocket_connect(HTTPRequest(url, headers=headers))
|
||||
ws.write_message('hello')
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, 'hello')
|
||||
|
|
@ -283,8 +454,7 @@ class WebSocketTest(WebSocketBaseTestCase):
|
|||
headers = {'Origin': '127.0.0.1:%d' % port}
|
||||
|
||||
with self.assertRaises(HTTPError) as cm:
|
||||
yield websocket_connect(HTTPRequest(url, headers=headers),
|
||||
io_loop=self.io_loop)
|
||||
yield websocket_connect(HTTPRequest(url, headers=headers))
|
||||
self.assertEqual(cm.exception.code, 403)
|
||||
|
||||
@gen_test
|
||||
|
|
@ -297,8 +467,7 @@ class WebSocketTest(WebSocketBaseTestCase):
|
|||
headers = {'Origin': 'http://somewhereelse.com'}
|
||||
|
||||
with self.assertRaises(HTTPError) as cm:
|
||||
yield websocket_connect(HTTPRequest(url, headers=headers),
|
||||
io_loop=self.io_loop)
|
||||
yield websocket_connect(HTTPRequest(url, headers=headers))
|
||||
|
||||
self.assertEqual(cm.exception.code, 403)
|
||||
|
||||
|
|
@ -312,21 +481,94 @@ class WebSocketTest(WebSocketBaseTestCase):
|
|||
headers = {'Origin': 'http://subtenant.localhost'}
|
||||
|
||||
with self.assertRaises(HTTPError) as cm:
|
||||
yield websocket_connect(HTTPRequest(url, headers=headers),
|
||||
io_loop=self.io_loop)
|
||||
yield websocket_connect(HTTPRequest(url, headers=headers))
|
||||
|
||||
self.assertEqual(cm.exception.code, 403)
|
||||
|
||||
@gen_test
|
||||
def test_subprotocols(self):
|
||||
ws = yield self.ws_connect('/subprotocol', subprotocols=['badproto', 'goodproto'])
|
||||
self.assertEqual(ws.selected_subprotocol, 'goodproto')
|
||||
res = yield ws.read_message()
|
||||
self.assertEqual(res, 'subprotocol=goodproto')
|
||||
yield self.close(ws)
|
||||
|
||||
@gen_test
|
||||
def test_subprotocols_not_offered(self):
|
||||
ws = yield self.ws_connect('/subprotocol')
|
||||
self.assertIs(ws.selected_subprotocol, None)
|
||||
res = yield ws.read_message()
|
||||
self.assertEqual(res, 'subprotocol=None')
|
||||
yield self.close(ws)
|
||||
|
||||
@gen_test
|
||||
def test_open_coroutine(self):
|
||||
self.message_sent = Event()
|
||||
ws = yield self.ws_connect('/open_coroutine')
|
||||
yield ws.write_message('hello')
|
||||
self.message_sent.set()
|
||||
res = yield ws.read_message()
|
||||
self.assertEqual(res, 'ok')
|
||||
yield self.close(ws)
|
||||
|
||||
|
||||
if sys.version_info >= (3, 5):
|
||||
NativeCoroutineOnMessageHandler = exec_test(globals(), locals(), """
|
||||
class NativeCoroutineOnMessageHandler(TestWebSocketHandler):
|
||||
def initialize(self, close_future, compression_options=None):
|
||||
super().initialize(close_future, compression_options)
|
||||
self.sleeping = 0
|
||||
|
||||
async def on_message(self, message):
|
||||
if self.sleeping > 0:
|
||||
self.write_message('another coroutine is already sleeping')
|
||||
self.sleeping += 1
|
||||
await gen.sleep(0.01)
|
||||
self.sleeping -= 1
|
||||
self.write_message(message)""")['NativeCoroutineOnMessageHandler']
|
||||
|
||||
|
||||
class WebSocketNativeCoroutineTest(WebSocketBaseTestCase):
|
||||
def get_app(self):
|
||||
self.close_future = Future()
|
||||
return Application([
|
||||
('/native', NativeCoroutineOnMessageHandler,
|
||||
dict(close_future=self.close_future))])
|
||||
|
||||
@skipBefore35
|
||||
@gen_test
|
||||
def test_native_coroutine(self):
|
||||
ws = yield self.ws_connect('/native')
|
||||
# Send both messages immediately, coroutine must process one at a time.
|
||||
yield ws.write_message('hello1')
|
||||
yield ws.write_message('hello2')
|
||||
res = yield ws.read_message()
|
||||
self.assertEqual(res, 'hello1')
|
||||
res = yield ws.read_message()
|
||||
self.assertEqual(res, 'hello2')
|
||||
|
||||
|
||||
class CompressionTestMixin(object):
|
||||
MESSAGE = 'Hello world. Testing 123 123'
|
||||
|
||||
def get_app(self):
|
||||
self.close_future = Future()
|
||||
|
||||
class LimitedHandler(TestWebSocketHandler):
|
||||
@property
|
||||
def max_message_size(self):
|
||||
return 1024
|
||||
|
||||
def on_message(self, message):
|
||||
self.write_message(str(len(message)))
|
||||
|
||||
return Application([
|
||||
('/echo', EchoHandler, dict(
|
||||
close_future=self.close_future,
|
||||
compression_options=self.get_server_compression_options())),
|
||||
('/limited', LimitedHandler, dict(
|
||||
close_future=self.close_future,
|
||||
compression_options=self.get_server_compression_options())),
|
||||
])
|
||||
|
||||
def get_server_compression_options(self):
|
||||
|
|
@ -352,6 +594,22 @@ class CompressionTestMixin(object):
|
|||
ws.protocol._wire_bytes_out)
|
||||
yield self.close(ws)
|
||||
|
||||
@gen_test
|
||||
def test_size_limit(self):
|
||||
ws = yield self.ws_connect(
|
||||
'/limited',
|
||||
compression_options=self.get_client_compression_options())
|
||||
# Small messages pass through.
|
||||
ws.write_message('a' * 128)
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, '128')
|
||||
# This message is too big after decompression, but it compresses
|
||||
# down to a size that will pass the initial checks.
|
||||
ws.write_message('a' * 2048)
|
||||
response = yield ws.read_message()
|
||||
self.assertIsNone(response)
|
||||
yield self.close(ws)
|
||||
|
||||
|
||||
class UncompressedTestMixin(CompressionTestMixin):
|
||||
"""Specialization of CompressionTestMixin when we expect no compression."""
|
||||
|
|
@ -417,3 +675,101 @@ class PythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
|
|||
class CythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
|
||||
def mask(self, mask, data):
|
||||
return speedups.websocket_mask(mask, data)
|
||||
|
||||
|
||||
class ServerPeriodicPingTest(WebSocketBaseTestCase):
|
||||
def get_app(self):
|
||||
class PingHandler(TestWebSocketHandler):
|
||||
def on_pong(self, data):
|
||||
self.write_message("got pong")
|
||||
|
||||
self.close_future = Future()
|
||||
return Application([
|
||||
('/', PingHandler, dict(close_future=self.close_future)),
|
||||
], websocket_ping_interval=0.01)
|
||||
|
||||
@gen_test
|
||||
def test_server_ping(self):
|
||||
ws = yield self.ws_connect('/')
|
||||
for i in range(3):
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, "got pong")
|
||||
yield self.close(ws)
|
||||
# TODO: test that the connection gets closed if ping responses stop.
|
||||
|
||||
|
||||
class ClientPeriodicPingTest(WebSocketBaseTestCase):
|
||||
def get_app(self):
|
||||
class PingHandler(TestWebSocketHandler):
|
||||
def on_ping(self, data):
|
||||
self.write_message("got ping")
|
||||
|
||||
self.close_future = Future()
|
||||
return Application([
|
||||
('/', PingHandler, dict(close_future=self.close_future)),
|
||||
])
|
||||
|
||||
@gen_test
|
||||
def test_client_ping(self):
|
||||
ws = yield self.ws_connect('/', ping_interval=0.01)
|
||||
for i in range(3):
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, "got ping")
|
||||
yield self.close(ws)
|
||||
# TODO: test that the connection gets closed if ping responses stop.
|
||||
|
||||
|
||||
class ManualPingTest(WebSocketBaseTestCase):
|
||||
def get_app(self):
|
||||
class PingHandler(TestWebSocketHandler):
|
||||
def on_ping(self, data):
|
||||
self.write_message(data, binary=isinstance(data, bytes))
|
||||
|
||||
self.close_future = Future()
|
||||
return Application([
|
||||
('/', PingHandler, dict(close_future=self.close_future)),
|
||||
])
|
||||
|
||||
@gen_test
|
||||
def test_manual_ping(self):
|
||||
ws = yield self.ws_connect('/')
|
||||
|
||||
self.assertRaises(ValueError, ws.ping, 'a' * 126)
|
||||
|
||||
ws.ping('hello')
|
||||
resp = yield ws.read_message()
|
||||
# on_ping always sees bytes.
|
||||
self.assertEqual(resp, b'hello')
|
||||
|
||||
ws.ping(b'binary hello')
|
||||
resp = yield ws.read_message()
|
||||
self.assertEqual(resp, b'binary hello')
|
||||
yield self.close(ws)
|
||||
|
||||
|
||||
class MaxMessageSizeTest(WebSocketBaseTestCase):
|
||||
def get_app(self):
|
||||
self.close_future = Future()
|
||||
return Application([
|
||||
('/', EchoHandler, dict(close_future=self.close_future)),
|
||||
], websocket_max_message_size=1024)
|
||||
|
||||
@gen_test
|
||||
def test_large_message(self):
|
||||
ws = yield self.ws_connect('/')
|
||||
|
||||
# Write a message that is allowed.
|
||||
msg = 'a' * 1024
|
||||
ws.write_message(msg)
|
||||
resp = yield ws.read_message()
|
||||
self.assertEqual(resp, msg)
|
||||
|
||||
# Write a message that is too large.
|
||||
ws.write_message(msg + 'b')
|
||||
resp = yield ws.read_message()
|
||||
# A message of None means the other side closed the connection.
|
||||
self.assertIs(resp, None)
|
||||
self.assertEqual(ws.close_code, 1009)
|
||||
self.assertEqual(ws.close_reason, "message too big")
|
||||
# TODO: Needs tests of messages split over multiple
|
||||
# continuation frames.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,25 @@
|
|||
from __future__ import absolute_import, division, print_function
|
||||
import functools
|
||||
import os
|
||||
import socket
|
||||
import unittest
|
||||
|
||||
from tornado.platform.auto import set_close_exec
|
||||
|
||||
skipIfNonWindows = unittest.skipIf(os.name != 'nt', 'non-windows platform')
|
||||
|
||||
|
||||
@skipIfNonWindows
|
||||
class WindowsTest(unittest.TestCase):
|
||||
def test_set_close_exec(self):
|
||||
# set_close_exec works with sockets.
|
||||
s = socket.socket()
|
||||
self.addCleanup(s.close)
|
||||
set_close_exec(s.fileno())
|
||||
|
||||
# But it doesn't work with pipes.
|
||||
r, w = os.pipe()
|
||||
self.addCleanup(functools.partial(os.close, r))
|
||||
self.addCleanup(functools.partial(os.close, w))
|
||||
with self.assertRaises(WindowsError):
|
||||
set_close_exec(r)
|
||||
|
|
@ -1,13 +1,16 @@
|
|||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
from wsgiref.validate import validator
|
||||
|
||||
from tornado.escape import json_decode
|
||||
from tornado.test.httpserver_test import TypeCheckHandler
|
||||
from tornado.test.util import ignore_deprecation
|
||||
from tornado.testing import AsyncHTTPTestCase
|
||||
from tornado.util import u
|
||||
from tornado.web import RequestHandler, Application
|
||||
from tornado.wsgi import WSGIApplication, WSGIContainer, WSGIAdapter
|
||||
|
||||
from tornado.test import httpserver_test
|
||||
from tornado.test import web_test
|
||||
|
||||
|
||||
class WSGIContainerTest(AsyncHTTPTestCase):
|
||||
def wsgi_app(self, environ, start_response):
|
||||
|
|
@ -24,7 +27,7 @@ class WSGIContainerTest(AsyncHTTPTestCase):
|
|||
self.assertEqual(response.body, b"Hello world!")
|
||||
|
||||
|
||||
class WSGIApplicationTest(AsyncHTTPTestCase):
|
||||
class WSGIAdapterTest(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
class HelloHandler(RequestHandler):
|
||||
def get(self):
|
||||
|
|
@ -38,11 +41,13 @@ class WSGIApplicationTest(AsyncHTTPTestCase):
|
|||
# another thread instead of using our own WSGIContainer, but this
|
||||
# fits better in our async testing framework and the wsgiref
|
||||
# validator should keep us honest
|
||||
return WSGIContainer(validator(WSGIApplication([
|
||||
("/", HelloHandler),
|
||||
("/path/(.*)", PathQuotingHandler),
|
||||
("/typecheck", TypeCheckHandler),
|
||||
])))
|
||||
with ignore_deprecation():
|
||||
return WSGIContainer(validator(WSGIAdapter(
|
||||
Application([
|
||||
("/", HelloHandler),
|
||||
("/path/(.*)", PathQuotingHandler),
|
||||
("/typecheck", TypeCheckHandler),
|
||||
]))))
|
||||
|
||||
def test_simple(self):
|
||||
response = self.fetch("/")
|
||||
|
|
@ -50,7 +55,7 @@ class WSGIApplicationTest(AsyncHTTPTestCase):
|
|||
|
||||
def test_path_quoting(self):
|
||||
response = self.fetch("/path/foo%20bar%C3%A9")
|
||||
self.assertEqual(response.body, u("foo bar\u00e9").encode("utf-8"))
|
||||
self.assertEqual(response.body, u"foo bar\u00e9".encode("utf-8"))
|
||||
|
||||
def test_types(self):
|
||||
headers = {"Cookie": "foo=bar"}
|
||||
|
|
@ -62,39 +67,52 @@ class WSGIApplicationTest(AsyncHTTPTestCase):
|
|||
data = json_decode(response.body)
|
||||
self.assertEqual(data, {})
|
||||
|
||||
# This is kind of hacky, but run some of the HTTPServer tests through
|
||||
# WSGIContainer and WSGIApplication to make sure everything survives
|
||||
# repeated disassembly and reassembly.
|
||||
from tornado.test import httpserver_test
|
||||
from tornado.test import web_test
|
||||
|
||||
|
||||
# This is kind of hacky, but run some of the HTTPServer and web tests
|
||||
# through WSGIContainer and WSGIApplication to make sure everything
|
||||
# survives repeated disassembly and reassembly.
|
||||
class WSGIConnectionTest(httpserver_test.HTTPConnectionTest):
|
||||
def get_app(self):
|
||||
return WSGIContainer(validator(WSGIApplication(self.get_handlers())))
|
||||
with ignore_deprecation():
|
||||
return WSGIContainer(validator(WSGIAdapter(Application(self.get_handlers()))))
|
||||
|
||||
|
||||
def wrap_web_tests_application():
|
||||
result = {}
|
||||
for cls in web_test.wsgi_safe_tests:
|
||||
class WSGIApplicationWrappedTest(cls):
|
||||
def get_app(self):
|
||||
self.app = WSGIApplication(self.get_handlers(),
|
||||
**self.get_app_kwargs())
|
||||
return WSGIContainer(validator(self.app))
|
||||
result["WSGIApplication_" + cls.__name__] = WSGIApplicationWrappedTest
|
||||
def class_factory():
|
||||
class WSGIApplicationWrappedTest(cls): # type: ignore
|
||||
def setUp(self):
|
||||
self.warning_catcher = ignore_deprecation()
|
||||
self.warning_catcher.__enter__()
|
||||
super(WSGIApplicationWrappedTest, self).setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super(WSGIApplicationWrappedTest, self).tearDown()
|
||||
self.warning_catcher.__exit__(None, None, None)
|
||||
|
||||
def get_app(self):
|
||||
self.app = WSGIApplication(self.get_handlers(),
|
||||
**self.get_app_kwargs())
|
||||
return WSGIContainer(validator(self.app))
|
||||
result["WSGIApplication_" + cls.__name__] = class_factory()
|
||||
return result
|
||||
|
||||
|
||||
globals().update(wrap_web_tests_application())
|
||||
|
||||
|
||||
def wrap_web_tests_adapter():
|
||||
result = {}
|
||||
for cls in web_test.wsgi_safe_tests:
|
||||
class WSGIAdapterWrappedTest(cls):
|
||||
class WSGIAdapterWrappedTest(cls): # type: ignore
|
||||
def get_app(self):
|
||||
self.app = Application(self.get_handlers(),
|
||||
**self.get_app_kwargs())
|
||||
return WSGIContainer(validator(WSGIAdapter(self.app)))
|
||||
with ignore_deprecation():
|
||||
return WSGIContainer(validator(WSGIAdapter(self.app)))
|
||||
result["WSGIAdapter_" + cls.__name__] = WSGIAdapterWrappedTest
|
||||
return result
|
||||
|
||||
|
||||
globals().update(wrap_web_tests_adapter())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue