update shared dependencies

This commit is contained in:
j 2016-02-23 11:36:55 +05:30
commit 736cd598a8
521 changed files with 45146 additions and 22574 deletions

View file

@ -25,5 +25,5 @@ from __future__ import absolute_import, division, print_function, with_statement
# is zero for an official release, positive for a development branch,
# or negative for a release candidate or beta (after the base version
# number has been incremented)
version = "4.0"
version_info = (4, 0, 0, 0)
version = "4.3"
version_info = (4, 3, 0, 0)

View file

@ -0,0 +1,94 @@
#!/usr/bin/env python
# coding: utf-8
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Data used by the tornado.locale module."""
from __future__ import absolute_import, division, print_function, with_statement
# NOTE: This file is supposed to contain unicode strings, which is
# exactly what you'd get with e.g. u"Español" in most python versions.
# However, Python 3.2 doesn't support the u"" syntax, so we use a u()
# function instead. tornado.util.u cannot be used because it doesn't
# support non-ascii characters on python 2.
# When we drop support for Python 3.2, we can remove the parens
# and make these plain unicode strings.
from tornado.escape import to_unicode as u
LOCALE_NAMES = {
"af_ZA": {"name_en": u("Afrikaans"), "name": u("Afrikaans")},
"am_ET": {"name_en": u("Amharic"), "name": u("አማርኛ")},
"ar_AR": {"name_en": u("Arabic"), "name": u("العربية")},
"bg_BG": {"name_en": u("Bulgarian"), "name": u("Български")},
"bn_IN": {"name_en": u("Bengali"), "name": u("বাংলা")},
"bs_BA": {"name_en": u("Bosnian"), "name": u("Bosanski")},
"ca_ES": {"name_en": u("Catalan"), "name": u("Català")},
"cs_CZ": {"name_en": u("Czech"), "name": u("Čeština")},
"cy_GB": {"name_en": u("Welsh"), "name": u("Cymraeg")},
"da_DK": {"name_en": u("Danish"), "name": u("Dansk")},
"de_DE": {"name_en": u("German"), "name": u("Deutsch")},
"el_GR": {"name_en": u("Greek"), "name": u("Ελληνικά")},
"en_GB": {"name_en": u("English (UK)"), "name": u("English (UK)")},
"en_US": {"name_en": u("English (US)"), "name": u("English (US)")},
"es_ES": {"name_en": u("Spanish (Spain)"), "name": u("Español (España)")},
"es_LA": {"name_en": u("Spanish"), "name": u("Español")},
"et_EE": {"name_en": u("Estonian"), "name": u("Eesti")},
"eu_ES": {"name_en": u("Basque"), "name": u("Euskara")},
"fa_IR": {"name_en": u("Persian"), "name": u("فارسی")},
"fi_FI": {"name_en": u("Finnish"), "name": u("Suomi")},
"fr_CA": {"name_en": u("French (Canada)"), "name": u("Français (Canada)")},
"fr_FR": {"name_en": u("French"), "name": u("Français")},
"ga_IE": {"name_en": u("Irish"), "name": u("Gaeilge")},
"gl_ES": {"name_en": u("Galician"), "name": u("Galego")},
"he_IL": {"name_en": u("Hebrew"), "name": u("עברית")},
"hi_IN": {"name_en": u("Hindi"), "name": u("हिन्दी")},
"hr_HR": {"name_en": u("Croatian"), "name": u("Hrvatski")},
"hu_HU": {"name_en": u("Hungarian"), "name": u("Magyar")},
"id_ID": {"name_en": u("Indonesian"), "name": u("Bahasa Indonesia")},
"is_IS": {"name_en": u("Icelandic"), "name": u("Íslenska")},
"it_IT": {"name_en": u("Italian"), "name": u("Italiano")},
"ja_JP": {"name_en": u("Japanese"), "name": u("日本語")},
"ko_KR": {"name_en": u("Korean"), "name": u("한국어")},
"lt_LT": {"name_en": u("Lithuanian"), "name": u("Lietuvių")},
"lv_LV": {"name_en": u("Latvian"), "name": u("Latviešu")},
"mk_MK": {"name_en": u("Macedonian"), "name": u("Македонски")},
"ml_IN": {"name_en": u("Malayalam"), "name": u("മലയാളം")},
"ms_MY": {"name_en": u("Malay"), "name": u("Bahasa Melayu")},
"nb_NO": {"name_en": u("Norwegian (bokmal)"), "name": u("Norsk (bokmål)")},
"nl_NL": {"name_en": u("Dutch"), "name": u("Nederlands")},
"nn_NO": {"name_en": u("Norwegian (nynorsk)"), "name": u("Norsk (nynorsk)")},
"pa_IN": {"name_en": u("Punjabi"), "name": u("ਪੰਜਾਬੀ")},
"pl_PL": {"name_en": u("Polish"), "name": u("Polski")},
"pt_BR": {"name_en": u("Portuguese (Brazil)"), "name": u("Português (Brasil)")},
"pt_PT": {"name_en": u("Portuguese (Portugal)"), "name": u("Português (Portugal)")},
"ro_RO": {"name_en": u("Romanian"), "name": u("Română")},
"ru_RU": {"name_en": u("Russian"), "name": u("Русский")},
"sk_SK": {"name_en": u("Slovak"), "name": u("Slovenčina")},
"sl_SI": {"name_en": u("Slovenian"), "name": u("Slovenščina")},
"sq_AL": {"name_en": u("Albanian"), "name": u("Shqip")},
"sr_RS": {"name_en": u("Serbian"), "name": u("Српски")},
"sv_SE": {"name_en": u("Swedish"), "name": u("Svenska")},
"sw_KE": {"name_en": u("Swahili"), "name": u("Kiswahili")},
"ta_IN": {"name_en": u("Tamil"), "name": u("தமிழ்")},
"te_IN": {"name_en": u("Telugu"), "name": u("తెలుగు")},
"th_TH": {"name_en": u("Thai"), "name": u("ภาษาไทย")},
"tl_PH": {"name_en": u("Filipino"), "name": u("Filipino")},
"tr_TR": {"name_en": u("Turkish"), "name": u("Türkçe")},
"uk_UA": {"name_en": u("Ukraini "), "name": u("Українська")},
"vi_VN": {"name_en": u("Vietnamese"), "name": u("Tiếng Việt")},
"zh_CN": {"name_en": u("Chinese (Simplified)"), "name": u("中文(简体)")},
"zh_TW": {"name_en": u("Chinese (Traditional)"), "name": u("中文(繁體)")},
}

View file

@ -32,7 +32,9 @@ They all take slightly different arguments due to the fact all these
services implement authentication and authorization slightly differently.
See the individual service classes below for complete documentation.
Example usage for Google OpenID::
Example usage for Google OAuth:
.. testcode::
class GoogleOAuth2LoginHandler(tornado.web.RequestHandler,
tornado.auth.GoogleOAuth2Mixin):
@ -51,6 +53,10 @@ Example usage for Google OpenID::
response_type='code',
extra_params={'approval_prompt': 'auto'})
.. testoutput::
:hide:
.. versionchanged:: 4.0
All of the callback interfaces in this module are now guaranteed
to run their callback with an argument of ``None`` on error.
@ -69,14 +75,14 @@ import hmac
import time
import uuid
from tornado.concurrent import TracebackFuture, chain_future, return_future
from tornado.concurrent import TracebackFuture, return_future, chain_future
from tornado import gen
from tornado import httpclient
from tornado import escape
from tornado.httputil import url_concat
from tornado.log import gen_log
from tornado.stack_context import ExceptionStackContext
from tornado.util import bytes_type, u, unicode_type, ArgReplacer
from tornado.util import u, unicode_type, ArgReplacer
try:
import urlparse # py2
@ -123,6 +129,7 @@ def _auth_return_future(f):
if callback is not None:
future.add_done_callback(
functools.partial(_auth_future_to_callback, callback))
def handle_exception(typ, value, tb):
if future.done():
return False
@ -138,9 +145,6 @@ def _auth_return_future(f):
class OpenIdMixin(object):
"""Abstract implementation of OpenID and Attribute Exchange.
See `GoogleMixin` below for a customized example (which also
includes OAuth support).
Class attributes:
* ``_OPENID_ENDPOINT``: the identity provider's URI.
@ -312,8 +316,7 @@ class OpenIdMixin(object):
class OAuthMixin(object):
"""Abstract implementation of OAuth 1.0 and 1.0a.
See `TwitterMixin` and `FriendFeedMixin` below for example implementations,
or `GoogleMixin` for an OAuth/OpenID hybrid.
See `TwitterMixin` below for an example implementation.
Class attributes:
@ -333,7 +336,7 @@ class OAuthMixin(object):
The ``callback_uri`` may be omitted if you have previously
registered a callback URI with the third-party service. For
some sevices (including Friendfeed), you must use a
some services (including Friendfeed), you must use a
previously-registered callback URI and cannot specify a
callback via this method.
@ -565,7 +568,8 @@ class OAuthMixin(object):
class OAuth2Mixin(object):
"""Abstract implementation of OAuth 2.0.
See `FacebookGraphMixin` below for an example implementation.
See `FacebookGraphMixin` or `GoogleOAuth2Mixin` below for example
implementations.
Class attributes:
@ -617,6 +621,72 @@ class OAuth2Mixin(object):
args.update(extra_params)
return url_concat(url, args)
@_auth_return_future
def oauth2_request(self, url, callback, access_token=None,
post_args=None, **args):
"""Fetches the given URL auth an OAuth2 access token.
If the request is a POST, ``post_args`` should be provided. Query
string arguments should be given as keyword arguments.
Example usage:
..testcode::
class MainHandler(tornado.web.RequestHandler,
tornado.auth.FacebookGraphMixin):
@tornado.web.authenticated
@tornado.gen.coroutine
def get(self):
new_entry = yield self.oauth2_request(
"https://graph.facebook.com/me/feed",
post_args={"message": "I am posting from my Tornado application!"},
access_token=self.current_user["access_token"])
if not new_entry:
# Call failed; perhaps missing permission?
yield self.authorize_redirect()
return
self.finish("Posted a message!")
.. testoutput::
:hide:
.. versionadded:: 4.3
"""
all_args = {}
if access_token:
all_args["access_token"] = access_token
all_args.update(args)
if all_args:
url += "?" + urllib_parse.urlencode(all_args)
callback = functools.partial(self._on_oauth2_request, callback)
http = self.get_auth_http_client()
if post_args is not None:
http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args),
callback=callback)
else:
http.fetch(url, callback=callback)
def _on_oauth2_request(self, future, response):
if response.error:
future.set_exception(AuthError("Error response %s fetching %s" %
(response.error, response.request.url)))
return
future.set_result(escape.json_decode(response.body))
def get_auth_http_client(self):
"""Returns the `.AsyncHTTPClient` instance to be used for auth requests.
May be overridden by subclasses to use an HTTP client other than
the default.
.. versionadded:: 4.3
"""
return httpclient.AsyncHTTPClient()
class TwitterMixin(OAuthMixin):
"""Twitter OAuth authentication.
@ -629,7 +699,9 @@ class TwitterMixin(OAuthMixin):
URL you registered as your application's callback URL.
When your application is set up, you can use this mixin like this
to authenticate the user with Twitter and get access to their stream::
to authenticate the user with Twitter and get access to their stream:
.. testcode::
class TwitterLoginHandler(tornado.web.RequestHandler,
tornado.auth.TwitterMixin):
@ -641,6 +713,9 @@ class TwitterMixin(OAuthMixin):
else:
yield self.authorize_redirect()
.. testoutput::
:hide:
The user object returned by `~OAuthMixin.get_authenticated_user`
includes the attributes ``username``, ``name``, ``access_token``,
and all of the custom Twitter user attributes described at
@ -689,7 +764,9 @@ class TwitterMixin(OAuthMixin):
`~OAuthMixin.get_authenticated_user`. The user returned through that
process includes an 'access_token' attribute that can be used
to make authenticated requests via this method. Example
usage::
usage:
.. testcode::
class MainHandler(tornado.web.RequestHandler,
tornado.auth.TwitterMixin):
@ -706,6 +783,9 @@ class TwitterMixin(OAuthMixin):
return
self.finish("Posted a message!")
.. testoutput::
:hide:
"""
if path.startswith('http:') or path.startswith('https:'):
# Raw urls are useful for e.g. search which doesn't follow the
@ -757,223 +837,6 @@ class TwitterMixin(OAuthMixin):
raise gen.Return(user)
class FriendFeedMixin(OAuthMixin):
"""FriendFeed OAuth authentication.
To authenticate with FriendFeed, register your application with
FriendFeed at http://friendfeed.com/api/applications. Then copy
your Consumer Key and Consumer Secret to the application
`~tornado.web.Application.settings` ``friendfeed_consumer_key``
and ``friendfeed_consumer_secret``. Use this mixin on the handler
for the URL you registered as your application's Callback URL.
When your application is set up, you can use this mixin like this
to authenticate the user with FriendFeed and get access to their feed::
class FriendFeedLoginHandler(tornado.web.RequestHandler,
tornado.auth.FriendFeedMixin):
@tornado.gen.coroutine
def get(self):
if self.get_argument("oauth_token", None):
user = yield self.get_authenticated_user()
# Save the user using e.g. set_secure_cookie()
else:
yield self.authorize_redirect()
The user object returned by `~OAuthMixin.get_authenticated_user()` includes the
attributes ``username``, ``name``, and ``description`` in addition to
``access_token``. You should save the access token with the user;
it is required to make requests on behalf of the user later with
`friendfeed_request()`.
"""
_OAUTH_VERSION = "1.0"
_OAUTH_REQUEST_TOKEN_URL = "https://friendfeed.com/account/oauth/request_token"
_OAUTH_ACCESS_TOKEN_URL = "https://friendfeed.com/account/oauth/access_token"
_OAUTH_AUTHORIZE_URL = "https://friendfeed.com/account/oauth/authorize"
_OAUTH_NO_CALLBACKS = True
_OAUTH_VERSION = "1.0"
@_auth_return_future
def friendfeed_request(self, path, callback, access_token=None,
post_args=None, **args):
"""Fetches the given relative API path, e.g., "/bret/friends"
If the request is a POST, ``post_args`` should be provided. Query
string arguments should be given as keyword arguments.
All the FriendFeed methods are documented at
http://friendfeed.com/api/documentation.
Many methods require an OAuth access token which you can
obtain through `~OAuthMixin.authorize_redirect` and
`~OAuthMixin.get_authenticated_user`. The user returned
through that process includes an ``access_token`` attribute that
can be used to make authenticated requests via this
method.
Example usage::
class MainHandler(tornado.web.RequestHandler,
tornado.auth.FriendFeedMixin):
@tornado.web.authenticated
@tornado.gen.coroutine
def get(self):
new_entry = yield self.friendfeed_request(
"/entry",
post_args={"body": "Testing Tornado Web Server"},
access_token=self.current_user["access_token"])
if not new_entry:
# Call failed; perhaps missing permission?
yield self.authorize_redirect()
return
self.finish("Posted a message!")
"""
# Add the OAuth resource request signature if we have credentials
url = "http://friendfeed-api.com/v2" + path
if access_token:
all_args = {}
all_args.update(args)
all_args.update(post_args or {})
method = "POST" if post_args is not None else "GET"
oauth = self._oauth_request_parameters(
url, access_token, all_args, method=method)
args.update(oauth)
if args:
url += "?" + urllib_parse.urlencode(args)
callback = functools.partial(self._on_friendfeed_request, callback)
http = self.get_auth_http_client()
if post_args is not None:
http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args),
callback=callback)
else:
http.fetch(url, callback=callback)
def _on_friendfeed_request(self, future, response):
if response.error:
future.set_exception(AuthError(
"Error response %s fetching %s" % (response.error,
response.request.url)))
return
future.set_result(escape.json_decode(response.body))
def _oauth_consumer_token(self):
self.require_setting("friendfeed_consumer_key", "FriendFeed OAuth")
self.require_setting("friendfeed_consumer_secret", "FriendFeed OAuth")
return dict(
key=self.settings["friendfeed_consumer_key"],
secret=self.settings["friendfeed_consumer_secret"])
@gen.coroutine
def _oauth_get_user_future(self, access_token, callback):
user = yield self.friendfeed_request(
"/feedinfo/" + access_token["username"],
include="id,name,description", access_token=access_token)
if user:
user["username"] = user["id"]
callback(user)
def _parse_user_response(self, callback, user):
if user:
user["username"] = user["id"]
callback(user)
class GoogleMixin(OpenIdMixin, OAuthMixin):
"""Google Open ID / OAuth authentication.
.. deprecated:: 4.0
New applications should use `GoogleOAuth2Mixin`
below instead of this class. As of May 19, 2014, Google has stopped
supporting registration-free authentication.
No application registration is necessary to use Google for
authentication or to access Google resources on behalf of a user.
Google implements both OpenID and OAuth in a hybrid mode. If you
just need the user's identity, use
`~OpenIdMixin.authenticate_redirect`. If you need to make
requests to Google on behalf of the user, use
`authorize_redirect`. On return, parse the response with
`~OpenIdMixin.get_authenticated_user`. We send a dict containing
the values for the user, including ``email``, ``name``, and
``locale``.
Example usage::
class GoogleLoginHandler(tornado.web.RequestHandler,
tornado.auth.GoogleMixin):
@tornado.gen.coroutine
def get(self):
if self.get_argument("openid.mode", None):
user = yield self.get_authenticated_user()
# Save the user with e.g. set_secure_cookie()
else:
yield self.authenticate_redirect()
"""
_OPENID_ENDPOINT = "https://www.google.com/accounts/o8/ud"
_OAUTH_ACCESS_TOKEN_URL = "https://www.google.com/accounts/OAuthGetAccessToken"
@return_future
def authorize_redirect(self, oauth_scope, callback_uri=None,
ax_attrs=["name", "email", "language", "username"],
callback=None):
"""Authenticates and authorizes for the given Google resource.
Some of the available resources which can be used in the ``oauth_scope``
argument are:
* Gmail Contacts - http://www.google.com/m8/feeds/
* Calendar - http://www.google.com/calendar/feeds/
* Finance - http://finance.google.com/finance/feeds/
You can authorize multiple resources by separating the resource
URLs with a space.
.. versionchanged:: 3.1
Returns a `.Future` and takes an optional callback. These are
not strictly necessary as this method is synchronous,
but they are supplied for consistency with
`OAuthMixin.authorize_redirect`.
"""
callback_uri = callback_uri or self.request.uri
args = self._openid_args(callback_uri, ax_attrs=ax_attrs,
oauth_scope=oauth_scope)
self.redirect(self._OPENID_ENDPOINT + "?" + urllib_parse.urlencode(args))
callback()
@_auth_return_future
def get_authenticated_user(self, callback):
"""Fetches the authenticated user data upon redirect."""
# Look to see if we are doing combined OpenID/OAuth
oauth_ns = ""
for name, values in self.request.arguments.items():
if name.startswith("openid.ns.") and \
values[-1] == b"http://specs.openid.net/extensions/oauth/1.0":
oauth_ns = name[10:]
break
token = self.get_argument("openid." + oauth_ns + ".request_token", "")
if token:
http = self.get_auth_http_client()
token = dict(key=token, secret="")
http.fetch(self._oauth_access_token_url(token),
functools.partial(self._on_access_token, callback))
else:
chain_future(OpenIdMixin.get_authenticated_user(self),
callback)
def _oauth_consumer_token(self):
self.require_setting("google_consumer_key", "Google OAuth")
self.require_setting("google_consumer_secret", "Google OAuth")
return dict(
key=self.settings["google_consumer_key"],
secret=self.settings["google_consumer_secret"])
def _oauth_get_user_future(self, access_token):
return OpenIdMixin.get_authenticated_user(self)
class GoogleOAuth2Mixin(OAuth2Mixin):
"""Google authentication using OAuth2.
@ -994,24 +857,39 @@ class GoogleOAuth2Mixin(OAuth2Mixin):
"""
_OAUTH_AUTHORIZE_URL = "https://accounts.google.com/o/oauth2/auth"
_OAUTH_ACCESS_TOKEN_URL = "https://accounts.google.com/o/oauth2/token"
_OAUTH_USERINFO_URL = "https://www.googleapis.com/oauth2/v1/userinfo"
_OAUTH_NO_CALLBACKS = False
_OAUTH_SETTINGS_KEY = 'google_oauth'
@_auth_return_future
def get_authenticated_user(self, redirect_uri, code, callback):
"""Handles the login for the Google user, returning a user object.
"""Handles the login for the Google user, returning an access token.
Example usage::
The result is a dictionary containing an ``access_token`` field
([among others](https://developers.google.com/identity/protocols/OAuth2WebServer#handlingtheresponse)).
Unlike other ``get_authenticated_user`` methods in this package,
this method does not return any additional information about the user.
The returned access token can be used with `OAuth2Mixin.oauth2_request`
to request additional information (perhaps from
``https://www.googleapis.com/oauth2/v2/userinfo``)
Example usage:
.. testcode::
class GoogleOAuth2LoginHandler(tornado.web.RequestHandler,
tornado.auth.GoogleOAuth2Mixin):
@tornado.gen.coroutine
def get(self):
if self.get_argument('code', False):
user = yield self.get_authenticated_user(
access = yield self.get_authenticated_user(
redirect_uri='http://your.site.com/auth/google',
code=self.get_argument('code'))
# Save the user with e.g. set_secure_cookie
user = yield self.oauth2_request(
"https://www.googleapis.com/oauth2/v1/userinfo",
access_token=access["access_token"])
# Save the user and access token with
# e.g. set_secure_cookie.
else:
yield self.authorize_redirect(
redirect_uri='http://your.site.com/auth/google',
@ -1019,6 +897,10 @@ class GoogleOAuth2Mixin(OAuth2Mixin):
scope=['profile', 'email'],
response_type='code',
extra_params={'approval_prompt': 'auto'})
.. testoutput::
:hide:
"""
http = self.get_auth_http_client()
body = urllib_parse.urlencode({
@ -1042,225 +924,6 @@ class GoogleOAuth2Mixin(OAuth2Mixin):
args = escape.json_decode(response.body)
future.set_result(args)
def get_auth_http_client(self):
"""Returns the `.AsyncHTTPClient` instance to be used for auth requests.
May be overridden by subclasses to use an HTTP client other than
the default.
"""
return httpclient.AsyncHTTPClient()
class FacebookMixin(object):
"""Facebook Connect authentication.
.. deprecated:: 1.1
New applications should use `FacebookGraphMixin`
below instead of this class. This class does not support the
Future-based interface seen on other classes in this module.
To authenticate with Facebook, register your application with
Facebook at http://www.facebook.com/developers/apps.php. Then
copy your API Key and Application Secret to the application settings
``facebook_api_key`` and ``facebook_secret``.
When your application is set up, you can use this mixin like this
to authenticate the user with Facebook::
class FacebookHandler(tornado.web.RequestHandler,
tornado.auth.FacebookMixin):
@tornado.web.asynchronous
def get(self):
if self.get_argument("session", None):
self.get_authenticated_user(self._on_auth)
return
yield self.authenticate_redirect()
def _on_auth(self, user):
if not user:
raise tornado.web.HTTPError(500, "Facebook auth failed")
# Save the user using, e.g., set_secure_cookie()
The user object returned by `get_authenticated_user` includes the
attributes ``facebook_uid`` and ``name`` in addition to session attributes
like ``session_key``. You should save the session key with the user; it is
required to make requests on behalf of the user later with
`facebook_request`.
"""
@return_future
def authenticate_redirect(self, callback_uri=None, cancel_uri=None,
extended_permissions=None, callback=None):
"""Authenticates/installs this app for the current user.
.. versionchanged:: 3.1
Returns a `.Future` and takes an optional callback. These are
not strictly necessary as this method is synchronous,
but they are supplied for consistency with
`OAuthMixin.authorize_redirect`.
"""
self.require_setting("facebook_api_key", "Facebook Connect")
callback_uri = callback_uri or self.request.uri
args = {
"api_key": self.settings["facebook_api_key"],
"v": "1.0",
"fbconnect": "true",
"display": "page",
"next": urlparse.urljoin(self.request.full_url(), callback_uri),
"return_session": "true",
}
if cancel_uri:
args["cancel_url"] = urlparse.urljoin(
self.request.full_url(), cancel_uri)
if extended_permissions:
if isinstance(extended_permissions, (unicode_type, bytes_type)):
extended_permissions = [extended_permissions]
args["req_perms"] = ",".join(extended_permissions)
self.redirect("http://www.facebook.com/login.php?" +
urllib_parse.urlencode(args))
callback()
def authorize_redirect(self, extended_permissions, callback_uri=None,
cancel_uri=None, callback=None):
"""Redirects to an authorization request for the given FB resource.
The available resource names are listed at
http://wiki.developers.facebook.com/index.php/Extended_permission.
The most common resource types include:
* publish_stream
* read_stream
* email
* sms
extended_permissions can be a single permission name or a list of
names. To get the session secret and session key, call
get_authenticated_user() just as you would with
authenticate_redirect().
.. versionchanged:: 3.1
Returns a `.Future` and takes an optional callback. These are
not strictly necessary as this method is synchronous,
but they are supplied for consistency with
`OAuthMixin.authorize_redirect`.
"""
return self.authenticate_redirect(callback_uri, cancel_uri,
extended_permissions,
callback=callback)
def get_authenticated_user(self, callback):
"""Fetches the authenticated Facebook user.
The authenticated user includes the special Facebook attributes
'session_key' and 'facebook_uid' in addition to the standard
user attributes like 'name'.
"""
self.require_setting("facebook_api_key", "Facebook Connect")
session = escape.json_decode(self.get_argument("session"))
self.facebook_request(
method="facebook.users.getInfo",
callback=functools.partial(
self._on_get_user_info, callback, session),
session_key=session["session_key"],
uids=session["uid"],
fields="uid,first_name,last_name,name,locale,pic_square,"
"profile_url,username")
def facebook_request(self, method, callback, **args):
"""Makes a Facebook API REST request.
We automatically include the Facebook API key and signature, but
it is the callers responsibility to include 'session_key' and any
other required arguments to the method.
The available Facebook methods are documented here:
http://wiki.developers.facebook.com/index.php/API
Here is an example for the stream.get() method::
class MainHandler(tornado.web.RequestHandler,
tornado.auth.FacebookMixin):
@tornado.web.authenticated
@tornado.web.asynchronous
def get(self):
self.facebook_request(
method="stream.get",
callback=self._on_stream,
session_key=self.current_user["session_key"])
def _on_stream(self, stream):
if stream is None:
# Not authorized to read the stream yet?
self.redirect(self.authorize_redirect("read_stream"))
return
self.render("stream.html", stream=stream)
"""
self.require_setting("facebook_api_key", "Facebook Connect")
self.require_setting("facebook_secret", "Facebook Connect")
if not method.startswith("facebook."):
method = "facebook." + method
args["api_key"] = self.settings["facebook_api_key"]
args["v"] = "1.0"
args["method"] = method
args["call_id"] = str(long(time.time() * 1e6))
args["format"] = "json"
args["sig"] = self._signature(args)
url = "http://api.facebook.com/restserver.php?" + \
urllib_parse.urlencode(args)
http = self.get_auth_http_client()
http.fetch(url, callback=functools.partial(
self._parse_response, callback))
def _on_get_user_info(self, callback, session, users):
if users is None:
callback(None)
return
callback({
"name": users[0]["name"],
"first_name": users[0]["first_name"],
"last_name": users[0]["last_name"],
"uid": users[0]["uid"],
"locale": users[0]["locale"],
"pic_square": users[0]["pic_square"],
"profile_url": users[0]["profile_url"],
"username": users[0].get("username"),
"session_key": session["session_key"],
"session_expires": session.get("expires"),
})
def _parse_response(self, callback, response):
if response.error:
gen_log.warning("HTTP error from Facebook: %s", response.error)
callback(None)
return
try:
json = escape.json_decode(response.body)
except Exception:
gen_log.warning("Invalid JSON from Facebook: %r", response.body)
callback(None)
return
if isinstance(json, dict) and json.get("error_code"):
gen_log.warning("Facebook error: %d: %r", json["error_code"],
json.get("error_msg"))
callback(None)
return
callback(json)
def _signature(self, args):
parts = ["%s=%s" % (n, args[n]) for n in sorted(args.keys())]
body = "".join(parts) + self.settings["facebook_secret"]
if isinstance(body, unicode_type):
body = body.encode("utf-8")
return hashlib.md5(body).hexdigest()
def get_auth_http_client(self):
"""Returns the `.AsyncHTTPClient` instance to be used for auth requests.
May be overridden by subclasses to use an HTTP client other than
the default.
"""
return httpclient.AsyncHTTPClient()
class FacebookGraphMixin(OAuth2Mixin):
"""Facebook authentication using the new Graph API and OAuth2."""
@ -1274,9 +937,12 @@ class FacebookGraphMixin(OAuth2Mixin):
code, callback, extra_fields=None):
"""Handles the login for the Facebook user, returning a user object.
Example usage::
Example usage:
class FacebookGraphLoginHandler(LoginHandler, tornado.auth.FacebookGraphMixin):
.. testcode::
class FacebookGraphLoginHandler(tornado.web.RequestHandler,
tornado.auth.FacebookGraphMixin):
@tornado.gen.coroutine
def get(self):
if self.get_argument("code", False):
@ -1291,6 +957,10 @@ class FacebookGraphMixin(OAuth2Mixin):
redirect_uri='/auth/facebookgraph/',
client_id=self.settings["facebook_api_key"],
extra_params={"scope": "read_stream,offline_access"})
.. testoutput::
:hide:
"""
http = self.get_auth_http_client()
args = {
@ -1307,7 +977,7 @@ class FacebookGraphMixin(OAuth2Mixin):
http.fetch(self._oauth_request_token_url(**args),
functools.partial(self._on_access_token, redirect_uri, client_id,
client_secret, callback, fields))
client_secret, callback, fields))
def _on_access_token(self, redirect_uri, client_id, client_secret,
future, fields, response):
@ -1315,7 +985,7 @@ class FacebookGraphMixin(OAuth2Mixin):
future.set_exception(AuthError('Facebook auth error: %s' % str(response)))
return
args = escape.parse_qs_bytes(escape.native_str(response.body))
args = urlparse.parse_qs(escape.native_str(response.body))
session = {
"access_token": args["access_token"][-1],
"expires": args.get("expires")
@ -1358,7 +1028,9 @@ class FacebookGraphMixin(OAuth2Mixin):
process includes an ``access_token`` attribute that can be
used to make authenticated requests via this method.
Example usage::
Example usage:
..testcode::
class MainHandler(tornado.web.RequestHandler,
tornado.auth.FacebookGraphMixin):
@ -1376,43 +1048,27 @@ class FacebookGraphMixin(OAuth2Mixin):
return
self.finish("Posted a message!")
.. testoutput::
:hide:
The given path is relative to ``self._FACEBOOK_BASE_URL``,
by default "https://graph.facebook.com".
This method is a wrapper around `OAuth2Mixin.oauth2_request`;
the only difference is that this method takes a relative path,
while ``oauth2_request`` takes a complete url.
.. versionchanged:: 3.1
Added the ability to override ``self._FACEBOOK_BASE_URL``.
"""
url = self._FACEBOOK_BASE_URL + path
all_args = {}
if access_token:
all_args["access_token"] = access_token
all_args.update(args)
if all_args:
url += "?" + urllib_parse.urlencode(all_args)
callback = functools.partial(self._on_facebook_request, callback)
http = self.get_auth_http_client()
if post_args is not None:
http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args),
callback=callback)
else:
http.fetch(url, callback=callback)
def _on_facebook_request(self, future, response):
if response.error:
future.set_exception(AuthError("Error response %s fetching %s" %
(response.error, response.request.url)))
return
future.set_result(escape.json_decode(response.body))
def get_auth_http_client(self):
"""Returns the `.AsyncHTTPClient` instance to be used for auth requests.
May be overridden by subclasses to use an HTTP client other than
the default.
"""
return httpclient.AsyncHTTPClient()
# Thanks to the _auth_return_future decorator, our "callback"
# argument is a Future, which we cannot pass as a callback to
# oauth2_request. Instead, have oauth2_request return a
# future and chain them together.
oauth_future = self.oauth2_request(url, access_token=access_token,
post_args=post_args, **args)
chain_future(oauth_future, callback)
def _oauth_signature(consumer_token, method, url, parameters={}, token=None):

View file

@ -100,6 +100,14 @@ try:
except ImportError:
signal = None
# os.execv is broken on Windows and can't properly parse command line
# arguments and executable name if they contain whitespaces. subprocess
# fixes that behavior.
# This distinction is also important because when we use execv, we want to
# close the IOLoop and all its file descriptors, to guard against any
# file descriptors that were not set CLOEXEC. When execv is not available,
# we must not close the IOLoop because we want the process to exit cleanly.
_has_execv = sys.platform != 'win32'
_watched_files = set()
_reload_hooks = []
@ -108,14 +116,19 @@ _io_loops = weakref.WeakKeyDictionary()
def start(io_loop=None, check_time=500):
"""Begins watching source files for changes using the given `.IOLoop`. """
"""Begins watching source files for changes.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
io_loop = io_loop or ioloop.IOLoop.current()
if io_loop in _io_loops:
return
_io_loops[io_loop] = True
if len(_io_loops) > 1:
gen_log.warning("tornado.autoreload started more than once in the same process")
add_reload_hook(functools.partial(io_loop.close, all_fds=True))
if _has_execv:
add_reload_hook(functools.partial(io_loop.close, all_fds=True))
modify_times = {}
callback = functools.partial(_reload_on_update, modify_times)
scheduler = ioloop.PeriodicCallback(callback, check_time, io_loop=io_loop)
@ -162,7 +175,7 @@ def _reload_on_update(modify_times):
# processes restarted themselves, they'd all restart and then
# all call fork_processes again.
return
for module in sys.modules.values():
for module in list(sys.modules.values()):
# Some modules play games with sys.modules (e.g. email/__init__.py
# in the standard library), and occasionally this can cause strange
# failures in getattr. Just ignore anything that's not an ordinary
@ -211,10 +224,7 @@ def _reload():
not os.environ.get("PYTHONPATH", "").startswith(path_prefix)):
os.environ["PYTHONPATH"] = (path_prefix +
os.environ.get("PYTHONPATH", ""))
if sys.platform == 'win32':
# os.execv is broken on Windows and can't properly parse command line
# arguments and executable name if they contain whitespaces. subprocess
# fixes that behavior.
if not _has_execv:
subprocess.Popen([sys.executable] + sys.argv)
sys.exit(0)
else:
@ -234,7 +244,10 @@ def _reload():
# this error specifically.
os.spawnv(os.P_NOWAIT, sys.executable,
[sys.executable] + sys.argv)
sys.exit(0)
# At this point the IOLoop has been closed and finally
# blocks will experience errors if we allow the stack to
# unwind, so just exit uncleanly.
os._exit(0)
_USAGE = """\
Usage:
@ -276,11 +289,16 @@ def main():
runpy.run_module(module, run_name="__main__", alter_sys=True)
elif mode == "script":
with open(script) as f:
# Execute the script in our namespace instead of creating
# a new one so that something that tries to import __main__
# (e.g. the unittest module) will see names defined in the
# script instead of just those defined in this module.
global __file__
__file__ = script
# Use globals as our "locals" dictionary so that
# something that tries to import __main__ (e.g. the unittest
# module) will see the right things.
# If __package__ is defined, imports may be incorrectly
# interpreted as relative to this module.
global __package__
del __package__
exec_in(f.read(), globals(), globals())
except SystemExit as e:
logging.basicConfig()

View file

@ -16,17 +16,20 @@
"""Utilities for working with threads and ``Futures``.
``Futures`` are a pattern for concurrent programming introduced in
Python 3.2 in the `concurrent.futures` package (this package has also
been backported to older versions of Python and can be installed with
``pip install futures``). Tornado will use `concurrent.futures.Future` if
it is available; otherwise it will use a compatible class defined in this
module.
Python 3.2 in the `concurrent.futures` package. This package defines
a mostly-compatible `Future` class designed for use from coroutines,
as well as some utility functions for interacting with the
`concurrent.futures` package.
"""
from __future__ import absolute_import, division, print_function, with_statement
import functools
import platform
import textwrap
import traceback
import sys
from tornado.log import app_log
from tornado.stack_context import ExceptionStackContext, wrap
from tornado.util import raise_exc_info, ArgReplacer
@ -36,9 +39,90 @@ except ImportError:
futures = None
# Can the garbage collector handle cycles that include __del__ methods?
# This is true in cpython beginning with version 3.4 (PEP 442).
_GC_CYCLE_FINALIZERS = (platform.python_implementation() == 'CPython' and
sys.version_info >= (3, 4))
class ReturnValueIgnoredError(Exception):
pass
# This class and associated code in the future object is derived
# from the Trollius project, a backport of asyncio to Python 2.x - 3.x
class _TracebackLogger(object):
"""Helper to log a traceback upon destruction if not cleared.
This solves a nasty problem with Futures and Tasks that have an
exception set: if nobody asks for the exception, the exception is
never logged. This violates the Zen of Python: 'Errors should
never pass silently. Unless explicitly silenced.'
However, we don't want to log the exception as soon as
set_exception() is called: if the calling code is written
properly, it will get the exception and handle it properly. But
we *do* want to log it if result() or exception() was never called
-- otherwise developers waste a lot of time wondering why their
buggy code fails silently.
An earlier attempt added a __del__() method to the Future class
itself, but this backfired because the presence of __del__()
prevents garbage collection from breaking cycles. A way out of
this catch-22 is to avoid having a __del__() method on the Future
class itself, but instead to have a reference to a helper object
with a __del__() method that logs the traceback, where we ensure
that the helper object doesn't participate in cycles, and only the
Future has a reference to it.
The helper object is added when set_exception() is called. When
the Future is collected, and the helper is present, the helper
object is also collected, and its __del__() method will log the
traceback. When the Future's result() or exception() method is
called (and a helper object is present), it removes the the helper
object, after calling its clear() method to prevent it from
logging.
One downside is that we do a fair amount of work to extract the
traceback from the exception, even when it is never logged. It
would seem cheaper to just store the exception object, but that
references the traceback, which references stack frames, which may
reference the Future, which references the _TracebackLogger, and
then the _TracebackLogger would be included in a cycle, which is
what we're trying to avoid! As an optimization, we don't
immediately format the exception; we only do the work when
activate() is called, which call is delayed until after all the
Future's callbacks have run. Since usually a Future has at least
one callback (typically set by 'yield From') and usually that
callback extracts the callback, thereby removing the need to
format the exception.
PS. I don't claim credit for this solution. I first heard of it
in a discussion about closing files when they are collected.
"""
__slots__ = ('exc_info', 'formatted_tb')
def __init__(self, exc_info):
self.exc_info = exc_info
self.formatted_tb = None
def activate(self):
exc_info = self.exc_info
if exc_info is not None:
self.exc_info = None
self.formatted_tb = traceback.format_exception(*exc_info)
def clear(self):
self.exc_info = None
self.formatted_tb = None
def __del__(self):
if self.formatted_tb:
app_log.error('Future exception was never retrieved: %s',
''.join(self.formatted_tb).rstrip())
class Future(object):
"""Placeholder for an asynchronous result.
@ -67,14 +151,42 @@ class Future(object):
if that package was available and fall back to the thread-unsafe
implementation if it was not.
.. versionchanged:: 4.1
If a `.Future` contains an error but that error is never observed
(by calling ``result()``, ``exception()``, or ``exc_info()``),
a stack trace will be logged when the `.Future` is garbage collected.
This normally indicates an error in the application, but in cases
where it results in undesired logging it may be necessary to
suppress the logging by ensuring that the exception is observed:
``f.add_done_callback(lambda f: f.exception())``.
"""
def __init__(self):
self._done = False
self._result = None
self._exception = None
self._exc_info = None
self._log_traceback = False # Used for Python >= 3.4
self._tb_logger = None # Used for Python <= 3.3
self._callbacks = []
# Implement the Python 3.5 Awaitable protocol if possible
# (we can't use return and yield together until py33).
if sys.version_info >= (3, 3):
exec(textwrap.dedent("""
def __await__(self):
return (yield self)
"""))
else:
# Py2-compatible version for use with cython.
def __await__(self):
result = yield self
# StopIteration doesn't take args before py33,
# but Cython recognizes the args tuple.
e = StopIteration()
e.args = (result,)
raise e
def cancel(self):
"""Cancel the operation, if possible.
@ -99,25 +211,39 @@ class Future(object):
"""Returns True if the future has finished running."""
return self._done
def _clear_tb_log(self):
self._log_traceback = False
if self._tb_logger is not None:
self._tb_logger.clear()
self._tb_logger = None
def result(self, timeout=None):
"""If the operation succeeded, return its result. If it failed,
re-raise its exception.
This method takes a ``timeout`` argument for compatibility with
`concurrent.futures.Future` but it is an error to call it
before the `Future` is done, so the ``timeout`` is never used.
"""
self._clear_tb_log()
if self._result is not None:
return self._result
if self._exc_info is not None:
raise_exc_info(self._exc_info)
elif self._exception is not None:
raise self._exception
self._check_done()
return self._result
def exception(self, timeout=None):
"""If the operation raised an exception, return the `Exception`
object. Otherwise returns None.
This method takes a ``timeout`` argument for compatibility with
`concurrent.futures.Future` but it is an error to call it
before the `Future` is done, so the ``timeout`` is never used.
"""
if self._exception is not None:
return self._exception
self._clear_tb_log()
if self._exc_info is not None:
return self._exc_info[1]
else:
self._check_done()
return None
@ -146,14 +272,17 @@ class Future(object):
def set_exception(self, exception):
"""Sets the exception of a ``Future.``"""
self._exception = exception
self._set_done()
self.set_exc_info(
(exception.__class__,
exception,
getattr(exception, '__traceback__', None)))
def exc_info(self):
"""Returns a tuple in the same format as `sys.exc_info` or None.
.. versionadded:: 4.0
"""
self._clear_tb_log()
return self._exc_info
def set_exc_info(self, exc_info):
@ -164,7 +293,18 @@ class Future(object):
.. versionadded:: 4.0
"""
self._exc_info = exc_info
self.set_exception(exc_info[1])
self._log_traceback = True
if not _GC_CYCLE_FINALIZERS:
self._tb_logger = _TracebackLogger(exc_info)
try:
self._set_done()
finally:
# Activate the logger after all callbacks have had a
# chance to call result() or exception().
if self._log_traceback and self._tb_logger is not None:
self._tb_logger.activate()
self._exc_info = exc_info
def _check_done(self):
if not self._done:
@ -173,10 +313,28 @@ class Future(object):
def _set_done(self):
self._done = True
for cb in self._callbacks:
# TODO: error handling
cb(self)
try:
cb(self)
except Exception:
app_log.exception('Exception in callback %r for %r',
cb, self)
self._callbacks = None
# On Python 3.3 or older, objects with a destructor part of a reference
# cycle are never destroyed. It's no longer the case on Python 3.4 thanks to
# the PEP 442.
if _GC_CYCLE_FINALIZERS:
def __del__(self):
if not self._log_traceback:
# set_exception() was not called, or result() or exception()
# has consumed the exception
return
tb = traceback.format_exception(*self._exc_info)
app_log.error('Future %r exception was never retrieved: %s',
self, ''.join(tb).rstrip())
TracebackFuture = Future
if futures is None:
@ -204,24 +362,43 @@ class DummyExecutor(object):
dummy_executor = DummyExecutor()
def run_on_executor(fn):
def run_on_executor(*args, **kwargs):
"""Decorator to run a synchronous method asynchronously on an executor.
The decorated method may be called with a ``callback`` keyword
argument and returns a future.
This decorator should be used only on methods of objects with attributes
``executor`` and ``io_loop``.
The `.IOLoop` and executor to be used are determined by the ``io_loop``
and ``executor`` attributes of ``self``. To use different attributes,
pass keyword arguments to the decorator::
@run_on_executor(executor='_thread_pool')
def foo(self):
pass
.. versionchanged:: 4.2
Added keyword arguments to use alternative attributes.
"""
@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
callback = kwargs.pop("callback", None)
future = self.executor.submit(fn, self, *args, **kwargs)
if callback:
self.io_loop.add_future(future,
lambda future: callback(future.result()))
return future
return wrapper
def run_on_executor_decorator(fn):
executor = kwargs.get("executor", "executor")
io_loop = kwargs.get("io_loop", "io_loop")
@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
callback = kwargs.pop("callback", None)
future = getattr(self, executor).submit(fn, self, *args, **kwargs)
if callback:
getattr(self, io_loop).add_future(
future, lambda future: callback(future.result()))
return future
return wrapper
if args and kwargs:
raise ValueError("cannot combine positional and keyword args")
if len(args) == 1:
return run_on_executor_decorator(args[0])
elif len(args) != 0:
raise ValueError("expected 1 argument, got %d", len(args))
return run_on_executor_decorator
_NO_RESULT = object()
@ -246,7 +423,9 @@ def return_future(f):
wait for the function to complete (perhaps by yielding it in a
`.gen.engine` function, or passing it to `.IOLoop.add_future`).
Usage::
Usage:
.. testcode::
@return_future
def future_func(arg1, arg2, callback):
@ -258,6 +437,8 @@ def return_future(f):
yield future_func(arg1, arg2)
callback()
..
Note that ``@return_future`` and ``@gen.engine`` can be applied to the
same function, provided ``@return_future`` appears first. However,
consider using ``@gen.coroutine`` instead of this combination.
@ -289,7 +470,7 @@ def return_future(f):
# If the initial synchronous part of f() raised an exception,
# go ahead and raise it to the caller directly without waiting
# for them to inspect the Future.
raise_exc_info(exc_info)
future.result()
# If the caller passed in a callback, schedule it to be called
# when the future resolves. It is important that this happens

View file

@ -19,24 +19,21 @@
from __future__ import absolute_import, division, print_function, with_statement
import collections
import functools
import logging
import pycurl
import threading
import time
from io import BytesIO
from tornado import httputil
from tornado import ioloop
from tornado.log import gen_log
from tornado import stack_context
from tornado.escape import utf8, native_str
from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main
from tornado.util import bytes_type
try:
from io import BytesIO # py3
except ImportError:
from cStringIO import StringIO as BytesIO # py2
curl_log = logging.getLogger('tornado.curl_httpclient')
class CurlAsyncHTTPClient(AsyncHTTPClient):
@ -45,7 +42,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
self._multi = pycurl.CurlMulti()
self._multi.setopt(pycurl.M_TIMERFUNCTION, self._set_timeout)
self._multi.setopt(pycurl.M_SOCKETFUNCTION, self._handle_socket)
self._curls = [_curl_create() for i in range(max_clients)]
self._curls = [self._curl_create() for i in range(max_clients)]
self._free_list = self._curls[:]
self._requests = collections.deque()
self._fds = {}
@ -211,9 +208,25 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
"callback": callback,
"curl_start_time": time.time(),
}
_curl_setup_request(curl, request, curl.info["buffer"],
curl.info["headers"])
self._multi.add_handle(curl)
try:
self._curl_setup_request(
curl, request, curl.info["buffer"],
curl.info["headers"])
except Exception as e:
# If there was an error in setup, pass it on
# to the callback. Note that allowing the
# error to escape here will appear to work
# most of the time since we are still in the
# caller's original stack frame, but when
# _process_queue() is called from
# _finish_pending_requests the exceptions have
# nowhere to go.
callback(HTTPResponse(
request=request,
code=599,
error=e))
else:
self._multi.add_handle(curl)
if not started:
break
@ -259,6 +272,222 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
def handle_callback_exception(self, callback):
self.io_loop.handle_callback_exception(callback)
def _curl_create(self):
curl = pycurl.Curl()
if curl_log.isEnabledFor(logging.DEBUG):
curl.setopt(pycurl.VERBOSE, 1)
curl.setopt(pycurl.DEBUGFUNCTION, self._curl_debug)
return curl
def _curl_setup_request(self, curl, request, buffer, headers):
curl.setopt(pycurl.URL, native_str(request.url))
# libcurl's magic "Expect: 100-continue" behavior causes delays
# with servers that don't support it (which include, among others,
# Google's OpenID endpoint). Additionally, this behavior has
# a bug in conjunction with the curl_multi_socket_action API
# (https://sourceforge.net/tracker/?func=detail&atid=100976&aid=3039744&group_id=976),
# which increases the delays. It's more trouble than it's worth,
# so just turn off the feature (yes, setting Expect: to an empty
# value is the official way to disable this)
if "Expect" not in request.headers:
request.headers["Expect"] = ""
# libcurl adds Pragma: no-cache by default; disable that too
if "Pragma" not in request.headers:
request.headers["Pragma"] = ""
curl.setopt(pycurl.HTTPHEADER,
["%s: %s" % (native_str(k), native_str(v))
for k, v in request.headers.get_all()])
curl.setopt(pycurl.HEADERFUNCTION,
functools.partial(self._curl_header_callback,
headers, request.header_callback))
if request.streaming_callback:
def write_function(chunk):
self.io_loop.add_callback(request.streaming_callback, chunk)
else:
write_function = buffer.write
if bytes is str: # py2
curl.setopt(pycurl.WRITEFUNCTION, write_function)
else: # py3
# Upstream pycurl doesn't support py3, but ubuntu 12.10 includes
# a fork/port. That version has a bug in which it passes unicode
# strings instead of bytes to the WRITEFUNCTION. This means that
# if you use a WRITEFUNCTION (which tornado always does), you cannot
# download arbitrary binary data. This needs to be fixed in the
# ported pycurl package, but in the meantime this lambda will
# make it work for downloading (utf8) text.
curl.setopt(pycurl.WRITEFUNCTION, lambda s: write_function(utf8(s)))
curl.setopt(pycurl.FOLLOWLOCATION, request.follow_redirects)
curl.setopt(pycurl.MAXREDIRS, request.max_redirects)
curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(1000 * request.connect_timeout))
curl.setopt(pycurl.TIMEOUT_MS, int(1000 * request.request_timeout))
if request.user_agent:
curl.setopt(pycurl.USERAGENT, native_str(request.user_agent))
else:
curl.setopt(pycurl.USERAGENT, "Mozilla/5.0 (compatible; pycurl)")
if request.network_interface:
curl.setopt(pycurl.INTERFACE, request.network_interface)
if request.decompress_response:
curl.setopt(pycurl.ENCODING, "gzip,deflate")
else:
curl.setopt(pycurl.ENCODING, "none")
if request.proxy_host and request.proxy_port:
curl.setopt(pycurl.PROXY, request.proxy_host)
curl.setopt(pycurl.PROXYPORT, request.proxy_port)
if request.proxy_username:
credentials = '%s:%s' % (request.proxy_username,
request.proxy_password)
curl.setopt(pycurl.PROXYUSERPWD, credentials)
else:
curl.setopt(pycurl.PROXY, '')
curl.unsetopt(pycurl.PROXYUSERPWD)
if request.validate_cert:
curl.setopt(pycurl.SSL_VERIFYPEER, 1)
curl.setopt(pycurl.SSL_VERIFYHOST, 2)
else:
curl.setopt(pycurl.SSL_VERIFYPEER, 0)
curl.setopt(pycurl.SSL_VERIFYHOST, 0)
if request.ca_certs is not None:
curl.setopt(pycurl.CAINFO, request.ca_certs)
else:
# There is no way to restore pycurl.CAINFO to its default value
# (Using unsetopt makes it reject all certificates).
# I don't see any way to read the default value from python so it
# can be restored later. We'll have to just leave CAINFO untouched
# if no ca_certs file was specified, and require that if any
# request uses a custom ca_certs file, they all must.
pass
if request.allow_ipv6 is False:
# Curl behaves reasonably when DNS resolution gives an ipv6 address
# that we can't reach, so allow ipv6 unless the user asks to disable.
curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
else:
curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_WHATEVER)
# Set the request method through curl's irritating interface which makes
# up names for almost every single method
curl_options = {
"GET": pycurl.HTTPGET,
"POST": pycurl.POST,
"PUT": pycurl.UPLOAD,
"HEAD": pycurl.NOBODY,
}
custom_methods = set(["DELETE", "OPTIONS", "PATCH"])
for o in curl_options.values():
curl.setopt(o, False)
if request.method in curl_options:
curl.unsetopt(pycurl.CUSTOMREQUEST)
curl.setopt(curl_options[request.method], True)
elif request.allow_nonstandard_methods or request.method in custom_methods:
curl.setopt(pycurl.CUSTOMREQUEST, request.method)
else:
raise KeyError('unknown method ' + request.method)
body_expected = request.method in ("POST", "PATCH", "PUT")
body_present = request.body is not None
if not request.allow_nonstandard_methods:
# Some HTTP methods nearly always have bodies while others
# almost never do. Fail in this case unless the user has
# opted out of sanity checks with allow_nonstandard_methods.
if ((body_expected and not body_present) or
(body_present and not body_expected)):
raise ValueError(
'Body must %sbe None for method %s (unless '
'allow_nonstandard_methods is true)' %
('not ' if body_expected else '', request.method))
if body_expected or body_present:
if request.method == "GET":
# Even with `allow_nonstandard_methods` we disallow
# GET with a body (because libcurl doesn't allow it
# unless we use CUSTOMREQUEST). While the spec doesn't
# forbid clients from sending a body, it arguably
# disallows the server from doing anything with them.
raise ValueError('Body must be None for GET request')
request_buffer = BytesIO(utf8(request.body or ''))
def ioctl(cmd):
if cmd == curl.IOCMD_RESTARTREAD:
request_buffer.seek(0)
curl.setopt(pycurl.READFUNCTION, request_buffer.read)
curl.setopt(pycurl.IOCTLFUNCTION, ioctl)
if request.method == "POST":
curl.setopt(pycurl.POSTFIELDSIZE, len(request.body or ''))
else:
curl.setopt(pycurl.UPLOAD, True)
curl.setopt(pycurl.INFILESIZE, len(request.body or ''))
if request.auth_username is not None:
userpwd = "%s:%s" % (request.auth_username, request.auth_password or '')
if request.auth_mode is None or request.auth_mode == "basic":
curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_BASIC)
elif request.auth_mode == "digest":
curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_DIGEST)
else:
raise ValueError("Unsupported auth_mode %s" % request.auth_mode)
curl.setopt(pycurl.USERPWD, native_str(userpwd))
curl_log.debug("%s %s (username: %r)", request.method, request.url,
request.auth_username)
else:
curl.unsetopt(pycurl.USERPWD)
curl_log.debug("%s %s", request.method, request.url)
if request.client_cert is not None:
curl.setopt(pycurl.SSLCERT, request.client_cert)
if request.client_key is not None:
curl.setopt(pycurl.SSLKEY, request.client_key)
if request.ssl_options is not None:
raise ValueError("ssl_options not supported in curl_httpclient")
if threading.activeCount() > 1:
# libcurl/pycurl is not thread-safe by default. When multiple threads
# are used, signals should be disabled. This has the side effect
# of disabling DNS timeouts in some environments (when libcurl is
# not linked against ares), so we don't do it when there is only one
# thread. Applications that use many short-lived threads may need
# to set NOSIGNAL manually in a prepare_curl_callback since
# there may not be any other threads running at the time we call
# threading.activeCount.
curl.setopt(pycurl.NOSIGNAL, 1)
if request.prepare_curl_callback is not None:
request.prepare_curl_callback(curl)
def _curl_header_callback(self, headers, header_callback, header_line):
header_line = native_str(header_line)
if header_callback is not None:
self.io_loop.add_callback(header_callback, header_line)
# header_line as returned by curl includes the end-of-line characters.
# whitespace at the start should be preserved to allow multi-line headers
header_line = header_line.rstrip()
if header_line.startswith("HTTP/"):
headers.clear()
try:
(__, __, reason) = httputil.parse_response_start_line(header_line)
header_line = "X-Http-Reason: %s" % reason
except httputil.HTTPInputError:
return
if not header_line:
return
headers.parse_line(header_line)
def _curl_debug(self, debug_type, debug_msg):
debug_types = ('I', '<', '>', '<', '>')
if debug_type == 0:
curl_log.debug('%s', debug_msg.strip())
elif debug_type in (1, 2):
for line in debug_msg.splitlines():
curl_log.debug('%s %s', debug_types[debug_type], line)
elif debug_type == 4:
curl_log.debug('%s %r', debug_types[debug_type], debug_msg)
class CurlError(HTTPError):
def __init__(self, errno, message):
@ -266,212 +495,6 @@ class CurlError(HTTPError):
self.errno = errno
def _curl_create():
curl = pycurl.Curl()
if gen_log.isEnabledFor(logging.DEBUG):
curl.setopt(pycurl.VERBOSE, 1)
curl.setopt(pycurl.DEBUGFUNCTION, _curl_debug)
return curl
def _curl_setup_request(curl, request, buffer, headers):
curl.setopt(pycurl.URL, native_str(request.url))
# libcurl's magic "Expect: 100-continue" behavior causes delays
# with servers that don't support it (which include, among others,
# Google's OpenID endpoint). Additionally, this behavior has
# a bug in conjunction with the curl_multi_socket_action API
# (https://sourceforge.net/tracker/?func=detail&atid=100976&aid=3039744&group_id=976),
# which increases the delays. It's more trouble than it's worth,
# so just turn off the feature (yes, setting Expect: to an empty
# value is the official way to disable this)
if "Expect" not in request.headers:
request.headers["Expect"] = ""
# libcurl adds Pragma: no-cache by default; disable that too
if "Pragma" not in request.headers:
request.headers["Pragma"] = ""
# Request headers may be either a regular dict or HTTPHeaders object
if isinstance(request.headers, httputil.HTTPHeaders):
curl.setopt(pycurl.HTTPHEADER,
[native_str("%s: %s" % i) for i in request.headers.get_all()])
else:
curl.setopt(pycurl.HTTPHEADER,
[native_str("%s: %s" % i) for i in request.headers.items()])
if request.header_callback:
curl.setopt(pycurl.HEADERFUNCTION,
lambda line: request.header_callback(native_str(line)))
else:
curl.setopt(pycurl.HEADERFUNCTION,
lambda line: _curl_header_callback(headers,
native_str(line)))
if request.streaming_callback:
write_function = request.streaming_callback
else:
write_function = buffer.write
if bytes_type is str: # py2
curl.setopt(pycurl.WRITEFUNCTION, write_function)
else: # py3
# Upstream pycurl doesn't support py3, but ubuntu 12.10 includes
# a fork/port. That version has a bug in which it passes unicode
# strings instead of bytes to the WRITEFUNCTION. This means that
# if you use a WRITEFUNCTION (which tornado always does), you cannot
# download arbitrary binary data. This needs to be fixed in the
# ported pycurl package, but in the meantime this lambda will
# make it work for downloading (utf8) text.
curl.setopt(pycurl.WRITEFUNCTION, lambda s: write_function(utf8(s)))
curl.setopt(pycurl.FOLLOWLOCATION, request.follow_redirects)
curl.setopt(pycurl.MAXREDIRS, request.max_redirects)
curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(1000 * request.connect_timeout))
curl.setopt(pycurl.TIMEOUT_MS, int(1000 * request.request_timeout))
if request.user_agent:
curl.setopt(pycurl.USERAGENT, native_str(request.user_agent))
else:
curl.setopt(pycurl.USERAGENT, "Mozilla/5.0 (compatible; pycurl)")
if request.network_interface:
curl.setopt(pycurl.INTERFACE, request.network_interface)
if request.decompress_response:
curl.setopt(pycurl.ENCODING, "gzip,deflate")
else:
curl.setopt(pycurl.ENCODING, "none")
if request.proxy_host and request.proxy_port:
curl.setopt(pycurl.PROXY, request.proxy_host)
curl.setopt(pycurl.PROXYPORT, request.proxy_port)
if request.proxy_username:
credentials = '%s:%s' % (request.proxy_username,
request.proxy_password)
curl.setopt(pycurl.PROXYUSERPWD, credentials)
else:
curl.setopt(pycurl.PROXY, '')
curl.unsetopt(pycurl.PROXYUSERPWD)
if request.validate_cert:
curl.setopt(pycurl.SSL_VERIFYPEER, 1)
curl.setopt(pycurl.SSL_VERIFYHOST, 2)
else:
curl.setopt(pycurl.SSL_VERIFYPEER, 0)
curl.setopt(pycurl.SSL_VERIFYHOST, 0)
if request.ca_certs is not None:
curl.setopt(pycurl.CAINFO, request.ca_certs)
else:
# There is no way to restore pycurl.CAINFO to its default value
# (Using unsetopt makes it reject all certificates).
# I don't see any way to read the default value from python so it
# can be restored later. We'll have to just leave CAINFO untouched
# if no ca_certs file was specified, and require that if any
# request uses a custom ca_certs file, they all must.
pass
if request.allow_ipv6 is False:
# Curl behaves reasonably when DNS resolution gives an ipv6 address
# that we can't reach, so allow ipv6 unless the user asks to disable.
curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
else:
curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_WHATEVER)
# Set the request method through curl's irritating interface which makes
# up names for almost every single method
curl_options = {
"GET": pycurl.HTTPGET,
"POST": pycurl.POST,
"PUT": pycurl.UPLOAD,
"HEAD": pycurl.NOBODY,
}
custom_methods = set(["DELETE", "OPTIONS", "PATCH"])
for o in curl_options.values():
curl.setopt(o, False)
if request.method in curl_options:
curl.unsetopt(pycurl.CUSTOMREQUEST)
curl.setopt(curl_options[request.method], True)
elif request.allow_nonstandard_methods or request.method in custom_methods:
curl.setopt(pycurl.CUSTOMREQUEST, request.method)
else:
raise KeyError('unknown method ' + request.method)
# Handle curl's cryptic options for every individual HTTP method
if request.method in ("POST", "PUT"):
if request.body is None:
raise AssertionError(
'Body must not be empty for "%s" request'
% request.method)
request_buffer = BytesIO(utf8(request.body))
curl.setopt(pycurl.READFUNCTION, request_buffer.read)
if request.method == "POST":
def ioctl(cmd):
if cmd == curl.IOCMD_RESTARTREAD:
request_buffer.seek(0)
curl.setopt(pycurl.IOCTLFUNCTION, ioctl)
curl.setopt(pycurl.POSTFIELDSIZE, len(request.body))
else:
curl.setopt(pycurl.INFILESIZE, len(request.body))
elif request.method == "GET":
if request.body is not None:
raise AssertionError('Body must be empty for GET request')
if request.auth_username is not None:
userpwd = "%s:%s" % (request.auth_username, request.auth_password or '')
if request.auth_mode is None or request.auth_mode == "basic":
curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_BASIC)
elif request.auth_mode == "digest":
curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_DIGEST)
else:
raise ValueError("Unsupported auth_mode %s" % request.auth_mode)
curl.setopt(pycurl.USERPWD, native_str(userpwd))
gen_log.debug("%s %s (username: %r)", request.method, request.url,
request.auth_username)
else:
curl.unsetopt(pycurl.USERPWD)
gen_log.debug("%s %s", request.method, request.url)
if request.client_cert is not None:
curl.setopt(pycurl.SSLCERT, request.client_cert)
if request.client_key is not None:
curl.setopt(pycurl.SSLKEY, request.client_key)
if threading.activeCount() > 1:
# libcurl/pycurl is not thread-safe by default. When multiple threads
# are used, signals should be disabled. This has the side effect
# of disabling DNS timeouts in some environments (when libcurl is
# not linked against ares), so we don't do it when there is only one
# thread. Applications that use many short-lived threads may need
# to set NOSIGNAL manually in a prepare_curl_callback since
# there may not be any other threads running at the time we call
# threading.activeCount.
curl.setopt(pycurl.NOSIGNAL, 1)
if request.prepare_curl_callback is not None:
request.prepare_curl_callback(curl)
def _curl_header_callback(headers, header_line):
# header_line as returned by curl includes the end-of-line characters.
header_line = header_line.strip()
if header_line.startswith("HTTP/"):
headers.clear()
try:
(__, __, reason) = httputil.parse_response_start_line(header_line)
header_line = "X-Http-Reason: %s" % reason
except httputil.HTTPInputError:
return
if not header_line:
return
headers.parse_line(header_line)
def _curl_debug(debug_type, debug_msg):
debug_types = ('I', '<', '>', '<', '>')
if debug_type == 0:
gen_log.debug('%s', debug_msg.strip())
elif debug_type in (1, 2):
for line in debug_msg.splitlines():
gen_log.debug('%s %s', debug_types[debug_type], line)
elif debug_type == 4:
gen_log.debug('%s %r', debug_types[debug_type], debug_msg)
if __name__ == "__main__":
AsyncHTTPClient.configure(CurlAsyncHTTPClient)
main()

View file

@ -25,7 +25,7 @@ from __future__ import absolute_import, division, print_function, with_statement
import re
import sys
from tornado.util import bytes_type, unicode_type, basestring_type, u
from tornado.util import unicode_type, basestring_type, u
try:
from urllib.parse import parse_qs as _parse_qs # py3
@ -82,7 +82,7 @@ def json_encode(value):
# JSON permits but does not require forward slashes to be escaped.
# This is useful when json data is emitted in a <script> tag
# in HTML, as it prevents </script> tags from prematurely terminating
# the javscript. Some json libraries do this escaping by default,
# the javascript. Some json libraries do this escaping by default,
# although python's standard library does not, so we do it here.
# http://stackoverflow.com/questions/1580647/json-why-are-forward-slashes-escaped
return json.dumps(value).replace("</", "<\\/")
@ -187,7 +187,7 @@ else:
return encoded
_UTF8_TYPES = (bytes_type, type(None))
_UTF8_TYPES = (bytes, type(None))
def utf8(value):
@ -215,7 +215,7 @@ def to_unicode(value):
"""
if isinstance(value, _TO_UNICODE_TYPES):
return value
if not isinstance(value, bytes_type):
if not isinstance(value, bytes):
raise TypeError(
"Expected bytes, unicode, or None; got %r" % type(value)
)
@ -246,7 +246,7 @@ def to_basestring(value):
"""
if isinstance(value, _BASESTRING_TYPES):
return value
if not isinstance(value, bytes_type):
if not isinstance(value, bytes):
raise TypeError(
"Expected bytes, unicode, or None; got %r" % type(value)
)
@ -264,7 +264,7 @@ def recursive_unicode(obj):
return list(recursive_unicode(i) for i in obj)
elif isinstance(obj, tuple):
return tuple(recursive_unicode(i) for i in obj)
elif isinstance(obj, bytes_type):
elif isinstance(obj, bytes):
return to_unicode(obj)
else:
return obj
@ -378,7 +378,10 @@ def linkify(text, shorten=False, extra_params="",
def _convert_entity(m):
if m.group(1) == "#":
try:
return unichr(int(m.group(2)))
if m.group(2)[:1].lower() == 'x':
return unichr(int(m.group(2)[1:], 16))
else:
return unichr(int(m.group(2)))
except ValueError:
return "&#%s;" % m.group(2)
try:

View file

@ -3,7 +3,9 @@ work in an asynchronous environment. Code using the ``gen`` module
is technically asynchronous, but it is written as a single generator
instead of a collection of separate functions.
For example, the following asynchronous handler::
For example, the following asynchronous handler:
.. testcode::
class AsyncHandler(RequestHandler):
@asynchronous
@ -16,7 +18,12 @@ For example, the following asynchronous handler::
do_something_with_response(response)
self.render("template.html")
could be written with ``gen`` as::
.. testoutput::
:hide:
could be written with ``gen`` as:
.. testcode::
class GenAsyncHandler(RequestHandler):
@gen.coroutine
@ -26,12 +33,17 @@ could be written with ``gen`` as::
do_something_with_response(response)
self.render("template.html")
.. testoutput::
:hide:
Most asynchronous functions in Tornado return a `.Future`;
yielding this object returns its `~.Future.result`.
You can also yield a list or dict of ``Futures``, which will be
started at the same time and run in parallel; a list or dict of results will
be returned when they are all finished::
be returned when they are all finished:
.. testcode::
@gen.coroutine
def get(self):
@ -43,20 +55,79 @@ be returned when they are all finished::
response3 = response_dict['response3']
response4 = response_dict['response4']
.. testoutput::
:hide:
If the `~functools.singledispatch` library is available (standard in
Python 3.4, available via the `singledispatch
<https://pypi.python.org/pypi/singledispatch>`_ package on older
versions), additional types of objects may be yielded. Tornado includes
support for ``asyncio.Future`` and Twisted's ``Deferred`` class when
``tornado.platform.asyncio`` and ``tornado.platform.twisted`` are imported.
See the `convert_yielded` function to extend this mechanism.
.. versionchanged:: 3.2
Dict support added.
.. versionchanged:: 4.1
Support added for yielding ``asyncio`` Futures and Twisted Deferreds
via ``singledispatch``.
"""
from __future__ import absolute_import, division, print_function, with_statement
import collections
import functools
import itertools
import os
import sys
import textwrap
import types
from tornado.concurrent import Future, TracebackFuture, is_future, chain_future
from tornado.ioloop import IOLoop
from tornado.log import app_log
from tornado import stack_context
from tornado.util import raise_exc_info
try:
try:
from functools import singledispatch # py34+
except ImportError:
from singledispatch import singledispatch # backport
except ImportError:
# In most cases, singledispatch is required (to avoid
# difficult-to-diagnose problems in which the functionality
# available differs depending on which invisble packages are
# installed). However, in Google App Engine third-party
# dependencies are more trouble so we allow this module to be
# imported without it.
if 'APPENGINE_RUNTIME' not in os.environ:
raise
singledispatch = None
try:
try:
from collections.abc import Generator as GeneratorType # py35+
except ImportError:
from backports_abc import Generator as GeneratorType
try:
from inspect import isawaitable # py35+
except ImportError:
from backports_abc import isawaitable
except ImportError:
if 'APPENGINE_RUNTIME' not in os.environ:
raise
from types import GeneratorType
def isawaitable(x):
return False
try:
import builtins # py3
except ImportError:
import __builtin__ as builtins
class KeyReuseError(Exception):
@ -83,6 +154,21 @@ class TimeoutError(Exception):
"""Exception raised by ``with_timeout``."""
def _value_from_stopiteration(e):
try:
# StopIteration has a value attribute beginning in py33.
# So does our Return class.
return e.value
except AttributeError:
pass
try:
# Cython backports coroutine functionality by putting the value in
# e.args[0].
return e.args[0]
except (AttributeError, IndexError):
return None
def engine(func):
"""Callback-oriented decorator for asynchronous generators.
@ -101,15 +187,20 @@ def engine(func):
which use ``self.finish()`` in place of a callback argument.
"""
func = _make_coroutine_wrapper(func, replace_callback=False)
@functools.wraps(func)
def wrapper(*args, **kwargs):
future = func(*args, **kwargs)
def final_callback(future):
if future.result() is not None:
raise ReturnValueIgnoredError(
"@gen.engine functions cannot return values: %r" %
(future.result(),))
future.add_done_callback(final_callback)
# The engine interface doesn't give us any way to return
# errors but to raise them into the stack context.
# Save the stack context here to use when the Future has resolved.
future.add_done_callback(stack_context.wrap(final_callback))
return wrapper
@ -136,6 +227,17 @@ def coroutine(func, replace_callback=True):
From the caller's perspective, ``@gen.coroutine`` is similar to
the combination of ``@return_future`` and ``@gen.engine``.
.. warning::
When exceptions occur inside a coroutine, the exception
information will be stored in the `.Future` object. You must
examine the result of the `.Future` object, or the exception
may go unnoticed by your code. This means yielding the function
if called from another coroutine, using something like
`.IOLoop.run_sync` for top-level calls, or passing the `.Future`
to `.IOLoop.add_future`.
"""
return _make_coroutine_wrapper(func, replace_callback=True)
@ -147,6 +249,11 @@ def _make_coroutine_wrapper(func, replace_callback):
argument, so we cannot simply implement ``@engine`` in terms of
``@coroutine``.
"""
# On Python 3.5, set the coroutine flag on our generator, to allow it
# to be used with 'await'.
if hasattr(types, 'coroutine'):
func = types.coroutine(func)
@functools.wraps(func)
def wrapper(*args, **kwargs):
future = TracebackFuture()
@ -159,12 +266,12 @@ def _make_coroutine_wrapper(func, replace_callback):
try:
result = func(*args, **kwargs)
except (Return, StopIteration) as e:
result = getattr(e, 'value', None)
result = _value_from_stopiteration(e)
except Exception:
future.set_exc_info(sys.exc_info())
return future
else:
if isinstance(result, types.GeneratorType):
if isinstance(result, GeneratorType):
# Inline the first iteration of Runner.run. This lets us
# avoid the cost of creating a Runner when the coroutine
# never actually yields, which in turn allows us to
@ -180,12 +287,23 @@ def _make_coroutine_wrapper(func, replace_callback):
'stack_context inconsistency (probably caused '
'by yield within a "with StackContext" block)'))
except (StopIteration, Return) as e:
future.set_result(getattr(e, 'value', None))
future.set_result(_value_from_stopiteration(e))
except Exception:
future.set_exc_info(sys.exc_info())
else:
Runner(result, future, yielded)
return future
try:
return future
finally:
# Subtle memory optimization: if next() raised an exception,
# the future's exc_info contains a traceback which
# includes this stack frame. This creates a cycle,
# which will be collected at the next full GC but has
# been shown to greatly increase memory usage of
# benchmarks (relative to the refcount-based scheme
# used in the absence of cycles). We can avoid the
# cycle by clearing the local variable after we return it.
future = None
future.set_result(result)
return future
return wrapper
@ -214,6 +332,127 @@ class Return(Exception):
def __init__(self, value=None):
super(Return, self).__init__()
self.value = value
# Cython recognizes subclasses of StopIteration with a .args tuple.
self.args = (value,)
class WaitIterator(object):
"""Provides an iterator to yield the results of futures as they finish.
Yielding a set of futures like this:
``results = yield [future1, future2]``
pauses the coroutine until both ``future1`` and ``future2``
return, and then restarts the coroutine with the results of both
futures. If either future is an exception, the expression will
raise that exception and all the results will be lost.
If you need to get the result of each future as soon as possible,
or if you need the result of some futures even if others produce
errors, you can use ``WaitIterator``::
wait_iterator = gen.WaitIterator(future1, future2)
while not wait_iterator.done():
try:
result = yield wait_iterator.next()
except Exception as e:
print("Error {} from {}".format(e, wait_iterator.current_future))
else:
print("Result {} received from {} at {}".format(
result, wait_iterator.current_future,
wait_iterator.current_index))
Because results are returned as soon as they are available the
output from the iterator *will not be in the same order as the
input arguments*. If you need to know which future produced the
current result, you can use the attributes
``WaitIterator.current_future``, or ``WaitIterator.current_index``
to get the index of the future from the input list. (if keyword
arguments were used in the construction of the `WaitIterator`,
``current_index`` will use the corresponding keyword).
On Python 3.5, `WaitIterator` implements the async iterator
protocol, so it can be used with the ``async for`` statement (note
that in this version the entire iteration is aborted if any value
raises an exception, while the previous example can continue past
individual errors)::
async for result in gen.WaitIterator(future1, future2):
print("Result {} received from {} at {}".format(
result, wait_iterator.current_future,
wait_iterator.current_index))
.. versionadded:: 4.1
.. versionchanged:: 4.3
Added ``async for`` support in Python 3.5.
"""
def __init__(self, *args, **kwargs):
if args and kwargs:
raise ValueError(
"You must provide args or kwargs, not both")
if kwargs:
self._unfinished = dict((f, k) for (k, f) in kwargs.items())
futures = list(kwargs.values())
else:
self._unfinished = dict((f, i) for (i, f) in enumerate(args))
futures = args
self._finished = collections.deque()
self.current_index = self.current_future = None
self._running_future = None
for future in futures:
future.add_done_callback(self._done_callback)
def done(self):
"""Returns True if this iterator has no more results."""
if self._finished or self._unfinished:
return False
# Clear the 'current' values when iteration is done.
self.current_index = self.current_future = None
return True
def next(self):
"""Returns a `.Future` that will yield the next available result.
Note that this `.Future` will not be the same object as any of
the inputs.
"""
self._running_future = TracebackFuture()
if self._finished:
self._return_result(self._finished.popleft())
return self._running_future
def _done_callback(self, done):
if self._running_future and not self._running_future.done():
self._return_result(done)
else:
self._finished.append(done)
def _return_result(self, done):
"""Called set the returned future's state that of the future
we yielded, and set the current future for the iterator.
"""
chain_future(done, self._running_future)
self.current_future = done
self.current_index = self._unfinished.pop(done)
@coroutine
def __aiter__(self):
raise Return(self)
def __anext__(self):
if self.done():
# Lookup by name to silence pyflakes on older versions.
raise getattr(builtins, 'StopAsyncIteration')()
return self.next()
class YieldPoint(object):
@ -330,11 +569,13 @@ def Task(func, *args, **kwargs):
yielded.
"""
future = Future()
def handle_exception(typ, value, tb):
if future.done():
return False
future.set_exc_info((typ, value, tb))
return True
def set_result(result):
if future.done():
return
@ -346,6 +587,11 @@ def Task(func, *args, **kwargs):
class YieldFuture(YieldPoint):
def __init__(self, future, io_loop=None):
"""Adapts a `.Future` to the `YieldPoint` interface.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
self.future = future
self.io_loop = io_loop or IOLoop.current()
@ -357,7 +603,7 @@ class YieldFuture(YieldPoint):
self.io_loop.add_future(self.future, runner.result_callback(self.key))
else:
self.runner = None
self.result = self.future.result()
self.result_fn = self.future.result
def is_ready(self):
if self.runner is not None:
@ -369,33 +615,110 @@ class YieldFuture(YieldPoint):
if self.runner is not None:
return self.runner.pop_result(self.key).result()
else:
return self.result
return self.result_fn()
class Multi(YieldPoint):
def _contains_yieldpoint(children):
"""Returns True if ``children`` contains any YieldPoints.
``children`` may be a dict or a list, as used by `MultiYieldPoint`
and `multi_future`.
"""
if isinstance(children, dict):
return any(isinstance(i, YieldPoint) for i in children.values())
if isinstance(children, list):
return any(isinstance(i, YieldPoint) for i in children)
return False
def multi(children, quiet_exceptions=()):
"""Runs multiple asynchronous operations in parallel.
Takes a list of ``YieldPoints`` or ``Futures`` and returns a list of
their responses. It is not necessary to call `Multi` explicitly,
since the engine will do so automatically when the generator yields
a list of ``YieldPoints`` or a mixture of ``YieldPoints`` and ``Futures``.
``children`` may either be a list or a dict whose values are
yieldable objects. ``multi()`` returns a new yieldable
object that resolves to a parallel structure containing their
results. If ``children`` is a list, the result is a list of
results in the same order; if it is a dict, the result is a dict
with the same keys.
That is, ``results = yield multi(list_of_futures)`` is equivalent
to::
results = []
for future in list_of_futures:
results.append(yield future)
If any children raise exceptions, ``multi()`` will raise the first
one. All others will be logged, unless they are of types
contained in the ``quiet_exceptions`` argument.
If any of the inputs are `YieldPoints <YieldPoint>`, the returned
yieldable object is a `YieldPoint`. Otherwise, returns a `.Future`.
This means that the result of `multi` can be used in a native
coroutine if and only if all of its children can be.
In a ``yield``-based coroutine, it is not normally necessary to
call this function directly, since the coroutine runner will
do it automatically when a list or dict is yielded. However,
it is necessary in ``await``-based coroutines, or to pass
the ``quiet_exceptions`` argument.
This function is available under the names ``multi()`` and ``Multi()``
for historical reasons.
.. versionchanged:: 4.2
If multiple yieldables fail, any exceptions after the first
(which is raised) will be logged. Added the ``quiet_exceptions``
argument to suppress this logging for selected exception types.
.. versionchanged:: 4.3
Replaced the class ``Multi`` and the function ``multi_future``
with a unified function ``multi``. Added support for yieldables
other than `YieldPoint` and `.Future`.
Instead of a list, the argument may also be a dictionary whose values are
Futures, in which case a parallel dictionary is returned mapping the same
keys to their results.
"""
def __init__(self, children):
if _contains_yieldpoint(children):
return MultiYieldPoint(children, quiet_exceptions=quiet_exceptions)
else:
return multi_future(children, quiet_exceptions=quiet_exceptions)
Multi = multi
class MultiYieldPoint(YieldPoint):
"""Runs multiple asynchronous operations in parallel.
This class is similar to `multi`, but it always creates a stack
context even when no children require it. It is not compatible with
native coroutines.
.. versionchanged:: 4.2
If multiple ``YieldPoints`` fail, any exceptions after the first
(which is raised) will be logged. Added the ``quiet_exceptions``
argument to suppress this logging for selected exception types.
.. versionchanged:: 4.3
Renamed from ``Multi`` to ``MultiYieldPoint``. The name ``Multi``
remains as an alias for the equivalent `multi` function.
.. deprecated:: 4.3
Use `multi` instead.
"""
def __init__(self, children, quiet_exceptions=()):
self.keys = None
if isinstance(children, dict):
self.keys = list(children.keys())
children = children.values()
self.children = []
for i in children:
if not isinstance(i, YieldPoint):
i = convert_yielded(i)
if is_future(i):
i = YieldFuture(i)
self.children.append(i)
assert all(isinstance(i, YieldPoint) for i in self.children)
self.unfinished_children = set(self.children)
self.quiet_exceptions = quiet_exceptions
def start(self, runner):
for i in self.children:
@ -408,58 +731,80 @@ class Multi(YieldPoint):
return not self.unfinished_children
def get_result(self):
result = (i.get_result() for i in self.children)
result_list = []
exc_info = None
for f in self.children:
try:
result_list.append(f.get_result())
except Exception as e:
if exc_info is None:
exc_info = sys.exc_info()
else:
if not isinstance(e, self.quiet_exceptions):
app_log.error("Multiple exceptions in yield list",
exc_info=True)
if exc_info is not None:
raise_exc_info(exc_info)
if self.keys is not None:
return dict(zip(self.keys, result))
return dict(zip(self.keys, result_list))
else:
return list(result)
return list(result_list)
def multi_future(children):
def multi_future(children, quiet_exceptions=()):
"""Wait for multiple asynchronous futures in parallel.
Takes a list of ``Futures`` (but *not* other ``YieldPoints``) and returns
a new Future that resolves when all the other Futures are done.
If all the ``Futures`` succeeded, the returned Future's result is a list
of their results. If any failed, the returned Future raises the exception
of the first one to fail.
Instead of a list, the argument may also be a dictionary whose values are
Futures, in which case a parallel dictionary is returned mapping the same
keys to their results.
It is not necessary to call `multi_future` explcitly, since the engine will
do so automatically when the generator yields a list of `Futures`.
This function is faster than the `Multi` `YieldPoint` because it does not
require the creation of a stack context.
This function is similar to `multi`, but does not support
`YieldPoints <YieldPoint>`.
.. versionadded:: 4.0
.. versionchanged:: 4.2
If multiple ``Futures`` fail, any exceptions after the first (which is
raised) will be logged. Added the ``quiet_exceptions``
argument to suppress this logging for selected exception types.
.. deprecated:: 4.3
Use `multi` instead.
"""
if isinstance(children, dict):
keys = list(children.keys())
children = children.values()
else:
keys = None
children = list(map(convert_yielded, children))
assert all(is_future(i) for i in children)
unfinished_children = set(children)
future = Future()
if not children:
future.set_result({} if keys is not None else [])
def callback(f):
unfinished_children.remove(f)
if not unfinished_children:
try:
result_list = [i.result() for i in children]
except Exception:
future.set_exc_info(sys.exc_info())
else:
result_list = []
for f in children:
try:
result_list.append(f.result())
except Exception as e:
if future.done():
if not isinstance(e, quiet_exceptions):
app_log.error("Multiple exceptions in yield list",
exc_info=True)
else:
future.set_exc_info(sys.exc_info())
if not future.done():
if keys is not None:
future.set_result(dict(zip(keys, result_list)))
else:
future.set_result(result_list)
listening = set()
for f in children:
f.add_done_callback(callback)
if f not in listening:
listening.add(f)
f.add_done_callback(callback)
return future
@ -470,6 +815,11 @@ def maybe_future(x):
it is wrapped in a new `.Future`. This is suitable for use as
``result = yield gen.maybe_future(f())`` when you don't know whether
``f()`` returns a `.Future` or not.
.. deprecated:: 4.3
This function only handles ``Futures``, not other yieldable objects.
Instead of `maybe_future`, check for the non-future result types
you expect (often just ``None``), and ``yield`` anything unknown.
"""
if is_future(x):
return x
@ -479,7 +829,7 @@ def maybe_future(x):
return fut
def with_timeout(timeout, future, io_loop=None):
def with_timeout(timeout, future, io_loop=None, quiet_exceptions=()):
"""Wraps a `.Future` in a timeout.
Raises `TimeoutError` if the input future does not complete before
@ -487,9 +837,17 @@ def with_timeout(timeout, future, io_loop=None):
`.IOLoop.add_timeout` (i.e. a `datetime.timedelta` or an absolute time
relative to `.IOLoop.time`)
If the wrapped `.Future` fails after it has timed out, the exception
will be logged unless it is of a type contained in ``quiet_exceptions``
(which may be an exception type or a sequence of types).
Currently only supports Futures, not other `YieldPoint` classes.
.. versionadded:: 4.0
.. versionchanged:: 4.1
Added the ``quiet_exceptions`` argument and the logging of unhandled
exceptions.
"""
# TODO: allow yield points in addition to futures?
# Tricky to do with stack_context semantics.
@ -503,9 +861,21 @@ def with_timeout(timeout, future, io_loop=None):
chain_future(future, result)
if io_loop is None:
io_loop = IOLoop.current()
def error_callback(future):
try:
future.result()
except Exception as e:
if not isinstance(e, quiet_exceptions):
app_log.error("Exception in Future %r after timeout",
future, exc_info=True)
def timeout_callback():
result.set_exception(TimeoutError("Timeout"))
# In case the wrapped future goes on to fail, log it.
future.add_done_callback(error_callback)
timeout_handle = io_loop.add_timeout(
timeout,
lambda: result.set_exception(TimeoutError("Timeout")))
timeout, timeout_callback)
if isinstance(future, Future):
# We know this future will resolve on the IOLoop, so we don't
# need the extra thread-safety of IOLoop.add_future (and we also
@ -520,6 +890,25 @@ def with_timeout(timeout, future, io_loop=None):
return result
def sleep(duration):
"""Return a `.Future` that resolves after the given number of seconds.
When used with ``yield`` in a coroutine, this is a non-blocking
analogue to `time.sleep` (which should not be used in coroutines
because it is blocking)::
yield gen.sleep(0.5)
Note that calling this function on its own does nothing; you must
wait on the `.Future` it returns (usually by yielding it).
.. versionadded:: 4.1
"""
f = Future()
IOLoop.current().call_later(duration, lambda: f.set_result(None))
return f
_null_future = Future()
_null_future.set_result(None)
@ -613,13 +1002,20 @@ class Runner(object):
self.future = None
try:
orig_stack_contexts = stack_context._state.contexts
exc_info = None
try:
value = future.result()
except Exception:
self.had_exception = True
yielded = self.gen.throw(*sys.exc_info())
exc_info = sys.exc_info()
if exc_info is not None:
yielded = self.gen.throw(*exc_info)
exc_info = None
else:
yielded = self.gen.send(value)
if stack_context._state.contexts is not orig_stack_contexts:
self.gen.throw(
stack_context.StackContextInconsistentError(
@ -636,7 +1032,7 @@ class Runner(object):
raise LeakedCallbackError(
"finished without waiting for callbacks %r" %
self.pending_callbacks)
self.result_future.set_result(getattr(e, 'value', None))
self.result_future.set_result(_value_from_stopiteration(e))
self.result_future = None
self._deactivate_stack_context()
return
@ -653,19 +1049,16 @@ class Runner(object):
self.running = False
def handle_yield(self, yielded):
if isinstance(yielded, list):
if all(is_future(f) for f in yielded):
yielded = multi_future(yielded)
else:
yielded = Multi(yielded)
elif isinstance(yielded, dict):
if all(is_future(f) for f in yielded.values()):
yielded = multi_future(yielded)
else:
yielded = Multi(yielded)
# Lists containing YieldPoints require stack contexts;
# other lists are handled in convert_yielded.
if _contains_yieldpoint(yielded):
yielded = multi(yielded)
if isinstance(yielded, YieldPoint):
# YieldPoints are too closely coupled to the Runner to go
# through the generic convert_yielded mechanism.
self.future = TracebackFuture()
def start_yield_point():
try:
yielded.start(self)
@ -677,12 +1070,14 @@ class Runner(object):
except Exception:
self.future = TracebackFuture()
self.future.set_exc_info(sys.exc_info())
if self.stack_context_deactivate is None:
# Start a stack context if this is the first
# YieldPoint we've seen.
with stack_context.ExceptionStackContext(
self.handle_exception) as deactivate:
self.stack_context_deactivate = deactivate
def cb():
start_yield_point()
self.run()
@ -690,16 +1085,17 @@ class Runner(object):
return False
else:
start_yield_point()
elif is_future(yielded):
self.future = yielded
if not self.future.done() or self.future is moment:
self.io_loop.add_future(
self.future, lambda f: self.run())
return False
else:
self.future = TracebackFuture()
self.future.set_exception(BadYieldError(
"yielded unknown object %r" % (yielded,)))
try:
self.future = convert_yielded(yielded)
except BadYieldError:
self.future = TracebackFuture()
self.future.set_exc_info(sys.exc_info())
if not self.future.done() or self.future is moment:
self.io_loop.add_future(
self.future, lambda f: self.run())
return False
return True
def result_callback(self, key):
@ -738,3 +1134,108 @@ def _argument_adapter(callback):
else:
callback(None)
return wrapper
# Convert Awaitables into Futures. It is unfortunately possible
# to have infinite recursion here if those Awaitables assume that
# we're using a different coroutine runner and yield objects
# we don't understand. If that happens, the solution is to
# register that runner's yieldable objects with convert_yielded.
if sys.version_info >= (3, 3):
exec(textwrap.dedent("""
@coroutine
def _wrap_awaitable(x):
if hasattr(x, '__await__'):
x = x.__await__()
return (yield from x)
"""))
else:
# Py2-compatible version for use with Cython.
# Copied from PEP 380.
@coroutine
def _wrap_awaitable(x):
if hasattr(x, '__await__'):
_i = x.__await__()
else:
_i = iter(x)
try:
_y = next(_i)
except StopIteration as _e:
_r = _value_from_stopiteration(_e)
else:
while 1:
try:
_s = yield _y
except GeneratorExit as _e:
try:
_m = _i.close
except AttributeError:
pass
else:
_m()
raise _e
except BaseException as _e:
_x = sys.exc_info()
try:
_m = _i.throw
except AttributeError:
raise _e
else:
try:
_y = _m(*_x)
except StopIteration as _e:
_r = _value_from_stopiteration(_e)
break
else:
try:
if _s is None:
_y = next(_i)
else:
_y = _i.send(_s)
except StopIteration as _e:
_r = _value_from_stopiteration(_e)
break
raise Return(_r)
def convert_yielded(yielded):
"""Convert a yielded object into a `.Future`.
The default implementation accepts lists, dictionaries, and Futures.
If the `~functools.singledispatch` library is available, this function
may be extended to support additional types. For example::
@convert_yielded.register(asyncio.Future)
def _(asyncio_future):
return tornado.platform.asyncio.to_tornado_future(asyncio_future)
.. versionadded:: 4.1
"""
# Lists and dicts containing YieldPoints were handled earlier.
if isinstance(yielded, (list, dict)):
return multi(yielded)
elif is_future(yielded):
return yielded
elif isawaitable(yielded):
return _wrap_awaitable(yielded)
else:
raise BadYieldError("yielded unknown object %r" % (yielded,))
if singledispatch is not None:
convert_yielded = singledispatch(convert_yielded)
try:
# If we can import t.p.asyncio, do it for its side effect
# (registering asyncio.Future with convert_yielded).
# It's ugly to do this here, but it prevents a cryptic
# infinite recursion in _wrap_awaitable.
# Note that even with this, asyncio integration is unlikely
# to work unless the application also configures AsyncIOLoop,
# but at least the error messages in that case are more
# comprehensible than a stack overflow.
import tornado.platform.asyncio
except ImportError:
pass
else:
# Reference the imported module to make pyflakes happy.
tornado

View file

@ -21,6 +21,8 @@
from __future__ import absolute_import, division, print_function, with_statement
import re
from tornado.concurrent import Future
from tornado.escape import native_str, utf8
from tornado import gen
@ -35,6 +37,7 @@ class _QuietException(Exception):
def __init__(self):
pass
class _ExceptionLoggingContext(object):
"""Used with the ``with`` statement when calling delegate methods to
log any exceptions with the given logger. Any exceptions caught are
@ -51,6 +54,7 @@ class _ExceptionLoggingContext(object):
self.logger.error("Uncaught exception", exc_info=(typ, value, tb))
raise _QuietException
class HTTP1ConnectionParameters(object):
"""Parameters for `.HTTP1Connection` and `.HTTP1ServerConnection`.
"""
@ -160,7 +164,8 @@ class HTTP1Connection(httputil.HTTPConnection):
header_data = yield gen.with_timeout(
self.stream.io_loop.time() + self.params.header_timeout,
header_future,
io_loop=self.stream.io_loop)
io_loop=self.stream.io_loop,
quiet_exceptions=iostream.StreamClosedError)
except gen.TimeoutError:
self.close()
raise gen.Return(False)
@ -191,8 +196,17 @@ class HTTP1Connection(httputil.HTTPConnection):
skip_body = True
code = start_line.code
if code == 304:
# 304 responses may include the content-length header
# but do not actually have a body.
# http://tools.ietf.org/html/rfc7230#section-3.3
skip_body = True
if code >= 100 and code < 200:
# 1xx responses should never indicate the presence of
# a body.
if ('Content-Length' in headers or
'Transfer-Encoding' in headers):
raise httputil.HTTPInputError(
"Response code %d cannot have body" % code)
# TODO: client delegates will get headers_received twice
# in the case of a 100-continue. Document or change?
yield self._read_message(delegate)
@ -201,7 +215,8 @@ class HTTP1Connection(httputil.HTTPConnection):
not self._write_finished):
self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
if not skip_body:
body_future = self._read_body(headers, delegate)
body_future = self._read_body(
start_line.code if self.is_client else 0, headers, delegate)
if body_future is not None:
if self._body_timeout is None:
yield body_future
@ -209,7 +224,8 @@ class HTTP1Connection(httputil.HTTPConnection):
try:
yield gen.with_timeout(
self.stream.io_loop.time() + self._body_timeout,
body_future, self.stream.io_loop)
body_future, self.stream.io_loop,
quiet_exceptions=iostream.StreamClosedError)
except gen.TimeoutError:
gen_log.info("Timeout reading body from %s",
self.context)
@ -294,6 +310,8 @@ class HTTP1Connection(httputil.HTTPConnection):
self._clear_callbacks()
stream = self.stream
self.stream = None
if not self._finish_future.done():
self._finish_future.set_result(None)
return stream
def set_body_timeout(self, timeout):
@ -312,8 +330,10 @@ class HTTP1Connection(httputil.HTTPConnection):
def write_headers(self, start_line, headers, chunk=None, callback=None):
"""Implements `.HTTPConnection.write_headers`."""
lines = []
if self.is_client:
self._request_start_line = start_line
lines.append(utf8('%s %s HTTP/1.1' % (start_line[0], start_line[1])))
# Client requests with a non-empty body must have either a
# Content-Length or a Transfer-Encoding.
self._chunking_output = (
@ -322,6 +342,7 @@ class HTTP1Connection(httputil.HTTPConnection):
'Transfer-Encoding' not in headers)
else:
self._response_start_line = start_line
lines.append(utf8('HTTP/1.1 %s %s' % (start_line[1], start_line[2])))
self._chunking_output = (
# TODO: should this use
# self._request_start_line.version or
@ -351,7 +372,6 @@ class HTTP1Connection(httputil.HTTPConnection):
self._expected_content_remaining = int(headers['Content-Length'])
else:
self._expected_content_remaining = None
lines = [utf8("%s %s %s" % start_line)]
lines.extend([utf8(n) + b": " + utf8(v) for n, v in headers.get_all()])
for line in lines:
if b'\n' in line:
@ -360,6 +380,7 @@ class HTTP1Connection(httputil.HTTPConnection):
if self.stream.closed():
future = self._write_future = Future()
future.set_exception(iostream.StreamClosedError())
future.exception()
else:
if callback is not None:
self._write_callback = stack_context.wrap(callback)
@ -398,6 +419,7 @@ class HTTP1Connection(httputil.HTTPConnection):
if self.stream.closed():
future = self._write_future = Future()
self._write_future.set_exception(iostream.StreamClosedError())
self._write_future.exception()
else:
if callback is not None:
self._write_callback = stack_context.wrap(callback)
@ -437,6 +459,9 @@ class HTTP1Connection(httputil.HTTPConnection):
self._pending_write.add_done_callback(self._finish_request)
def _on_write_complete(self, future):
exc = future.exception()
if exc is not None and not isinstance(exc, iostream.StreamClosedError):
future.result()
if self._write_callback is not None:
callback = self._write_callback
self._write_callback = None
@ -455,6 +480,7 @@ class HTTP1Connection(httputil.HTTPConnection):
if start_line.version == "HTTP/1.1":
return connection_header != "close"
elif ("Content-Length" in headers
or headers.get("Transfer-Encoding", "").lower() == "chunked"
or start_line.method in ("HEAD", "GET")):
return connection_header == "keep-alive"
return False
@ -471,9 +497,14 @@ class HTTP1Connection(httputil.HTTPConnection):
self._finish_future.set_result(None)
def _parse_headers(self, data):
data = native_str(data.decode('latin1'))
eol = data.find("\r\n")
start_line = data[:eol]
# The lstrip removes newlines that some implementations sometimes
# insert between messages of a reused connection. Per RFC 7230,
# we SHOULD ignore at least one empty line before the request.
# http://tools.ietf.org/html/rfc7230#section-3.5
data = native_str(data.decode('latin1')).lstrip("\r\n")
# RFC 7230 section allows for both CRLF and bare LF.
eol = data.find("\n")
start_line = data[:eol].rstrip("\r")
try:
headers = httputil.HTTPHeaders.parse(data[eol:])
except ValueError:
@ -482,12 +513,42 @@ class HTTP1Connection(httputil.HTTPConnection):
data[eol:100])
return start_line, headers
def _read_body(self, headers, delegate):
content_length = headers.get("Content-Length")
if content_length:
content_length = int(content_length)
def _read_body(self, code, headers, delegate):
if "Content-Length" in headers:
if "Transfer-Encoding" in headers:
# Response cannot contain both Content-Length and
# Transfer-Encoding headers.
# http://tools.ietf.org/html/rfc7230#section-3.3.3
raise httputil.HTTPInputError(
"Response with both Transfer-Encoding and Content-Length")
if "," in headers["Content-Length"]:
# Proxies sometimes cause Content-Length headers to get
# duplicated. If all the values are identical then we can
# use them but if they differ it's an error.
pieces = re.split(r',\s*', headers["Content-Length"])
if any(i != pieces[0] for i in pieces):
raise httputil.HTTPInputError(
"Multiple unequal Content-Lengths: %r" %
headers["Content-Length"])
headers["Content-Length"] = pieces[0]
content_length = int(headers["Content-Length"])
if content_length > self._max_body_size:
raise httputil.HTTPInputError("Content-Length too long")
else:
content_length = None
if code == 204:
# This response code is not allowed to have a non-empty body,
# and has an implicit length of zero instead of read-until-close.
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
if ("Transfer-Encoding" in headers or
content_length not in (None, 0)):
raise httputil.HTTPInputError(
"Response with code %d should not have body" % code)
content_length = 0
if content_length is not None:
return self._read_fixed_body(content_length, delegate)
if headers.get("Transfer-Encoding") == "chunked":
return self._read_chunked_body(delegate)
@ -503,7 +564,9 @@ class HTTP1Connection(httputil.HTTPConnection):
content_length -= len(body)
if not self._write_finished or self.is_client:
with _ExceptionLoggingContext(app_log):
yield gen.maybe_future(delegate.data_received(body))
ret = delegate.data_received(body)
if ret is not None:
yield ret
@gen.coroutine
def _read_chunked_body(self, delegate):
@ -524,7 +587,9 @@ class HTTP1Connection(httputil.HTTPConnection):
bytes_to_read -= len(chunk)
if not self._write_finished or self.is_client:
with _ExceptionLoggingContext(app_log):
yield gen.maybe_future(delegate.data_received(chunk))
ret = delegate.data_received(chunk)
if ret is not None:
yield ret
# chunk ends with \r\n
crlf = yield self.stream.read_bytes(2)
assert crlf == b"\r\n"
@ -564,11 +629,14 @@ class _GzipMessageDelegate(httputil.HTTPMessageDelegate):
decompressed = self._decompressor.decompress(
compressed_data, self._chunk_size)
if decompressed:
yield gen.maybe_future(
self._delegate.data_received(decompressed))
ret = self._delegate.data_received(decompressed)
if ret is not None:
yield ret
compressed_data = self._decompressor.unconsumed_tail
else:
yield gen.maybe_future(self._delegate.data_received(chunk))
ret = self._delegate.data_received(chunk)
if ret is not None:
yield ret
def finish(self):
if self._decompressor is not None:
@ -582,6 +650,9 @@ class _GzipMessageDelegate(httputil.HTTPMessageDelegate):
self._delegate.data_received(tail)
return self._delegate.finish()
def on_connection_close(self):
return self._delegate.on_connection_close()
class HTTP1ServerConnection(object):
"""An HTTP/1.x server."""

View file

@ -63,11 +63,16 @@ class HTTPClient(object):
response = http_client.fetch("http://www.google.com/")
print response.body
except httpclient.HTTPError as e:
print "Error:", e
# HTTPError is raised for non-200 responses; the response
# can be found in e.response.
print("Error: " + str(e))
except Exception as e:
# Other errors are possible, such as IOError.
print("Error: " + str(e))
http_client.close()
"""
def __init__(self, async_client_class=None, **kwargs):
self._io_loop = IOLoop()
self._io_loop = IOLoop(make_current=False)
if async_client_class is None:
async_client_class = AsyncHTTPClient
self._async_client = async_client_class(self._io_loop, **kwargs)
@ -90,11 +95,11 @@ class HTTPClient(object):
If it is a string, we construct an `HTTPRequest` using any additional
kwargs: ``HTTPRequest(request, **kwargs)``
If an error occurs during the fetch, we raise an `HTTPError`.
If an error occurs during the fetch, we raise an `HTTPError` unless
the ``raise_error`` keyword argument is set to False.
"""
response = self._io_loop.run_sync(functools.partial(
self._async_client.fetch, request, **kwargs))
response.rethrow()
return response
@ -131,6 +136,9 @@ class AsyncHTTPClient(Configurable):
# or with force_instance:
client = AsyncHTTPClient(force_instance=True,
defaults=dict(user_agent="MyUserAgent"))
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
@classmethod
def configurable_base(cls):
@ -195,7 +203,7 @@ class AsyncHTTPClient(Configurable):
raise RuntimeError("inconsistent AsyncHTTPClient cache")
del self._instance_cache[self.io_loop]
def fetch(self, request, callback=None, **kwargs):
def fetch(self, request, callback=None, raise_error=True, **kwargs):
"""Executes a request, asynchronously returning an `HTTPResponse`.
The request may be either a string URL or an `HTTPRequest` object.
@ -203,8 +211,10 @@ class AsyncHTTPClient(Configurable):
kwargs: ``HTTPRequest(request, **kwargs)``
This method returns a `.Future` whose result is an
`HTTPResponse`. The ``Future`` will raise an `HTTPError` if
the request returned a non-200 response code.
`HTTPResponse`. By default, the ``Future`` will raise an `HTTPError`
if the request returned a non-200 response code. Instead, if
``raise_error`` is set to False, the response will always be
returned regardless of the response code.
If a ``callback`` is given, it will be invoked with the `HTTPResponse`.
In the callback interface, `HTTPError` is not automatically raised.
@ -238,7 +248,7 @@ class AsyncHTTPClient(Configurable):
future.add_done_callback(handle_future)
def handle_response(response):
if response.error:
if raise_error and response.error:
future.set_exception(response.error)
else:
future.set_result(response)
@ -299,7 +309,8 @@ class HTTPRequest(object):
validate_cert=None, ca_certs=None,
allow_ipv6=None,
client_key=None, client_cert=None, body_producer=None,
expect_100_continue=False, decompress_response=None):
expect_100_continue=False, decompress_response=None,
ssl_options=None):
r"""All parameters except ``url`` are optional.
:arg string url: URL to fetch
@ -369,12 +380,15 @@ class HTTPRequest(object):
:arg string ca_certs: filename of CA certificates in PEM format,
or None to use defaults. See note below when used with
``curl_httpclient``.
:arg bool allow_ipv6: Use IPv6 when available? Default is false in
``simple_httpclient`` and true in ``curl_httpclient``
:arg string client_key: Filename for client SSL key, if any. See
note below when used with ``curl_httpclient``.
:arg string client_cert: Filename for client SSL certificate, if any.
See note below when used with ``curl_httpclient``.
:arg ssl.SSLContext ssl_options: `ssl.SSLContext` object for use in
``simple_httpclient`` (unsupported by ``curl_httpclient``).
Overrides ``validate_cert``, ``ca_certs``, ``client_key``,
and ``client_cert``.
:arg bool allow_ipv6: Use IPv6 when available? Default is true.
:arg bool expect_100_continue: If true, send the
``Expect: 100-continue`` header and wait for a continue response
before sending the request body. Only supported with
@ -397,6 +411,9 @@ class HTTPRequest(object):
.. versionadded:: 4.0
The ``body_producer`` and ``expect_100_continue`` arguments.
.. versionadded:: 4.2
The ``ssl_options`` argument.
"""
# Note that some of these attributes go through property setters
# defined below.
@ -434,6 +451,7 @@ class HTTPRequest(object):
self.allow_ipv6 = allow_ipv6
self.client_key = client_key
self.client_cert = client_cert
self.ssl_options = ssl_options
self.expect_100_continue = expect_100_continue
self.start_time = time.time()
@ -585,9 +603,12 @@ class HTTPError(Exception):
"""
def __init__(self, code, message=None, response=None):
self.code = code
message = message or httputil.responses.get(code, "Unknown")
self.message = message or httputil.responses.get(code, "Unknown")
self.response = response
Exception.__init__(self, "HTTP %d: %s" % (self.code, message))
super(HTTPError, self).__init__(code, message, response)
def __str__(self):
return "HTTP %d: %s" % (self.code, self.message)
class _RequestProxy(object):

View file

@ -37,34 +37,17 @@ from tornado import httputil
from tornado import iostream
from tornado import netutil
from tornado.tcpserver import TCPServer
from tornado.util import Configurable
class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
class HTTPServer(TCPServer, Configurable,
httputil.HTTPServerConnectionDelegate):
r"""A non-blocking, single-threaded HTTP server.
A server is defined by either a request callback that takes a
`.HTTPServerRequest` as an argument or a `.HTTPServerConnectionDelegate`
instance.
A simple example server that echoes back the URI you requested::
import tornado.httpserver
import tornado.ioloop
def handle_request(request):
message = "You requested %s\n" % request.uri
request.connection.write_headers(
httputil.ResponseStartLine('HTTP/1.1', 200, 'OK'),
{"Content-Length": str(len(message))})
request.connection.write(message)
request.connection.finish()
http_server = tornado.httpserver.HTTPServer(handle_request)
http_server.listen(8888)
tornado.ioloop.IOLoop.instance().start()
Applications should use the methods of `.HTTPConnection` to write
their response.
A server is defined by a subclass of `.HTTPServerConnectionDelegate`,
or, for backwards compatibility, a callback that takes an
`.HTTPServerRequest` as an argument. The delegate is usually a
`tornado.web.Application`.
`HTTPServer` supports keep-alive connections by default
(automatically for HTTP/1.1, or for HTTP/1.0 when the client
@ -79,15 +62,15 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
if Tornado is run behind an SSL-decoding proxy that does not set one of
the supported ``xheaders``.
To make this server serve SSL traffic, send the ``ssl_options`` dictionary
argument with the arguments required for the `ssl.wrap_socket` method,
including ``certfile`` and ``keyfile``. (In Python 3.2+ you can pass
an `ssl.SSLContext` object instead of a dict)::
To make this server serve SSL traffic, send the ``ssl_options`` keyword
argument with an `ssl.SSLContext` object. For compatibility with older
versions of Python ``ssl_options`` may also be a dictionary of keyword
arguments for the `ssl.wrap_socket` method.::
HTTPServer(applicaton, ssl_options={
"certfile": os.path.join(data_dir, "mydomain.crt"),
"keyfile": os.path.join(data_dir, "mydomain.key"),
})
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_ctx.load_cert_chain(os.path.join(data_dir, "mydomain.crt"),
os.path.join(data_dir, "mydomain.key"))
HTTPServer(applicaton, ssl_options=ssl_ctx)
`HTTPServer` initialization follows one of three patterns (the
initialization methods are defined on `tornado.tcpserver.TCPServer`):
@ -96,7 +79,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
server = HTTPServer(app)
server.listen(8888)
IOLoop.instance().start()
IOLoop.current().start()
In many cases, `tornado.web.Application.listen` can be used to avoid
the need to explicitly create the `HTTPServer`.
@ -107,7 +90,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
server = HTTPServer(app)
server.bind(8888)
server.start(0) # Forks multiple sub-processes
IOLoop.instance().start()
IOLoop.current().start()
When using this interface, an `.IOLoop` must *not* be passed
to the `HTTPServer` constructor. `~.TCPServer.start` will always start
@ -119,7 +102,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
tornado.process.fork_processes(0)
server = HTTPServer(app)
server.add_sockets(sockets)
IOLoop.instance().start()
IOLoop.current().start()
The `~.TCPServer.add_sockets` interface is more complicated,
but it can be used with `tornado.process.fork_processes` to
@ -133,13 +116,29 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
``idle_connection_timeout``, ``body_timeout``, ``max_body_size``
arguments. Added support for `.HTTPServerConnectionDelegate`
instances as ``request_callback``.
.. versionchanged:: 4.1
`.HTTPServerConnectionDelegate.start_request` is now called with
two arguments ``(server_conn, request_conn)`` (in accordance with the
documentation) instead of one ``(request_conn)``.
.. versionchanged:: 4.2
`HTTPServer` is now a subclass of `tornado.util.Configurable`.
"""
def __init__(self, request_callback, no_keep_alive=False, io_loop=None,
xheaders=False, ssl_options=None, protocol=None,
decompress_request=False,
chunk_size=None, max_header_size=None,
idle_connection_timeout=None, body_timeout=None,
max_body_size=None, max_buffer_size=None):
def __init__(self, *args, **kwargs):
# Ignore args to __init__; real initialization belongs in
# initialize since we're Configurable. (there's something
# weird in initialization order between this class,
# Configurable, and TCPServer so we can't leave __init__ out
# completely)
pass
def initialize(self, request_callback, no_keep_alive=False, io_loop=None,
xheaders=False, ssl_options=None, protocol=None,
decompress_request=False,
chunk_size=None, max_header_size=None,
idle_connection_timeout=None, body_timeout=None,
max_body_size=None, max_buffer_size=None):
self.request_callback = request_callback
self.no_keep_alive = no_keep_alive
self.xheaders = xheaders
@ -156,6 +155,14 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
read_chunk_size=chunk_size)
self._connections = set()
@classmethod
def configurable_base(cls):
return HTTPServer
@classmethod
def configurable_default(cls):
return HTTPServer
@gen.coroutine
def close_all_connections(self):
while self._connections:
@ -172,7 +179,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
conn.start_serving(self)
def start_request(self, server_conn, request_conn):
return _ServerRequestAdapter(self, request_conn)
return _ServerRequestAdapter(self, server_conn, request_conn)
def on_close(self, server_conn):
self._connections.remove(server_conn)
@ -181,7 +188,6 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
class _HTTPRequestContext(object):
def __init__(self, stream, address, protocol):
self.address = address
self.protocol = protocol
# Save the socket's address family now so we know how to
# interpret self.address even after the stream is closed
# and its socket attribute replaced with None.
@ -245,13 +251,14 @@ class _ServerRequestAdapter(httputil.HTTPMessageDelegate):
"""Adapts the `HTTPMessageDelegate` interface to the interface expected
by our clients.
"""
def __init__(self, server, connection):
def __init__(self, server, server_conn, request_conn):
self.server = server
self.connection = connection
self.connection = request_conn
self.request = None
if isinstance(server.request_callback,
httputil.HTTPServerConnectionDelegate):
self.delegate = server.request_callback.start_request(connection)
self.delegate = server.request_callback.start_request(
server_conn, request_conn)
self._chunks = None
else:
self.delegate = None

View file

@ -33,7 +33,7 @@ import time
from tornado.escape import native_str, parse_qs_bytes, utf8
from tornado.log import gen_log
from tornado.util import ObjectDict, bytes_type
from tornado.util import ObjectDict
try:
import Cookie # py2
@ -62,6 +62,11 @@ except ImportError:
pass
# RFC 7230 section 3.5: a recipient MAY recognize a single LF as a line
# terminator and ignore any preceding CR.
_CRLF_RE = re.compile(r'\r?\n')
class _NormalizedHeaderCache(dict):
"""Dynamic cached mapping of header names to Http-Header-Case.
@ -93,7 +98,7 @@ class _NormalizedHeaderCache(dict):
_normalized_headers = _NormalizedHeaderCache(1000)
class HTTPHeaders(dict):
class HTTPHeaders(collections.MutableMapping):
"""A dictionary that maintains ``Http-Header-Case`` for all keys.
Supports multiple values per key via a pair of new methods,
@ -122,9 +127,7 @@ class HTTPHeaders(dict):
Set-Cookie: C=D
"""
def __init__(self, *args, **kwargs):
# Don't pass args or kwargs to dict.__init__, as it will bypass
# our __setitem__
dict.__init__(self)
self._dict = {}
self._as_list = {}
self._last_key = None
if (len(args) == 1 and len(kwargs) == 0 and
@ -143,10 +146,8 @@ class HTTPHeaders(dict):
norm_name = _normalized_headers[name]
self._last_key = norm_name
if norm_name in self:
# bypass our override of __setitem__ since it modifies _as_list
dict.__setitem__(self, norm_name,
native_str(self[norm_name]) + ',' +
native_str(value))
self._dict[norm_name] = (native_str(self[norm_name]) + ',' +
native_str(value))
self._as_list[norm_name].append(value)
else:
self[norm_name] = value
@ -178,8 +179,7 @@ class HTTPHeaders(dict):
# continuation of a multi-line header
new_part = ' ' + line.lstrip()
self._as_list[self._last_key][-1] += new_part
dict.__setitem__(self, self._last_key,
self[self._last_key] + new_part)
self._dict[self._last_key] += new_part
else:
name, value = line.split(":", 1)
self.add(name, value.strip())
@ -193,42 +193,41 @@ class HTTPHeaders(dict):
[('Content-Length', '42'), ('Content-Type', 'text/html')]
"""
h = cls()
for line in headers.splitlines():
for line in _CRLF_RE.split(headers):
if line:
h.parse_line(line)
return h
# dict implementation overrides
# MutableMapping abstract method implementations.
def __setitem__(self, name, value):
norm_name = _normalized_headers[name]
dict.__setitem__(self, norm_name, value)
self._dict[norm_name] = value
self._as_list[norm_name] = [value]
def __getitem__(self, name):
return dict.__getitem__(self, _normalized_headers[name])
return self._dict[_normalized_headers[name]]
def __delitem__(self, name):
norm_name = _normalized_headers[name]
dict.__delitem__(self, norm_name)
del self._dict[norm_name]
del self._as_list[norm_name]
def __contains__(self, name):
norm_name = _normalized_headers[name]
return dict.__contains__(self, norm_name)
def __len__(self):
return len(self._dict)
def get(self, name, default=None):
return dict.get(self, _normalized_headers[name], default)
def update(self, *args, **kwargs):
# dict.update bypasses our __setitem__
for k, v in dict(*args, **kwargs).items():
self[k] = v
def __iter__(self):
return iter(self._dict)
def copy(self):
# default implementation returns dict(self), not the subclass
# defined in dict but not in MutableMapping.
return HTTPHeaders(self)
# Use our overridden copy method for the copy.copy module.
# This makes shallow copies one level deeper, but preserves
# the appearance that HTTPHeaders is a single container.
__copy__ = copy
class HTTPServerRequest(object):
"""A single HTTP request.
@ -331,11 +330,11 @@ class HTTPServerRequest(object):
self.uri = uri
self.version = version
self.headers = headers or HTTPHeaders()
self.body = body or ""
self.body = body or b""
# set remote IP and protocol
context = getattr(connection, 'context', None)
self.remote_ip = getattr(context, 'remote_ip')
self.remote_ip = getattr(context, 'remote_ip', None)
self.protocol = getattr(context, 'protocol', "http")
self.host = host or self.headers.get("Host") or "127.0.0.1"
@ -379,7 +378,9 @@ class HTTPServerRequest(object):
Use ``request.connection`` and the `.HTTPConnection` methods
to write the response.
"""
assert isinstance(chunk, bytes_type)
assert isinstance(chunk, bytes)
assert self.version.startswith("HTTP/1."), \
"deprecated interface only supported in HTTP/1.x"
self.connection.write(chunk, callback=callback)
def finish(self):
@ -406,15 +407,14 @@ class HTTPServerRequest(object):
def get_ssl_certificate(self, binary_form=False):
"""Returns the client's SSL certificate, if any.
To use client certificates, the HTTPServer must have been constructed
with cert_reqs set in ssl_options, e.g.::
To use client certificates, the HTTPServer's
`ssl.SSLContext.verify_mode` field must be set, e.g.::
server = HTTPServer(app,
ssl_options=dict(
certfile="foo.crt",
keyfile="foo.key",
cert_reqs=ssl.CERT_REQUIRED,
ca_certs="cacert.crt"))
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_ctx.load_cert_chain("foo.crt", "foo.key")
ssl_ctx.load_verify_locations("cacerts.pem")
ssl_ctx.verify_mode = ssl.CERT_REQUIRED
server = HTTPServer(app, ssl_options=ssl_ctx)
By default, the return value is a dictionary (or None, if no
client certificate is present). If ``binary_form`` is true, a
@ -543,6 +543,8 @@ class HTTPConnection(object):
headers.
:arg callback: a callback to be run when the write is complete.
The ``version`` field of ``start_line`` is ignored.
Returns a `.Future` if no callback is given.
"""
raise NotImplementedError()
@ -562,11 +564,18 @@ class HTTPConnection(object):
def url_concat(url, args):
"""Concatenate url and argument dictionary regardless of whether
"""Concatenate url and arguments regardless of whether
url has existing query parameters.
``args`` may be either a dictionary or a list of key-value pairs
(the latter allows for multiple values with the same key.
>>> url_concat("http://example.com/foo", dict(c="d"))
'http://example.com/foo?c=d'
>>> url_concat("http://example.com/foo?a=b", dict(c="d"))
'http://example.com/foo?a=b&c=d'
>>> url_concat("http://example.com/foo?a=b", [("c", "d"), ("c", "d2")])
'http://example.com/foo?a=b&c=d&c=d2'
"""
if not args:
return url
@ -682,14 +691,17 @@ def parse_body_arguments(content_type, body, arguments, files, headers=None):
if values:
arguments.setdefault(name, []).extend(values)
elif content_type.startswith("multipart/form-data"):
fields = content_type.split(";")
for field in fields:
k, sep, v = field.strip().partition("=")
if k == "boundary" and v:
parse_multipart_form_data(utf8(v), body, arguments, files)
break
else:
gen_log.warning("Invalid multipart/form-data")
try:
fields = content_type.split(";")
for field in fields:
k, sep, v = field.strip().partition("=")
if k == "boundary" and v:
parse_multipart_form_data(utf8(v), body, arguments, files)
break
else:
raise ValueError("multipart boundary not found")
except Exception as e:
gen_log.warning("Invalid multipart/form-data: %s", e)
def parse_multipart_form_data(boundary, data, arguments, files):
@ -775,7 +787,7 @@ def parse_request_start_line(line):
method, path, version = line.split(" ")
except ValueError:
raise HTTPInputError("Malformed HTTP request line")
if not version.startswith("HTTP/"):
if not re.match(r"^HTTP/1\.[0-9]$", version):
raise HTTPInputError(
"Malformed HTTP version in HTTP Request-Line: %r" % version)
return RequestStartLine(method, path, version)
@ -794,7 +806,7 @@ def parse_response_start_line(line):
ResponseStartLine(version='HTTP/1.1', code=200, reason='OK')
"""
line = native_str(line)
match = re.match("(HTTP/1.[01]) ([0-9]+) ([^\r]*)", line)
match = re.match("(HTTP/1.[0-9]) ([0-9]+) ([^\r]*)", line)
if not match:
raise HTTPInputError("Error parsing response start line")
return ResponseStartLine(match.group(1), int(match.group(2)),
@ -803,6 +815,8 @@ def parse_response_start_line(line):
# _parseparam and _parse_header are copied and modified from python2.7's cgi.py
# The original 2.7 version of this code did not correctly support some
# combinations of semicolons and double quotes.
# It has also been modified to support valueless parameters as seen in
# websocket extension negotiations.
def _parseparam(s):
@ -836,9 +850,48 @@ def _parse_header(line):
value = value[1:-1]
value = value.replace('\\\\', '\\').replace('\\"', '"')
pdict[name] = value
else:
pdict[p] = None
return key, pdict
def _encode_header(key, pdict):
"""Inverse of _parse_header.
>>> _encode_header('permessage-deflate',
... {'client_max_window_bits': 15, 'client_no_context_takeover': None})
'permessage-deflate; client_max_window_bits=15; client_no_context_takeover'
"""
if not pdict:
return key
out = [key]
# Sort the parameters just to make it easy to test.
for k, v in sorted(pdict.items()):
if v is None:
out.append(k)
else:
# TODO: quote if necessary.
out.append('%s=%s' % (k, v))
return '; '.join(out)
def doctests():
import doctest
return doctest.DocTestSuite()
def split_host_and_port(netloc):
"""Returns ``(host, port)`` tuple from ``netloc``.
Returned ``port`` will be ``None`` if not present.
.. versionadded:: 4.1
"""
match = re.match(r'^(.+):(\d+)$', netloc)
if match:
host = match.group(1)
port = int(match.group(2))
else:
host = netloc
port = None
return (host, port)

View file

@ -41,6 +41,7 @@ import sys
import threading
import time
import traceback
import math
from tornado.concurrent import TracebackFuture, is_future
from tornado.log import app_log, gen_log
@ -76,35 +77,52 @@ class IOLoop(Configurable):
simultaneous connections, you should use a system that supports
either ``epoll`` or ``kqueue``.
Example usage for a simple TCP server::
Example usage for a simple TCP server:
.. testcode::
import errno
import functools
import ioloop
import tornado.ioloop
import socket
def connection_ready(sock, fd, events):
while True:
try:
connection, address = sock.accept()
except socket.error, e:
except socket.error as e:
if e.args[0] not in (errno.EWOULDBLOCK, errno.EAGAIN):
raise
return
connection.setblocking(0)
handle_connection(connection, address)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setblocking(0)
sock.bind(("", port))
sock.listen(128)
if __name__ == '__main__':
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setblocking(0)
sock.bind(("", port))
sock.listen(128)
io_loop = ioloop.IOLoop.instance()
callback = functools.partial(connection_ready, sock)
io_loop.add_handler(sock.fileno(), callback, io_loop.READ)
io_loop.start()
io_loop = tornado.ioloop.IOLoop.current()
callback = functools.partial(connection_ready, sock)
io_loop.add_handler(sock.fileno(), callback, io_loop.READ)
io_loop.start()
.. testoutput::
:hide:
By default, a newly-constructed `IOLoop` becomes the thread's current
`IOLoop`, unless there already is a current `IOLoop`. This behavior
can be controlled with the ``make_current`` argument to the `IOLoop`
constructor: if ``make_current=True``, the new `IOLoop` will always
try to become current and it raises an error if there is already a
current instance. If ``make_current=False``, the new `IOLoop` will
not try to become current.
.. versionchanged:: 4.2
Added the ``make_current`` keyword argument to the `IOLoop`
constructor.
"""
# Constants from the epoll module
_EPOLLIN = 0x001
@ -133,7 +151,8 @@ class IOLoop(Configurable):
Most applications have a single, global `IOLoop` running on the
main thread. Use this method to get this instance from
another thread. To get the current thread's `IOLoop`, use `current()`.
another thread. In most other cases, it is better to use `current()`
to get the current thread's `IOLoop`.
"""
if not hasattr(IOLoop, "_instance"):
with IOLoop._instance_lock:
@ -167,28 +186,26 @@ class IOLoop(Configurable):
del IOLoop._instance
@staticmethod
def current():
def current(instance=True):
"""Returns the current thread's `IOLoop`.
If an `IOLoop` is currently running or has been marked as current
by `make_current`, returns that instance. Otherwise returns
`IOLoop.instance()`, i.e. the main thread's `IOLoop`.
A common pattern for classes that depend on ``IOLoops`` is to use
a default argument to enable programs with multiple ``IOLoops``
but not require the argument for simpler applications::
class MyClass(object):
def __init__(self, io_loop=None):
self.io_loop = io_loop or IOLoop.current()
If an `IOLoop` is currently running or has been marked as
current by `make_current`, returns that instance. If there is
no current `IOLoop`, returns `IOLoop.instance()` (i.e. the
main thread's `IOLoop`, creating one if necessary) if ``instance``
is true.
In general you should use `IOLoop.current` as the default when
constructing an asynchronous object, and use `IOLoop.instance`
when you mean to communicate to the main thread from a different
one.
.. versionchanged:: 4.1
Added ``instance`` argument to control the fallback to
`IOLoop.instance()`.
"""
current = getattr(IOLoop._current, "instance", None)
if current is None:
if current is None and instance:
return IOLoop.instance()
return current
@ -197,9 +214,13 @@ class IOLoop(Configurable):
An `IOLoop` automatically becomes current for its thread
when it is started, but it is sometimes useful to call
`make_current` explictly before starting the `IOLoop`,
`make_current` explicitly before starting the `IOLoop`,
so that code run at startup time can find the right
instance.
.. versionchanged:: 4.1
An `IOLoop` created while there is no current `IOLoop`
will automatically become current.
"""
IOLoop._current.instance = self
@ -223,8 +244,14 @@ class IOLoop(Configurable):
from tornado.platform.select import SelectIOLoop
return SelectIOLoop
def initialize(self):
pass
def initialize(self, make_current=None):
if make_current is None:
if IOLoop.current(instance=False) is None:
self.make_current()
elif make_current:
if IOLoop.current(instance=False) is not None:
raise RuntimeError("current IOLoop already exists")
self.make_current()
def close(self, all_fds=False):
"""Closes the `IOLoop`, freeing any resources used.
@ -373,10 +400,12 @@ class IOLoop(Configurable):
def run_sync(self, func, timeout=None):
"""Starts the `IOLoop`, runs the given function, and stops the loop.
If the function returns a `.Future`, the `IOLoop` will run
until the future is resolved. If it raises an exception, the
`IOLoop` will stop and the exception will be re-raised to the
caller.
The function must return either a yieldable object or
``None``. If the function returns a yieldable object, the
`IOLoop` will run until the yieldable is resolved (and
`run_sync()` will return the yieldable's result). If it raises
an exception, the `IOLoop` will stop and the exception will be
re-raised to the caller.
The keyword-only argument ``timeout`` may be used to set
a maximum duration for the function. If the timeout expires,
@ -390,13 +419,19 @@ class IOLoop(Configurable):
# do stuff...
if __name__ == '__main__':
IOLoop.instance().run_sync(main)
IOLoop.current().run_sync(main)
.. versionchanged:: 4.3
Returning a non-``None``, non-yieldable value is now an error.
"""
future_cell = [None]
def run():
try:
result = func()
if result is not None:
from tornado.gen import convert_yielded
result = convert_yielded(result)
except Exception:
future_cell[0] = TracebackFuture()
future_cell[0].set_exc_info(sys.exc_info())
@ -477,7 +512,7 @@ class IOLoop(Configurable):
.. versionadded:: 4.0
"""
self.call_at(self.time() + delay, callback, *args, **kwargs)
return self.call_at(self.time() + delay, callback, *args, **kwargs)
def call_at(self, when, callback, *args, **kwargs):
"""Runs the ``callback`` at the absolute time designated by ``when``.
@ -493,7 +528,7 @@ class IOLoop(Configurable):
.. versionadded:: 4.0
"""
self.add_timeout(when, callback, *args, **kwargs)
return self.add_timeout(when, callback, *args, **kwargs)
def remove_timeout(self, timeout):
"""Cancels a pending timeout.
@ -563,12 +598,21 @@ class IOLoop(Configurable):
"""
try:
ret = callback()
if ret is not None and is_future(ret):
if ret is not None:
from tornado import gen
# Functions that return Futures typically swallow all
# exceptions and store them in the Future. If a Future
# makes it out to the IOLoop, ensure its exception (if any)
# gets logged too.
self.add_future(ret, lambda f: f.result())
try:
ret = gen.convert_yielded(ret)
except gen.BadYieldError:
# It's not unusual for add_callback to be used with
# methods returning a non-None and non-yieldable
# result, which should just be ignored.
pass
else:
self.add_future(ret, lambda f: f.result())
except Exception:
self.handle_callback_exception(callback)
@ -633,8 +677,8 @@ class PollIOLoop(IOLoop):
(Linux), `tornado.platform.kqueue.KQueueIOLoop` (BSD and Mac), or
`tornado.platform.select.SelectIOLoop` (all platforms).
"""
def initialize(self, impl, time_func=None):
super(PollIOLoop, self).initialize()
def initialize(self, impl, time_func=None, **kwargs):
super(PollIOLoop, self).initialize(**kwargs)
self._impl = impl
if hasattr(self._impl, 'fileno'):
set_close_exec(self._impl.fileno())
@ -724,7 +768,7 @@ class PollIOLoop(IOLoop):
#
# If someone has already set a wakeup fd, we don't want to
# disturb it. This is an issue for twisted, which does its
# SIGCHILD processing in response to its own wakeup fd being
# SIGCHLD processing in response to its own wakeup fd being
# written to. As long as the wakeup fd is registered on the IOLoop,
# the loop will still wake up and everything should work.
old_wakeup_fd = None
@ -739,8 +783,10 @@ class PollIOLoop(IOLoop):
# IOLoop is just started once at the beginning.
signal.set_wakeup_fd(old_wakeup_fd)
old_wakeup_fd = None
except ValueError: # non-main thread
pass
except ValueError:
# Non-main thread, or the previous value of wakeup_fd
# is no longer valid.
old_wakeup_fd = None
try:
while True:
@ -754,17 +800,18 @@ class PollIOLoop(IOLoop):
# Do not run anything until we have determined which ones
# are ready, so timeouts that call add_timeout cannot
# schedule anything in this iteration.
due_timeouts = []
if self._timeouts:
now = self.time()
while self._timeouts:
if self._timeouts[0].callback is None:
# the timeout was cancelled
# The timeout was cancelled. Note that the
# cancellation check is repeated below for timeouts
# that are cancelled by another timeout or callback.
heapq.heappop(self._timeouts)
self._cancellations -= 1
elif self._timeouts[0].deadline <= now:
timeout = heapq.heappop(self._timeouts)
callbacks.append(timeout.callback)
del timeout
due_timeouts.append(heapq.heappop(self._timeouts))
else:
break
if (self._cancellations > 512
@ -778,9 +825,12 @@ class PollIOLoop(IOLoop):
for callback in callbacks:
self._run_callback(callback)
for timeout in due_timeouts:
if timeout.callback is not None:
self._run_callback(timeout.callback)
# Closures may be holding on to a lot of memory, so allow
# them to be freed before we go into our poll wait.
callbacks = callback = None
callbacks = callback = due_timeouts = timeout = None
if self._callbacks:
# If any callbacks or timeouts called add_callback,
@ -876,38 +926,40 @@ class PollIOLoop(IOLoop):
self._cancellations += 1
def add_callback(self, callback, *args, **kwargs):
with self._callback_lock:
if thread.get_ident() != self._thread_ident:
# If we're not on the IOLoop's thread, we need to synchronize
# with other threads, or waking logic will induce a race.
with self._callback_lock:
if self._closing:
return
list_empty = not self._callbacks
self._callbacks.append(functools.partial(
stack_context.wrap(callback), *args, **kwargs))
if list_empty:
# If we're not in the IOLoop's thread, and we added the
# first callback to an empty list, we may need to wake it
# up (it may wake up on its own, but an occasional extra
# wake is harmless). Waking up a polling IOLoop is
# relatively expensive, so we try to avoid it when we can.
self._waker.wake()
else:
if self._closing:
raise RuntimeError("IOLoop is closing")
list_empty = not self._callbacks
return
# If we're on the IOLoop's thread, we don't need the lock,
# since we don't need to wake anyone, just add the
# callback. Blindly insert into self._callbacks. This is
# safe even from signal handlers because the GIL makes
# list.append atomic. One subtlety is that if the signal
# is interrupting another thread holding the
# _callback_lock block in IOLoop.start, we may modify
# either the old or new version of self._callbacks, but
# either way will work.
self._callbacks.append(functools.partial(
stack_context.wrap(callback), *args, **kwargs))
if list_empty and thread.get_ident() != self._thread_ident:
# If we're in the IOLoop's thread, we know it's not currently
# polling. If we're not, and we added the first callback to an
# empty list, we may need to wake it up (it may wake up on its
# own, but an occasional extra wake is harmless). Waking
# up a polling IOLoop is relatively expensive, so we try to
# avoid it when we can.
self._waker.wake()
def add_callback_from_signal(self, callback, *args, **kwargs):
with stack_context.NullContext():
if thread.get_ident() != self._thread_ident:
# if the signal is handled on another thread, we can add
# it normally (modulo the NullContext)
self.add_callback(callback, *args, **kwargs)
else:
# If we're on the IOLoop's thread, we cannot use
# the regular add_callback because it may deadlock on
# _callback_lock. Blindly insert into self._callbacks.
# This is safe because the GIL makes list.append atomic.
# One subtlety is that if the signal interrupted the
# _callback_lock block in IOLoop.start, we may modify
# either the old or new version of self._callbacks,
# but either way will work.
self._callbacks.append(functools.partial(
stack_context.wrap(callback), *args, **kwargs))
self.add_callback(callback, *args, **kwargs)
class _Timeout(object):
@ -940,8 +992,16 @@ class PeriodicCallback(object):
"""Schedules the given callback to be called periodically.
The callback is called every ``callback_time`` milliseconds.
Note that the timeout is given in milliseconds, while most other
time-related functions in Tornado use seconds.
If the callback runs for longer than ``callback_time`` milliseconds,
subsequent invocations will be skipped to get back on schedule.
`start` must be called after the `PeriodicCallback` is created.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
def __init__(self, callback, callback_time, io_loop=None):
self.callback = callback
@ -965,18 +1025,29 @@ class PeriodicCallback(object):
self.io_loop.remove_timeout(self._timeout)
self._timeout = None
def is_running(self):
"""Return True if this `.PeriodicCallback` has been started.
.. versionadded:: 4.1
"""
return self._running
def _run(self):
if not self._running:
return
try:
self.callback()
return self.callback()
except Exception:
self.io_loop.handle_callback_exception(self.callback)
self._schedule_next()
finally:
self._schedule_next()
def _schedule_next(self):
if self._running:
current_time = self.io_loop.time()
while self._next_timeout <= current_time:
self._next_timeout += self.callback_time / 1000.0
if self._next_timeout <= current_time:
callback_time_sec = self.callback_time / 1000.0
self._next_timeout += (math.floor((current_time - self._next_timeout) / callback_time_sec) + 1) * callback_time_sec
self._timeout = self.io_loop.add_timeout(self._next_timeout, self._run)

View file

@ -37,9 +37,9 @@ import re
from tornado.concurrent import TracebackFuture
from tornado import ioloop
from tornado.log import gen_log, app_log
from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError
from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError, _client_ssl_defaults, _server_ssl_defaults
from tornado import stack_context
from tornado.util import bytes_type, errno_from_exception
from tornado.util import errno_from_exception
try:
from tornado.platform.posix import _set_nonblocking
@ -68,21 +68,37 @@ _ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE,
if hasattr(errno, "WSAECONNRESET"):
_ERRNO_CONNRESET += (errno.WSAECONNRESET, errno.WSAECONNABORTED, errno.WSAETIMEDOUT)
if sys.platform == 'darwin':
# OSX appears to have a race condition that causes send(2) to return
# EPROTOTYPE if called while a socket is being torn down:
# http://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/
# Since the socket is being closed anyway, treat this as an ECONNRESET
# instead of an unexpected error.
_ERRNO_CONNRESET += (errno.EPROTOTYPE,)
# More non-portable errnos:
_ERRNO_INPROGRESS = (errno.EINPROGRESS,)
if hasattr(errno, "WSAEINPROGRESS"):
_ERRNO_INPROGRESS += (errno.WSAEINPROGRESS,)
#######################################################
class StreamClosedError(IOError):
"""Exception raised by `IOStream` methods when the stream is closed.
Note that the close callback is scheduled to run *after* other
callbacks on the stream (to allow for buffered data to be processed),
so you may see this error before you see the close callback.
The ``real_error`` attribute contains the underlying error that caused
the stream to close (if any).
.. versionchanged:: 4.3
Added the ``real_error`` attribute.
"""
pass
def __init__(self, real_error=None):
super(StreamClosedError, self).__init__('Stream is closed')
self.real_error = real_error
class UnsatisfiableReadError(Exception):
@ -122,6 +138,7 @@ class BaseIOStream(object):
"""`BaseIOStream` constructor.
:arg io_loop: The `.IOLoop` to use; defaults to `.IOLoop.current`.
Deprecated since Tornado 4.1.
:arg max_buffer_size: Maximum amount of incoming data to buffer;
defaults to 100MB.
:arg read_chunk_size: Amount of data to read at one time from the
@ -160,6 +177,11 @@ class BaseIOStream(object):
self._close_callback = None
self._connect_callback = None
self._connect_future = None
# _ssl_connect_future should be defined in SSLIOStream
# but it's here so we can clean it up in maybe_run_close_callback.
# TODO: refactor that so subclasses can add additional futures
# to be cancelled.
self._ssl_connect_future = None
self._connecting = False
self._state = None
self._pending_callbacks = 0
@ -230,6 +252,12 @@ class BaseIOStream(object):
gen_log.info("Unsatisfiable read, closing connection: %s" % e)
self.close(exc_info=True)
return future
except:
if future is not None:
# Ensure that the future doesn't log an error because its
# failure was never examined.
future.add_done_callback(lambda f: f.exception())
raise
return future
def read_until(self, delimiter, callback=None, max_bytes=None):
@ -257,6 +285,10 @@ class BaseIOStream(object):
gen_log.info("Unsatisfiable read, closing connection: %s" % e)
self.close(exc_info=True)
return future
except:
if future is not None:
future.add_done_callback(lambda f: f.exception())
raise
return future
def read_bytes(self, num_bytes, callback=None, streaming_callback=None,
@ -281,7 +313,12 @@ class BaseIOStream(object):
self._read_bytes = num_bytes
self._read_partial = partial
self._streaming_callback = stack_context.wrap(streaming_callback)
self._try_inline_read()
try:
self._try_inline_read()
except:
if future is not None:
future.add_done_callback(lambda f: f.exception())
raise
return future
def read_until_close(self, callback=None, streaming_callback=None):
@ -293,9 +330,16 @@ class BaseIOStream(object):
If a callback is given, it will be run with the data as an argument;
if not, this method returns a `.Future`.
Note that if a ``streaming_callback`` is used, data will be
read from the socket as quickly as it becomes available; there
is no way to apply backpressure or cancel the reads. If flow
control or cancellation are desired, use a loop with
`read_bytes(partial=True) <.read_bytes>` instead.
.. versionchanged:: 4.0
The callback argument is now optional and a `.Future` will
be returned if it is omitted.
"""
future = self._set_read_callback(callback)
self._streaming_callback = stack_context.wrap(streaming_callback)
@ -305,7 +349,12 @@ class BaseIOStream(object):
self._run_read_callback(self._read_buffer_size, False)
return future
self._read_until_close = True
self._try_inline_read()
try:
self._try_inline_read()
except:
if future is not None:
future.add_done_callback(lambda f: f.exception())
raise
return future
def write(self, data, callback=None):
@ -324,14 +373,14 @@ class BaseIOStream(object):
.. versionchanged:: 4.0
Now returns a `.Future` if no callback is given.
"""
assert isinstance(data, bytes_type)
assert isinstance(data, bytes)
self._check_closed()
# We use bool(_write_buffer) as a proxy for write_buffer_size>0,
# so never put empty strings in the buffer.
if data:
if (self.max_write_buffer_size is not None and
self._write_buffer_size + len(data) > self.max_write_buffer_size):
raise StreamBufferFullError("Reached maximum read buffer size")
raise StreamBufferFullError("Reached maximum write buffer size")
# Break up large contiguous strings before inserting them in the
# write buffer, so we don't have to recopy the entire thing
# as we slice off pieces to send to the socket.
@ -344,6 +393,7 @@ class BaseIOStream(object):
future = None
else:
future = self._write_future = TracebackFuture()
future.add_done_callback(lambda f: f.exception())
if not self._connecting:
self._handle_write()
if self._write_buffer:
@ -401,15 +451,11 @@ class BaseIOStream(object):
if self._connect_future is not None:
futures.append(self._connect_future)
self._connect_future = None
if self._ssl_connect_future is not None:
futures.append(self._ssl_connect_future)
self._ssl_connect_future = None
for future in futures:
if (isinstance(self.error, (socket.error, IOError)) and
errno_from_exception(self.error) in _ERRNO_CONNRESET):
# Treat connection resets as closed connections so
# clients only have to catch one kind of exception
# to avoid logging.
future.set_exception(StreamClosedError())
else:
future.set_exception(self.error or StreamClosedError())
future.set_exception(StreamClosedError(real_error=self.error))
if self._close_callback is not None:
cb = self._close_callback
self._close_callback = None
@ -505,7 +551,7 @@ class BaseIOStream(object):
def wrapper():
self._pending_callbacks -= 1
try:
callback(*args)
return callback(*args)
except Exception:
app_log.error("Uncaught exception, closing connection.",
exc_info=True)
@ -517,7 +563,8 @@ class BaseIOStream(object):
# Re-raise the exception so that IOLoop.handle_callback_exception
# can see it and log the error
raise
self._maybe_add_error_listener()
finally:
self._maybe_add_error_listener()
# We schedule callbacks to be run on the next IOLoop iteration
# rather than running them directly for several reasons:
# * Prevents unbounded stack growth when a callback calls an
@ -553,7 +600,7 @@ class BaseIOStream(object):
# Pretend to have a pending callback so that an EOF in
# _read_to_buffer doesn't trigger an immediate close
# callback. At the end of this method we'll either
# estabilsh a real pending callback via
# establish a real pending callback via
# _read_from_buffer or run the close callback.
#
# We need two try statements here so that
@ -600,8 +647,8 @@ class BaseIOStream(object):
pos = self._read_to_buffer_loop()
except UnsatisfiableReadError:
raise
except Exception:
gen_log.warning("error on read", exc_info=True)
except Exception as e:
gen_log.warning("error on read: %s" % e)
self.close(exc_info=True)
return
if pos is not None:
@ -625,13 +672,13 @@ class BaseIOStream(object):
else:
callback = self._read_callback
self._read_callback = self._streaming_callback = None
if self._read_future is not None:
assert callback is None
future = self._read_future
self._read_future = None
future.set_result(self._consume(size))
if self._read_future is not None:
assert callback is None
future = self._read_future
self._read_future = None
future.set_result(self._consume(size))
if callback is not None:
assert self._read_future is None
assert (self._read_future is None) or streaming
self._run_callback(callback, self._consume(size))
else:
# If we scheduled a callback, we will add the error listener
@ -678,18 +725,22 @@ class BaseIOStream(object):
to read (i.e. the read returns EWOULDBLOCK or equivalent). On
error closes the socket and raises an exception.
"""
try:
chunk = self.read_from_fd()
except (socket.error, IOError, OSError) as e:
# ssl.SSLError is a subclass of socket.error
if e.args[0] in _ERRNO_CONNRESET:
# Treat ECONNRESET as a connection close rather than
# an error to minimize log spam (the exception will
# be available on self.error for apps that care).
while True:
try:
chunk = self.read_from_fd()
except (socket.error, IOError, OSError) as e:
if errno_from_exception(e) == errno.EINTR:
continue
# ssl.SSLError is a subclass of socket.error
if self._is_connreset(e):
# Treat ECONNRESET as a connection close rather than
# an error to minimize log spam (the exception will
# be available on self.error for apps that care).
self.close(exc_info=True)
return
self.close(exc_info=True)
return
self.close(exc_info=True)
raise
raise
break
if chunk is None:
return 0
self._read_buffer.append(chunk)
@ -804,7 +855,7 @@ class BaseIOStream(object):
self._write_buffer_frozen = True
break
else:
if e.args[0] not in _ERRNO_CONNRESET:
if not self._is_connreset(e):
# Broken pipe errors are usually caused by connection
# reset, and its better to not log EPIPE errors to
# minimize log spam
@ -831,7 +882,7 @@ class BaseIOStream(object):
def _check_closed(self):
if self.closed():
raise StreamClosedError("Stream is closed")
raise StreamClosedError(real_error=self.error)
def _maybe_add_error_listener(self):
# This method is part of an optimization: to detect a connection that
@ -882,6 +933,14 @@ class BaseIOStream(object):
self._state = self._state | state
self.io_loop.update_handler(self.fileno(), self._state)
def _is_connreset(self, exc):
"""Return true if exc is ECONNRESET or equivalent.
May be overridden in subclasses.
"""
return (isinstance(exc, (socket.error, IOError)) and
errno_from_exception(exc) in _ERRNO_CONNRESET)
class IOStream(BaseIOStream):
r"""Socket-based `IOStream` implementation.
@ -896,7 +955,9 @@ class IOStream(BaseIOStream):
connected before passing it to the `IOStream` or connected with
`IOStream.connect`.
A very simple (and broken) HTTP client using this class::
A very simple (and broken) HTTP client using this class:
.. testcode::
import tornado.ioloop
import tornado.iostream
@ -915,14 +976,19 @@ class IOStream(BaseIOStream):
stream.read_bytes(int(headers[b"Content-Length"]), on_body)
def on_body(data):
print data
print(data)
stream.close()
tornado.ioloop.IOLoop.instance().stop()
tornado.ioloop.IOLoop.current().stop()
if __name__ == '__main__':
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
stream = tornado.iostream.IOStream(s)
stream.connect(("friendfeed.com", 80), send_request)
tornado.ioloop.IOLoop.current().start()
.. testoutput::
:hide:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
stream = tornado.iostream.IOStream(s)
stream.connect(("friendfeed.com", 80), send_request)
tornado.ioloop.IOLoop.instance().start()
"""
def __init__(self, socket, *args, **kwargs):
self.socket = socket
@ -976,10 +1042,10 @@ class IOStream(BaseIOStream):
returns a `.Future` (whose result after a successful
connection will be the stream itself).
If specified, the ``server_hostname`` parameter will be used
in SSL connections for certificate validation (if requested in
the ``ssl_options``) and SNI (if supported; requires
Python 3.2+).
In SSL mode, the ``server_hostname`` parameter will be used
for certificate validation (unless disabled in the
``ssl_options``) and SNI (if supported; requires Python
2.7.9+).
Note that it is safe to call `IOStream.write
<BaseIOStream.write>` while the connection is pending, in
@ -990,8 +1056,18 @@ class IOStream(BaseIOStream):
.. versionchanged:: 4.0
If no callback is given, returns a `.Future`.
.. versionchanged:: 4.2
SSL certificates are validated by default; pass
``ssl_options=dict(cert_reqs=ssl.CERT_NONE)`` or a
suitably-configured `ssl.SSLContext` to the
`SSLIOStream` constructor to disable.
"""
self._connecting = True
if callback is not None:
self._connect_callback = stack_context.wrap(callback)
future = None
else:
future = self._connect_future = TracebackFuture()
try:
self.socket.connect(address)
except socket.error as e:
@ -1004,15 +1080,11 @@ class IOStream(BaseIOStream):
# reported later in _handle_connect.
if (errno_from_exception(e) not in _ERRNO_INPROGRESS and
errno_from_exception(e) not in _ERRNO_WOULDBLOCK):
gen_log.warning("Connect error on fd %s: %s",
self.socket.fileno(), e)
if future is None:
gen_log.warning("Connect error on fd %s: %s",
self.socket.fileno(), e)
self.close(exc_info=True)
return
if callback is not None:
self._connect_callback = stack_context.wrap(callback)
future = None
else:
future = self._connect_future = TracebackFuture()
return future
self._add_io_state(self.io_loop.WRITE)
return future
@ -1031,10 +1103,11 @@ class IOStream(BaseIOStream):
data. It can also be used immediately after connecting,
before any reads or writes.
The ``ssl_options`` argument may be either a dictionary
of options or an `ssl.SSLContext`. If a ``server_hostname``
is given, it will be used for certificate verification
(as configured in the ``ssl_options``).
The ``ssl_options`` argument may be either an `ssl.SSLContext`
object or a dictionary of keyword arguments for the
`ssl.wrap_socket` function. The ``server_hostname`` argument
will be used for certificate validation unless disabled
in the ``ssl_options``.
This method returns a `.Future` whose result is the new
`SSLIOStream`. After this method has been called,
@ -1044,6 +1117,11 @@ class IOStream(BaseIOStream):
transferred to the new stream.
.. versionadded:: 4.0
.. versionchanged:: 4.2
SSL certificates are validated by default; pass
``ssl_options=dict(cert_reqs=ssl.CERT_NONE)`` or a
suitably-configured `ssl.SSLContext` to disable.
"""
if (self._read_callback or self._read_future or
self._write_callback or self._write_future or
@ -1052,12 +1130,17 @@ class IOStream(BaseIOStream):
self._read_buffer or self._write_buffer):
raise ValueError("IOStream is not idle; cannot convert to SSL")
if ssl_options is None:
ssl_options = {}
if server_side:
ssl_options = _server_ssl_defaults
else:
ssl_options = _client_ssl_defaults
socket = self.socket
self.io_loop.remove_handler(socket)
self.socket = None
socket = ssl_wrap_socket(socket, ssl_options, server_side=server_side,
socket = ssl_wrap_socket(socket, ssl_options,
server_hostname=server_hostname,
server_side=server_side,
do_handshake_on_connect=False)
orig_close_callback = self._close_callback
self._close_callback = None
@ -1069,8 +1152,18 @@ class IOStream(BaseIOStream):
# If we had an "unwrap" counterpart to this method we would need
# to restore the original callback after our Future resolves
# so that repeated wrap/unwrap calls don't build up layers.
def close_callback():
if not future.done():
# Note that unlike most Futures returned by IOStream,
# this one passes the underlying error through directly
# instead of wrapping everything in a StreamClosedError
# with a real_error attribute. This is because once the
# connection is established it's more helpful to raise
# the SSLError directly than to hide it behind a
# StreamClosedError (and the client is expecting SSL
# issues rather than network issues since this method is
# named start_tls).
future.set_exception(ssl_stream.error or StreamClosedError())
if orig_close_callback is not None:
orig_close_callback()
@ -1113,7 +1206,7 @@ class IOStream(BaseIOStream):
# Sometimes setsockopt will fail if the socket is closed
# at the wrong time. This can happen with HTTPServer
# resetting the value to false between requests.
if e.errno not in (errno.EINVAL, errno.ECONNRESET):
if e.errno != errno.EINVAL and not self._is_connreset(e):
raise
@ -1129,11 +1222,11 @@ class SSLIOStream(IOStream):
wrapped when `IOStream.connect` is finished.
"""
def __init__(self, *args, **kwargs):
"""The ``ssl_options`` keyword argument may either be a dictionary
of keywords arguments for `ssl.wrap_socket`, or an `ssl.SSLContext`
object.
"""The ``ssl_options`` keyword argument may either be an
`ssl.SSLContext` object or a dictionary of keywords arguments
for `ssl.wrap_socket`
"""
self._ssl_options = kwargs.pop('ssl_options', {})
self._ssl_options = kwargs.pop('ssl_options', _client_ssl_defaults)
super(SSLIOStream, self).__init__(*args, **kwargs)
self._ssl_accepting = True
self._handshake_reading = False
@ -1184,8 +1277,14 @@ class SSLIOStream(IOStream):
return self.close(exc_info=True)
raise
except socket.error as err:
if err.args[0] in _ERRNO_CONNRESET:
# Some port scans (e.g. nmap in -sT mode) have been known
# to cause do_handshake to raise EBADF and ENOTCONN, so make
# those errors quiet as well.
# https://groups.google.com/forum/?fromgroups#!topic/python-tornado/ApucKJat1_0
if (self._is_connreset(err) or
err.args[0] in (errno.EBADF, errno.ENOTCONN)):
return self.close(exc_info=True)
raise
except AttributeError:
# On Linux, if the connection was reset before the call to
# wrap_socket, do_handshake will fail with an
@ -1196,10 +1295,17 @@ class SSLIOStream(IOStream):
if not self._verify_cert(self.socket.getpeercert()):
self.close()
return
if self._ssl_connect_callback is not None:
callback = self._ssl_connect_callback
self._ssl_connect_callback = None
self._run_callback(callback)
self._run_ssl_connect_callback()
def _run_ssl_connect_callback(self):
if self._ssl_connect_callback is not None:
callback = self._ssl_connect_callback
self._ssl_connect_callback = None
self._run_callback(callback)
if self._ssl_connect_future is not None:
future = self._ssl_connect_future
self._ssl_connect_future = None
future.set_result(self)
def _verify_cert(self, peercert):
"""Returns True if peercert is valid according to the configured
@ -1222,8 +1328,8 @@ class SSLIOStream(IOStream):
return False
try:
ssl_match_hostname(peercert, self._server_hostname)
except SSLCertificateError:
gen_log.warning("Invalid SSL certificate", exc_info=True)
except SSLCertificateError as e:
gen_log.warning("Invalid SSL certificate: %s" % e)
return False
else:
return True
@ -1241,14 +1347,11 @@ class SSLIOStream(IOStream):
super(SSLIOStream, self)._handle_write()
def connect(self, address, callback=None, server_hostname=None):
# Save the user's callback and run it after the ssl handshake
# has completed.
self._ssl_connect_callback = stack_context.wrap(callback)
self._server_hostname = server_hostname
# Note: Since we don't pass our callback argument along to
# super.connect(), this will always return a Future.
# This is harmless, but a bit less efficient than it could be.
return super(SSLIOStream, self).connect(address, callback=None)
# Pass a dummy callback to super.connect(), which is slightly
# more efficient than letting it return a Future we ignore.
super(SSLIOStream, self).connect(address, callback=lambda: None)
return self.wait_for_handshake(callback)
def _handle_connect(self):
# Call the superclass method to check for errors.
@ -1273,6 +1376,51 @@ class SSLIOStream(IOStream):
do_handshake_on_connect=False)
self._add_io_state(old_state)
def wait_for_handshake(self, callback=None):
"""Wait for the initial SSL handshake to complete.
If a ``callback`` is given, it will be called with no
arguments once the handshake is complete; otherwise this
method returns a `.Future` which will resolve to the
stream itself after the handshake is complete.
Once the handshake is complete, information such as
the peer's certificate and NPN/ALPN selections may be
accessed on ``self.socket``.
This method is intended for use on server-side streams
or after using `IOStream.start_tls`; it should not be used
with `IOStream.connect` (which already waits for the
handshake to complete). It may only be called once per stream.
.. versionadded:: 4.2
"""
if (self._ssl_connect_callback is not None or
self._ssl_connect_future is not None):
raise RuntimeError("Already waiting")
if callback is not None:
self._ssl_connect_callback = stack_context.wrap(callback)
future = None
else:
future = self._ssl_connect_future = TracebackFuture()
if not self._ssl_accepting:
self._run_ssl_connect_callback()
return future
def write_to_fd(self, data):
try:
return self.socket.send(data)
except ssl.SSLError as e:
if e.args[0] == ssl.SSL_ERROR_WANT_WRITE:
# In Python 3.5+, SSLSocket.send raises a WANT_WRITE error if
# the socket is not writeable; we need to transform this into
# an EWOULDBLOCK socket.error or a zero return value,
# either of which will be recognized by the caller of this
# method. Prior to Python 3.5, an unwriteable socket would
# simply return 0 bytes written.
return 0
raise
def read_from_fd(self):
if self._ssl_accepting:
# If the handshake hasn't finished yet, there can't be anything
@ -1303,6 +1451,11 @@ class SSLIOStream(IOStream):
return None
return chunk
def _is_connreset(self, e):
if isinstance(e, ssl.SSLError) and e.args[0] == ssl.SSL_ERROR_EOF:
return True
return super(SSLIOStream, self)._is_connreset(e)
class PipeIOStream(BaseIOStream):
"""Pipe-based `IOStream` implementation.

View file

@ -41,8 +41,10 @@ the `Locale.translate` method will simply return the original string.
from __future__ import absolute_import, division, print_function, with_statement
import codecs
import csv
import datetime
from io import BytesIO
import numbers
import os
import re
@ -51,10 +53,13 @@ from tornado import escape
from tornado.log import gen_log
from tornado.util import u
from tornado._locale_data import LOCALE_NAMES
_default_locale = "en_US"
_translations = {}
_supported_locales = frozenset([_default_locale])
_use_gettext = False
CONTEXT_SEPARATOR = "\x04"
def get(*locale_codes):
@ -85,7 +90,7 @@ def set_default_locale(code):
_supported_locales = frozenset(list(_translations.keys()) + [_default_locale])
def load_translations(directory):
def load_translations(directory, encoding=None):
"""Loads translations from CSV files in a directory.
Translations are strings with optional Python-style named placeholders
@ -105,12 +110,20 @@ def load_translations(directory):
The file is read using the `csv` module in the default "excel" dialect.
In this format there should not be spaces after the commas.
If no ``encoding`` parameter is given, the encoding will be
detected automatically (among UTF-8 and UTF-16) if the file
contains a byte-order marker (BOM), defaulting to UTF-8 if no BOM
is present.
Example translation ``es_LA.csv``::
"I love you","Te amo"
"%(name)s liked this","A %(name)s les gustó esto","plural"
"%(name)s liked this","A %(name)s le gustó esto","singular"
.. versionchanged:: 4.3
Added ``encoding`` parameter. Added support for BOM-based encoding
detection, UTF-16, and UTF-8-with-BOM.
"""
global _translations
global _supported_locales
@ -124,13 +137,29 @@ def load_translations(directory):
os.path.join(directory, path))
continue
full_path = os.path.join(directory, path)
if encoding is None:
# Try to autodetect encoding based on the BOM.
with open(full_path, 'rb') as f:
data = f.read(len(codecs.BOM_UTF16_LE))
if data in (codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE):
encoding = 'utf-16'
else:
# utf-8-sig is "utf-8 with optional BOM". It's discouraged
# in most cases but is common with CSV files because Excel
# cannot read utf-8 files without a BOM.
encoding = 'utf-8-sig'
try:
# python 3: csv.reader requires a file open in text mode.
# Force utf8 to avoid dependence on $LANG environment variable.
f = open(full_path, "r", encoding="utf-8")
f = open(full_path, "r", encoding=encoding)
except TypeError:
# python 2: files return byte strings, which are decoded below.
f = open(full_path, "r")
# python 2: csv can only handle byte strings (in ascii-compatible
# encodings), which we decode below. Transcode everything into
# utf8 before passing it to csv.reader.
f = BytesIO()
with codecs.open(full_path, "r", encoding=encoding) as infile:
f.write(escape.utf8(infile.read()))
f.seek(0)
_translations[locale] = {}
for i, row in enumerate(csv.reader(f)):
if not row or len(row) < 2:
@ -273,6 +302,9 @@ class Locale(object):
"""
raise NotImplementedError()
def pgettext(self, context, message, plural_message=None, count=None):
raise NotImplementedError()
def format_date(self, date, gmt_offset=0, relative=True, shorter=False,
full_format=False):
"""Formats the given date (which should be GMT).
@ -422,6 +454,11 @@ class CSVLocale(Locale):
message_dict = self.translations.get("unknown", {})
return message_dict.get(message, message)
def pgettext(self, context, message, plural_message=None, count=None):
if self.translations:
gen_log.warning('pgettext is not supported by CSVLocale')
return self.translate(message, plural_message, count)
class GettextLocale(Locale):
"""Locale implementation using the `gettext` module."""
@ -445,67 +482,40 @@ class GettextLocale(Locale):
else:
return self.gettext(message)
LOCALE_NAMES = {
"af_ZA": {"name_en": u("Afrikaans"), "name": u("Afrikaans")},
"am_ET": {"name_en": u("Amharic"), "name": u('\u12a0\u121b\u122d\u129b')},
"ar_AR": {"name_en": u("Arabic"), "name": u("\u0627\u0644\u0639\u0631\u0628\u064a\u0629")},
"bg_BG": {"name_en": u("Bulgarian"), "name": u("\u0411\u044a\u043b\u0433\u0430\u0440\u0441\u043a\u0438")},
"bn_IN": {"name_en": u("Bengali"), "name": u("\u09ac\u09be\u0982\u09b2\u09be")},
"bs_BA": {"name_en": u("Bosnian"), "name": u("Bosanski")},
"ca_ES": {"name_en": u("Catalan"), "name": u("Catal\xe0")},
"cs_CZ": {"name_en": u("Czech"), "name": u("\u010ce\u0161tina")},
"cy_GB": {"name_en": u("Welsh"), "name": u("Cymraeg")},
"da_DK": {"name_en": u("Danish"), "name": u("Dansk")},
"de_DE": {"name_en": u("German"), "name": u("Deutsch")},
"el_GR": {"name_en": u("Greek"), "name": u("\u0395\u03bb\u03bb\u03b7\u03bd\u03b9\u03ba\u03ac")},
"en_GB": {"name_en": u("English (UK)"), "name": u("English (UK)")},
"en_US": {"name_en": u("English (US)"), "name": u("English (US)")},
"es_ES": {"name_en": u("Spanish (Spain)"), "name": u("Espa\xf1ol (Espa\xf1a)")},
"es_LA": {"name_en": u("Spanish"), "name": u("Espa\xf1ol")},
"et_EE": {"name_en": u("Estonian"), "name": u("Eesti")},
"eu_ES": {"name_en": u("Basque"), "name": u("Euskara")},
"fa_IR": {"name_en": u("Persian"), "name": u("\u0641\u0627\u0631\u0633\u06cc")},
"fi_FI": {"name_en": u("Finnish"), "name": u("Suomi")},
"fr_CA": {"name_en": u("French (Canada)"), "name": u("Fran\xe7ais (Canada)")},
"fr_FR": {"name_en": u("French"), "name": u("Fran\xe7ais")},
"ga_IE": {"name_en": u("Irish"), "name": u("Gaeilge")},
"gl_ES": {"name_en": u("Galician"), "name": u("Galego")},
"he_IL": {"name_en": u("Hebrew"), "name": u("\u05e2\u05d1\u05e8\u05d9\u05ea")},
"hi_IN": {"name_en": u("Hindi"), "name": u("\u0939\u093f\u0928\u094d\u0926\u0940")},
"hr_HR": {"name_en": u("Croatian"), "name": u("Hrvatski")},
"hu_HU": {"name_en": u("Hungarian"), "name": u("Magyar")},
"id_ID": {"name_en": u("Indonesian"), "name": u("Bahasa Indonesia")},
"is_IS": {"name_en": u("Icelandic"), "name": u("\xcdslenska")},
"it_IT": {"name_en": u("Italian"), "name": u("Italiano")},
"ja_JP": {"name_en": u("Japanese"), "name": u("\u65e5\u672c\u8a9e")},
"ko_KR": {"name_en": u("Korean"), "name": u("\ud55c\uad6d\uc5b4")},
"lt_LT": {"name_en": u("Lithuanian"), "name": u("Lietuvi\u0173")},
"lv_LV": {"name_en": u("Latvian"), "name": u("Latvie\u0161u")},
"mk_MK": {"name_en": u("Macedonian"), "name": u("\u041c\u0430\u043a\u0435\u0434\u043e\u043d\u0441\u043a\u0438")},
"ml_IN": {"name_en": u("Malayalam"), "name": u("\u0d2e\u0d32\u0d2f\u0d3e\u0d33\u0d02")},
"ms_MY": {"name_en": u("Malay"), "name": u("Bahasa Melayu")},
"nb_NO": {"name_en": u("Norwegian (bokmal)"), "name": u("Norsk (bokm\xe5l)")},
"nl_NL": {"name_en": u("Dutch"), "name": u("Nederlands")},
"nn_NO": {"name_en": u("Norwegian (nynorsk)"), "name": u("Norsk (nynorsk)")},
"pa_IN": {"name_en": u("Punjabi"), "name": u("\u0a2a\u0a70\u0a1c\u0a3e\u0a2c\u0a40")},
"pl_PL": {"name_en": u("Polish"), "name": u("Polski")},
"pt_BR": {"name_en": u("Portuguese (Brazil)"), "name": u("Portugu\xeas (Brasil)")},
"pt_PT": {"name_en": u("Portuguese (Portugal)"), "name": u("Portugu\xeas (Portugal)")},
"ro_RO": {"name_en": u("Romanian"), "name": u("Rom\xe2n\u0103")},
"ru_RU": {"name_en": u("Russian"), "name": u("\u0420\u0443\u0441\u0441\u043a\u0438\u0439")},
"sk_SK": {"name_en": u("Slovak"), "name": u("Sloven\u010dina")},
"sl_SI": {"name_en": u("Slovenian"), "name": u("Sloven\u0161\u010dina")},
"sq_AL": {"name_en": u("Albanian"), "name": u("Shqip")},
"sr_RS": {"name_en": u("Serbian"), "name": u("\u0421\u0440\u043f\u0441\u043a\u0438")},
"sv_SE": {"name_en": u("Swedish"), "name": u("Svenska")},
"sw_KE": {"name_en": u("Swahili"), "name": u("Kiswahili")},
"ta_IN": {"name_en": u("Tamil"), "name": u("\u0ba4\u0bae\u0bbf\u0bb4\u0bcd")},
"te_IN": {"name_en": u("Telugu"), "name": u("\u0c24\u0c46\u0c32\u0c41\u0c17\u0c41")},
"th_TH": {"name_en": u("Thai"), "name": u("\u0e20\u0e32\u0e29\u0e32\u0e44\u0e17\u0e22")},
"tl_PH": {"name_en": u("Filipino"), "name": u("Filipino")},
"tr_TR": {"name_en": u("Turkish"), "name": u("T\xfcrk\xe7e")},
"uk_UA": {"name_en": u("Ukraini "), "name": u("\u0423\u043a\u0440\u0430\u0457\u043d\u0441\u044c\u043a\u0430")},
"vi_VN": {"name_en": u("Vietnamese"), "name": u("Ti\u1ebfng Vi\u1ec7t")},
"zh_CN": {"name_en": u("Chinese (Simplified)"), "name": u("\u4e2d\u6587(\u7b80\u4f53)")},
"zh_TW": {"name_en": u("Chinese (Traditional)"), "name": u("\u4e2d\u6587(\u7e41\u9ad4)")},
}
def pgettext(self, context, message, plural_message=None, count=None):
"""Allows to set context for translation, accepts plural forms.
Usage example::
pgettext("law", "right")
pgettext("good", "right")
Plural message example::
pgettext("organization", "club", "clubs", len(clubs))
pgettext("stick", "club", "clubs", len(clubs))
To generate POT file with context, add following options to step 1
of `load_gettext_translations` sequence::
xgettext [basic options] --keyword=pgettext:1c,2 --keyword=pgettext:1c,2,3
.. versionadded:: 4.2
"""
if plural_message is not None:
assert count is not None
msgs_with_ctxt = ("%s%s%s" % (context, CONTEXT_SEPARATOR, message),
"%s%s%s" % (context, CONTEXT_SEPARATOR, plural_message),
count)
result = self.ngettext(*msgs_with_ctxt)
if CONTEXT_SEPARATOR in result:
# Translation not found
result = self.ngettext(message, plural_message, count)
return result
else:
msg_with_ctxt = "%s%s%s" % (context, CONTEXT_SEPARATOR, message)
result = self.gettext(msg_with_ctxt)
if CONTEXT_SEPARATOR in result:
# Translation not found
result = message
return result

View file

@ -0,0 +1,512 @@
# Copyright 2015 The Tornado Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, with_statement
__all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore', 'Lock']
import collections
from tornado import gen, ioloop
from tornado.concurrent import Future
class _TimeoutGarbageCollector(object):
"""Base class for objects that periodically clean up timed-out waiters.
Avoids memory leak in a common pattern like:
while True:
yield condition.wait(short_timeout)
print('looping....')
"""
def __init__(self):
self._waiters = collections.deque() # Futures.
self._timeouts = 0
def _garbage_collect(self):
# Occasionally clear timed-out waiters.
self._timeouts += 1
if self._timeouts > 100:
self._timeouts = 0
self._waiters = collections.deque(
w for w in self._waiters if not w.done())
class Condition(_TimeoutGarbageCollector):
"""A condition allows one or more coroutines to wait until notified.
Like a standard `threading.Condition`, but does not need an underlying lock
that is acquired and released.
With a `Condition`, coroutines can wait to be notified by other coroutines:
.. testcode::
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.locks import Condition
condition = Condition()
@gen.coroutine
def waiter():
print("I'll wait right here")
yield condition.wait() # Yield a Future.
print("I'm done waiting")
@gen.coroutine
def notifier():
print("About to notify")
condition.notify()
print("Done notifying")
@gen.coroutine
def runner():
# Yield two Futures; wait for waiter() and notifier() to finish.
yield [waiter(), notifier()]
IOLoop.current().run_sync(runner)
.. testoutput::
I'll wait right here
About to notify
Done notifying
I'm done waiting
`wait` takes an optional ``timeout`` argument, which is either an absolute
timestamp::
io_loop = IOLoop.current()
# Wait up to 1 second for a notification.
yield condition.wait(timeout=io_loop.time() + 1)
...or a `datetime.timedelta` for a timeout relative to the current time::
# Wait up to 1 second.
yield condition.wait(timeout=datetime.timedelta(seconds=1))
The method raises `tornado.gen.TimeoutError` if there's no notification
before the deadline.
"""
def __init__(self):
super(Condition, self).__init__()
self.io_loop = ioloop.IOLoop.current()
def __repr__(self):
result = '<%s' % (self.__class__.__name__, )
if self._waiters:
result += ' waiters[%s]' % len(self._waiters)
return result + '>'
def wait(self, timeout=None):
"""Wait for `.notify`.
Returns a `.Future` that resolves ``True`` if the condition is notified,
or ``False`` after a timeout.
"""
waiter = Future()
self._waiters.append(waiter)
if timeout:
def on_timeout():
waiter.set_result(False)
self._garbage_collect()
io_loop = ioloop.IOLoop.current()
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
waiter.add_done_callback(
lambda _: io_loop.remove_timeout(timeout_handle))
return waiter
def notify(self, n=1):
"""Wake ``n`` waiters."""
waiters = [] # Waiters we plan to run right now.
while n and self._waiters:
waiter = self._waiters.popleft()
if not waiter.done(): # Might have timed out.
n -= 1
waiters.append(waiter)
for waiter in waiters:
waiter.set_result(True)
def notify_all(self):
"""Wake all waiters."""
self.notify(len(self._waiters))
class Event(object):
"""An event blocks coroutines until its internal flag is set to True.
Similar to `threading.Event`.
A coroutine can wait for an event to be set. Once it is set, calls to
``yield event.wait()`` will not block unless the event has been cleared:
.. testcode::
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.locks import Event
event = Event()
@gen.coroutine
def waiter():
print("Waiting for event")
yield event.wait()
print("Not waiting this time")
yield event.wait()
print("Done")
@gen.coroutine
def setter():
print("About to set the event")
event.set()
@gen.coroutine
def runner():
yield [waiter(), setter()]
IOLoop.current().run_sync(runner)
.. testoutput::
Waiting for event
About to set the event
Not waiting this time
Done
"""
def __init__(self):
self._future = Future()
def __repr__(self):
return '<%s %s>' % (
self.__class__.__name__, 'set' if self.is_set() else 'clear')
def is_set(self):
"""Return ``True`` if the internal flag is true."""
return self._future.done()
def set(self):
"""Set the internal flag to ``True``. All waiters are awakened.
Calling `.wait` once the flag is set will not block.
"""
if not self._future.done():
self._future.set_result(None)
def clear(self):
"""Reset the internal flag to ``False``.
Calls to `.wait` will block until `.set` is called.
"""
if self._future.done():
self._future = Future()
def wait(self, timeout=None):
"""Block until the internal flag is true.
Returns a Future, which raises `tornado.gen.TimeoutError` after a
timeout.
"""
if timeout is None:
return self._future
else:
return gen.with_timeout(timeout, self._future)
class _ReleasingContextManager(object):
"""Releases a Lock or Semaphore at the end of a "with" statement.
with (yield semaphore.acquire()):
pass
# Now semaphore.release() has been called.
"""
def __init__(self, obj):
self._obj = obj
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
self._obj.release()
class Semaphore(_TimeoutGarbageCollector):
"""A lock that can be acquired a fixed number of times before blocking.
A Semaphore manages a counter representing the number of `.release` calls
minus the number of `.acquire` calls, plus an initial value. The `.acquire`
method blocks if necessary until it can return without making the counter
negative.
Semaphores limit access to a shared resource. To allow access for two
workers at a time:
.. testsetup:: semaphore
from collections import deque
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.concurrent import Future
# Ensure reliable doctest output: resolve Futures one at a time.
futures_q = deque([Future() for _ in range(3)])
@gen.coroutine
def simulator(futures):
for f in futures:
yield gen.moment
f.set_result(None)
IOLoop.current().add_callback(simulator, list(futures_q))
def use_some_resource():
return futures_q.popleft()
.. testcode:: semaphore
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.locks import Semaphore
sem = Semaphore(2)
@gen.coroutine
def worker(worker_id):
yield sem.acquire()
try:
print("Worker %d is working" % worker_id)
yield use_some_resource()
finally:
print("Worker %d is done" % worker_id)
sem.release()
@gen.coroutine
def runner():
# Join all workers.
yield [worker(i) for i in range(3)]
IOLoop.current().run_sync(runner)
.. testoutput:: semaphore
Worker 0 is working
Worker 1 is working
Worker 0 is done
Worker 2 is working
Worker 1 is done
Worker 2 is done
Workers 0 and 1 are allowed to run concurrently, but worker 2 waits until
the semaphore has been released once, by worker 0.
`.acquire` is a context manager, so ``worker`` could be written as::
@gen.coroutine
def worker(worker_id):
with (yield sem.acquire()):
print("Worker %d is working" % worker_id)
yield use_some_resource()
# Now the semaphore has been released.
print("Worker %d is done" % worker_id)
In Python 3.5, the semaphore itself can be used as an async context
manager::
async def worker(worker_id):
async with sem:
print("Worker %d is working" % worker_id)
await use_some_resource()
# Now the semaphore has been released.
print("Worker %d is done" % worker_id)
.. versionchanged:: 4.3
Added ``async with`` support in Python 3.5.
"""
def __init__(self, value=1):
super(Semaphore, self).__init__()
if value < 0:
raise ValueError('semaphore initial value must be >= 0')
self._value = value
def __repr__(self):
res = super(Semaphore, self).__repr__()
extra = 'locked' if self._value == 0 else 'unlocked,value:{0}'.format(
self._value)
if self._waiters:
extra = '{0},waiters:{1}'.format(extra, len(self._waiters))
return '<{0} [{1}]>'.format(res[1:-1], extra)
def release(self):
"""Increment the counter and wake one waiter."""
self._value += 1
while self._waiters:
waiter = self._waiters.popleft()
if not waiter.done():
self._value -= 1
# If the waiter is a coroutine paused at
#
# with (yield semaphore.acquire()):
#
# then the context manager's __exit__ calls release() at the end
# of the "with" block.
waiter.set_result(_ReleasingContextManager(self))
break
def acquire(self, timeout=None):
"""Decrement the counter. Returns a Future.
Block if the counter is zero and wait for a `.release`. The Future
raises `.TimeoutError` after the deadline.
"""
waiter = Future()
if self._value > 0:
self._value -= 1
waiter.set_result(_ReleasingContextManager(self))
else:
self._waiters.append(waiter)
if timeout:
def on_timeout():
waiter.set_exception(gen.TimeoutError())
self._garbage_collect()
io_loop = ioloop.IOLoop.current()
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
waiter.add_done_callback(
lambda _: io_loop.remove_timeout(timeout_handle))
return waiter
def __enter__(self):
raise RuntimeError(
"Use Semaphore like 'with (yield semaphore.acquire())', not like"
" 'with semaphore'")
__exit__ = __enter__
@gen.coroutine
def __aenter__(self):
yield self.acquire()
@gen.coroutine
def __aexit__(self, typ, value, tb):
self.release()
class BoundedSemaphore(Semaphore):
"""A semaphore that prevents release() being called too many times.
If `.release` would increment the semaphore's value past the initial
value, it raises `ValueError`. Semaphores are mostly used to guard
resources with limited capacity, so a semaphore released too many times
is a sign of a bug.
"""
def __init__(self, value=1):
super(BoundedSemaphore, self).__init__(value=value)
self._initial_value = value
def release(self):
"""Increment the counter and wake one waiter."""
if self._value >= self._initial_value:
raise ValueError("Semaphore released too many times")
super(BoundedSemaphore, self).release()
class Lock(object):
"""A lock for coroutines.
A Lock begins unlocked, and `acquire` locks it immediately. While it is
locked, a coroutine that yields `acquire` waits until another coroutine
calls `release`.
Releasing an unlocked lock raises `RuntimeError`.
`acquire` supports the context manager protocol in all Python versions:
>>> from tornado import gen, locks
>>> lock = locks.Lock()
>>>
>>> @gen.coroutine
... def f():
... with (yield lock.acquire()):
... # Do something holding the lock.
... pass
...
... # Now the lock is released.
In Python 3.5, `Lock` also supports the async context manager
protocol. Note that in this case there is no `acquire`, because
``async with`` includes both the ``yield`` and the ``acquire``
(just as it does with `threading.Lock`):
>>> async def f(): # doctest: +SKIP
... async with lock:
... # Do something holding the lock.
... pass
...
... # Now the lock is released.
.. versionchanged:: 3.5
Added ``async with`` support in Python 3.5.
"""
def __init__(self):
self._block = BoundedSemaphore(value=1)
def __repr__(self):
return "<%s _block=%s>" % (
self.__class__.__name__,
self._block)
def acquire(self, timeout=None):
"""Attempt to lock. Returns a Future.
Returns a Future, which raises `tornado.gen.TimeoutError` after a
timeout.
"""
return self._block.acquire(timeout)
def release(self):
"""Unlock.
The first coroutine in line waiting for `acquire` gets the lock.
If not locked, raise a `RuntimeError`.
"""
try:
self._block.release()
except ValueError:
raise RuntimeError('release unlocked lock')
def __enter__(self):
raise RuntimeError(
"Use Lock like 'with (yield lock)', not like 'with lock'")
__exit__ = __enter__
@gen.coroutine
def __aenter__(self):
yield self.acquire()
@gen.coroutine
def __aexit__(self, typ, value, tb):
self.release()

View file

@ -190,10 +190,22 @@ def enable_pretty_logging(options=None, logger=None):
logger = logging.getLogger()
logger.setLevel(getattr(logging, options.logging.upper()))
if options.log_file_prefix:
channel = logging.handlers.RotatingFileHandler(
filename=options.log_file_prefix,
maxBytes=options.log_file_max_size,
backupCount=options.log_file_num_backups)
rotate_mode = options.log_rotate_mode
if rotate_mode == 'size':
channel = logging.handlers.RotatingFileHandler(
filename=options.log_file_prefix,
maxBytes=options.log_file_max_size,
backupCount=options.log_file_num_backups)
elif rotate_mode == 'time':
channel = logging.handlers.TimedRotatingFileHandler(
filename=options.log_file_prefix,
when=options.log_rotate_when,
interval=options.log_rotate_interval,
backupCount=options.log_file_num_backups)
else:
error_message = 'The value of log_rotate_mode option should be ' +\
'"size" or "time", not "%s".' % rotate_mode
raise ValueError(error_message)
channel.setFormatter(LogFormatter(color=False))
logger.addHandler(channel)
@ -206,6 +218,14 @@ def enable_pretty_logging(options=None, logger=None):
def define_logging_options(options=None):
"""Add logging-related flags to ``options``.
These options are present automatically on the default options instance;
this method is only necessary if you have created your own `.OptionParser`.
.. versionadded:: 4.2
This function existed in prior versions but was broken and undocumented until 4.2.
"""
if options is None:
# late import to prevent cycle
from tornado.options import options
@ -227,4 +247,13 @@ def define_logging_options(options=None):
options.define("log_file_num_backups", type=int, default=10,
help="number of log files to keep")
options.add_parse_callback(enable_pretty_logging)
options.define("log_rotate_when", type=str, default='midnight',
help=("specify the type of TimedRotatingFileHandler interval "
"other options:('S', 'M', 'H', 'D', 'W0'-'W6')"))
options.define("log_rotate_interval", type=int, default=1,
help="The interval value of timed rotating")
options.define("log_rotate_mode", type=str, default='size',
help="The mode of rotating files(time or size)")
options.add_parse_callback(lambda: enable_pretty_logging(options))

View file

@ -20,7 +20,7 @@ from __future__ import absolute_import, division, print_function, with_statement
import errno
import os
import platform
import sys
import socket
import stat
@ -35,6 +35,20 @@ except ImportError:
# ssl is not available on Google App Engine
ssl = None
try:
import certifi
except ImportError:
# certifi is optional as long as we have ssl.create_default_context.
if ssl is None or hasattr(ssl, 'create_default_context'):
certifi = None
else:
raise
try:
xrange # py2
except NameError:
xrange = range # py3
if hasattr(ssl, 'match_hostname') and hasattr(ssl, 'CertificateError'): # python 3.2+
ssl_match_hostname = ssl.match_hostname
SSLCertificateError = ssl.CertificateError
@ -45,6 +59,38 @@ else:
ssl_match_hostname = backports.ssl_match_hostname.match_hostname
SSLCertificateError = backports.ssl_match_hostname.CertificateError
if hasattr(ssl, 'SSLContext'):
if hasattr(ssl, 'create_default_context'):
# Python 2.7.9+, 3.4+
# Note that the naming of ssl.Purpose is confusing; the purpose
# of a context is to authentiate the opposite side of the connection.
_client_ssl_defaults = ssl.create_default_context(
ssl.Purpose.SERVER_AUTH)
_server_ssl_defaults = ssl.create_default_context(
ssl.Purpose.CLIENT_AUTH)
else:
# Python 3.2-3.3
_client_ssl_defaults = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
_client_ssl_defaults.verify_mode = ssl.CERT_REQUIRED
_client_ssl_defaults.load_verify_locations(certifi.where())
_server_ssl_defaults = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
if hasattr(ssl, 'OP_NO_COMPRESSION'):
# Disable TLS compression to avoid CRIME and related attacks.
# This constant wasn't added until python 3.3.
_client_ssl_defaults.options |= ssl.OP_NO_COMPRESSION
_server_ssl_defaults.options |= ssl.OP_NO_COMPRESSION
elif ssl:
# Python 2.6-2.7.8
_client_ssl_defaults = dict(cert_reqs=ssl.CERT_REQUIRED,
ca_certs=certifi.where())
_server_ssl_defaults = {}
else:
# Google App Engine
_client_ssl_defaults = dict(cert_reqs=None,
ca_certs=None)
_server_ssl_defaults = {}
# ThreadedResolver runs getaddrinfo on a thread. If the hostname is unicode,
# getaddrinfo attempts to import encodings.idna. If this is done at
# module-import time, the import lock is already held by the main thread,
@ -60,8 +106,12 @@ _ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN)
if hasattr(errno, "WSAEWOULDBLOCK"):
_ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,)
# Default backlog used when calling sock.listen()
_DEFAULT_BACKLOG = 128
def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128, flags=None):
def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
backlog=_DEFAULT_BACKLOG, flags=None, reuse_port=False):
"""Creates listening sockets bound to the given port and address.
Returns a list of socket objects (multiple sockets are returned if
@ -80,7 +130,14 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128, flags
``flags`` is a bitmask of AI_* flags to `~socket.getaddrinfo`, like
``socket.AI_PASSIVE | socket.AI_NUMERICHOST``.
``resuse_port`` option sets ``SO_REUSEPORT`` option for every socket
in the list. If your platform doesn't support this option ValueError will
be raised.
"""
if reuse_port and not hasattr(socket, "SO_REUSEPORT"):
raise ValueError("the platform doesn't support SO_REUSEPORT")
sockets = []
if address == "":
address = None
@ -97,7 +154,7 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128, flags
for res in set(socket.getaddrinfo(address, port, family, socket.SOCK_STREAM,
0, flags)):
af, socktype, proto, canonname, sockaddr = res
if (platform.system() == 'Darwin' and address == 'localhost' and
if (sys.platform == 'darwin' and address == 'localhost' and
af == socket.AF_INET6 and sockaddr[3] != 0):
# Mac OS X includes a link-local address fe80::1%lo0 in the
# getaddrinfo results for 'localhost'. However, the firewall
@ -115,6 +172,8 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128, flags
set_close_exec(sock.fileno())
if os.name != 'nt':
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if reuse_port:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
if af == socket.AF_INET6:
# On linux, ipv6 sockets accept ipv4 too by default,
# but this makes it impossible to bind to both
@ -141,7 +200,7 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128, flags
return sockets
if hasattr(socket, 'AF_UNIX'):
def bind_unix_socket(file, mode=0o600, backlog=128):
def bind_unix_socket(file, mode=0o600, backlog=_DEFAULT_BACKLOG):
"""Creates a listening unix socket.
If a socket with the given name already exists, it will be deleted.
@ -179,12 +238,26 @@ def add_accept_handler(sock, callback, io_loop=None):
address of the other end of the connection). Note that this signature
is different from the ``callback(fd, events)`` signature used for
`.IOLoop` handlers.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
if io_loop is None:
io_loop = IOLoop.current()
def accept_handler(fd, events):
while True:
# More connections may come in while we're handling callbacks;
# to prevent starvation of other tasks we must limit the number
# of connections we accept at a time. Ideally we would accept
# up to the number of connections that were waiting when we
# entered this method, but this information is not available
# (and rearranging this method to call accept() as many times
# as possible before running any callbacks would have adverse
# effects on load balancing in multiprocess configurations).
# Instead, we use the (default) listen backlog as a rough
# heuristic for the number of connections we can reasonably
# accept at once.
for i in xrange(_DEFAULT_BACKLOG):
try:
connection, address = sock.accept()
except socket.error as e:
@ -282,6 +355,9 @@ class ExecutorResolver(Resolver):
The executor will be shut down when the resolver is closed unless
``close_resolver=False``; use this if you want to reuse the same
executor elsewhere.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
def initialize(self, io_loop=None, executor=None, close_executor=True):
self.io_loop = io_loop or IOLoop.current()
@ -394,7 +470,7 @@ def ssl_options_to_context(ssl_options):
`~ssl.SSLContext` object.
The ``ssl_options`` dictionary contains keywords to be passed to
`ssl.wrap_socket`. In Python 3.2+, `ssl.SSLContext` objects can
`ssl.wrap_socket`. In Python 2.7.9+, `ssl.SSLContext` objects can
be used instead. This function converts the dict form to its
`~ssl.SSLContext` equivalent, and may be used when a component which
accepts both forms needs to upgrade to the `~ssl.SSLContext` version
@ -425,11 +501,11 @@ def ssl_options_to_context(ssl_options):
def ssl_wrap_socket(socket, ssl_options, server_hostname=None, **kwargs):
"""Returns an ``ssl.SSLSocket`` wrapping the given socket.
``ssl_options`` may be either a dictionary (as accepted by
`ssl_options_to_context`) or an `ssl.SSLContext` object.
Additional keyword arguments are passed to ``wrap_socket``
(either the `~ssl.SSLContext` method or the `ssl` module function
as appropriate).
``ssl_options`` may be either an `ssl.SSLContext` object or a
dictionary (as accepted by `ssl_options_to_context`). Additional
keyword arguments are passed to ``wrap_socket`` (either the
`~ssl.SSLContext` method or the `ssl` module function as
appropriate).
"""
context = ssl_options_to_context(ssl_options)
if hasattr(ssl, 'SSLContext') and isinstance(context, ssl.SSLContext):

View file

@ -68,6 +68,12 @@ instances to define isolated sets of options, such as for subcommands.
from tornado.options import options, parse_command_line
options.logging = None
parse_command_line()
.. versionchanged:: 4.3
Dashes and underscores are fully interchangeable in option names;
options can be defined, set, and read with any mix of the two.
Dashes are typical for command-line usage while config files require
underscores.
"""
from __future__ import absolute_import, division, print_function, with_statement
@ -79,7 +85,7 @@ import sys
import os
import textwrap
from tornado.escape import _unicode
from tornado.escape import _unicode, native_str
from tornado.log import define_logging_options
from tornado import stack_context
from tornado.util import basestring_type, exec_in
@ -103,28 +109,38 @@ class OptionParser(object):
self.define("help", type=bool, help="show this help information",
callback=self._help_callback)
def _normalize_name(self, name):
return name.replace('_', '-')
def __getattr__(self, name):
name = self._normalize_name(name)
if isinstance(self._options.get(name), _Option):
return self._options[name].value()
raise AttributeError("Unrecognized option %r" % name)
def __setattr__(self, name, value):
name = self._normalize_name(name)
if isinstance(self._options.get(name), _Option):
return self._options[name].set(value)
raise AttributeError("Unrecognized option %r" % name)
def __iter__(self):
return iter(self._options)
return (opt.name for opt in self._options.values())
def __getitem__(self, item):
return self._options[item].value()
def __contains__(self, name):
name = self._normalize_name(name)
return name in self._options
def __getitem__(self, name):
name = self._normalize_name(name)
return self._options[name].value()
def items(self):
"""A sequence of (name, value) pairs.
.. versionadded:: 3.1
"""
return [(name, opt.value()) for name, opt in self._options.items()]
return [(opt.name, opt.value()) for name, opt in self._options.items()]
def groups(self):
"""The set of option-groups created by ``define``.
@ -151,7 +167,7 @@ class OptionParser(object):
.. versionadded:: 3.1
"""
return dict(
(name, opt.value()) for name, opt in self._options.items()
(opt.name, opt.value()) for name, opt in self._options.items()
if not group or group == opt.group_name)
def as_dict(self):
@ -160,7 +176,7 @@ class OptionParser(object):
.. versionadded:: 3.1
"""
return dict(
(name, opt.value()) for name, opt in self._options.items())
(opt.name, opt.value()) for name, opt in self._options.items())
def define(self, name, default=None, type=None, help=None, metavar=None,
multiple=False, group=None, callback=None):
@ -204,6 +220,13 @@ class OptionParser(object):
(name, self._options[name].file_name))
frame = sys._getframe(0)
options_file = frame.f_code.co_filename
# Can be called directly, or through top level define() fn, in which
# case, step up above that frame to look for real caller.
if (frame.f_back.f_code.co_filename == options_file and
frame.f_back.f_code.co_name == 'define'):
frame = frame.f_back
file_name = frame.f_back.f_code.co_filename
if file_name == options_file:
file_name = ""
@ -216,11 +239,13 @@ class OptionParser(object):
group_name = group
else:
group_name = file_name
self._options[name] = _Option(name, file_name=file_name,
default=default, type=type, help=help,
metavar=metavar, multiple=multiple,
group_name=group_name,
callback=callback)
normalized = self._normalize_name(name)
option = _Option(name, file_name=file_name,
default=default, type=type, help=help,
metavar=metavar, multiple=multiple,
group_name=group_name,
callback=callback)
self._options[normalized] = option
def parse_command_line(self, args=None, final=True):
"""Parses all options given on the command line (defaults to
@ -248,8 +273,8 @@ class OptionParser(object):
break
arg = args[i].lstrip("-")
name, equals, value = arg.partition("=")
name = name.replace('-', '_')
if not name in self._options:
name = self._normalize_name(name)
if name not in self._options:
self.print_help()
raise Error('Unrecognized command line option: %r' % name)
option = self._options[name]
@ -271,13 +296,18 @@ class OptionParser(object):
If ``final`` is ``False``, parse callbacks will not be run.
This is useful for applications that wish to combine configurations
from multiple sources.
.. versionchanged:: 4.1
Config files are now always interpreted as utf-8 instead of
the system default encoding.
"""
config = {}
with open(path) as f:
exec_in(f.read(), config, config)
with open(path, 'rb') as f:
exec_in(native_str(f.read()), config, config)
for name in config:
if name in self._options:
self._options[name].set(config[name])
normalized = self._normalize_name(name)
if normalized in self._options:
self._options[normalized].set(config[name])
if final:
self.run_parse_callbacks()
@ -297,7 +327,8 @@ class OptionParser(object):
print("\n%s options:\n" % os.path.normpath(filename), file=file)
o.sort(key=lambda option: option.name)
for option in o:
prefix = option.name
# Always print names with dashes in a CLI context.
prefix = self._normalize_name(option.name)
if option.metavar:
prefix += "=" + option.metavar
description = option.help or ""
@ -456,19 +487,17 @@ class _Option(object):
pass
raise Error('Unrecognized date/time format: %r' % value)
_TIMEDELTA_ABBREVS = [
('hours', ['h']),
('minutes', ['m', 'min']),
('seconds', ['s', 'sec']),
('milliseconds', ['ms']),
('microseconds', ['us']),
('days', ['d']),
('weeks', ['w']),
]
_TIMEDELTA_ABBREV_DICT = dict(
(abbrev, full) for full, abbrevs in _TIMEDELTA_ABBREVS
for abbrev in abbrevs)
_TIMEDELTA_ABBREV_DICT = {
'h': 'hours',
'm': 'minutes',
'min': 'minutes',
's': 'seconds',
'sec': 'seconds',
'ms': 'milliseconds',
'us': 'microseconds',
'd': 'days',
'w': 'weeks',
}
_FLOAT_PATTERN = r'[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?'

View file

@ -1,21 +1,31 @@
"""Bridges between the `asyncio` module and Tornado IOLoop.
This is a work in progress and interfaces are subject to change.
.. versionadded:: 3.2
To test:
python3.4 -m tornado.test.runtests --ioloop=tornado.platform.asyncio.AsyncIOLoop
python3.4 -m tornado.test.runtests --ioloop=tornado.platform.asyncio.AsyncIOMainLoop
(the tests log a few warnings with AsyncIOMainLoop because they leave some
unfinished callbacks on the event loop that fail when it resumes)
This module integrates Tornado with the ``asyncio`` module introduced
in Python 3.4 (and available `as a separate download
<https://pypi.python.org/pypi/asyncio>`_ for Python 3.3). This makes
it possible to combine the two libraries on the same event loop.
Most applications should use `AsyncIOMainLoop` to run Tornado on the
default ``asyncio`` event loop. Applications that need to run event
loops on multiple threads may use `AsyncIOLoop` to create multiple
loops.
.. note::
Tornado requires the `~asyncio.BaseEventLoop.add_reader` family of methods,
so it is not compatible with the `~asyncio.ProactorEventLoop` on Windows.
Use the `~asyncio.SelectorEventLoop` instead.
"""
from __future__ import absolute_import, division, print_function, with_statement
import datetime
import functools
import tornado.concurrent
from tornado.gen import convert_yielded
from tornado.ioloop import IOLoop
from tornado import stack_context
from tornado.util import timedelta_to_seconds
try:
# Import the real asyncio module for py33+ first. Older versions of the
@ -29,11 +39,12 @@ except ImportError as e:
# Re-raise the original asyncio error, not the trollius one.
raise e
class BaseAsyncIOLoop(IOLoop):
def initialize(self, asyncio_loop, close_loop=False):
def initialize(self, asyncio_loop, close_loop=False, **kwargs):
super(BaseAsyncIOLoop, self).initialize(**kwargs)
self.asyncio_loop = asyncio_loop
self.close_loop = close_loop
self.asyncio_loop.call_soon(self.make_current)
# Maps fd to (fileobj, handler function) pair (as in IOLoop.add_handler)
self.handlers = {}
# Set of fds listening for reads/writes
@ -103,8 +114,16 @@ class BaseAsyncIOLoop(IOLoop):
handler_func(fileobj, events)
def start(self):
self._setup_logging()
self.asyncio_loop.run_forever()
old_current = IOLoop.current(instance=False)
try:
self._setup_logging()
self.make_current()
self.asyncio_loop.run_forever()
finally:
if old_current is None:
IOLoop.clear_current()
else:
old_current.make_current()
def stop(self):
self.asyncio_loop.stop()
@ -131,12 +150,67 @@ class BaseAsyncIOLoop(IOLoop):
class AsyncIOMainLoop(BaseAsyncIOLoop):
def initialize(self):
"""``AsyncIOMainLoop`` creates an `.IOLoop` that corresponds to the
current ``asyncio`` event loop (i.e. the one returned by
``asyncio.get_event_loop()``). Recommended usage::
from tornado.platform.asyncio import AsyncIOMainLoop
import asyncio
AsyncIOMainLoop().install()
asyncio.get_event_loop().run_forever()
"""
def initialize(self, **kwargs):
super(AsyncIOMainLoop, self).initialize(asyncio.get_event_loop(),
close_loop=False)
close_loop=False, **kwargs)
class AsyncIOLoop(BaseAsyncIOLoop):
def initialize(self):
super(AsyncIOLoop, self).initialize(asyncio.new_event_loop(),
close_loop=True)
"""``AsyncIOLoop`` is an `.IOLoop` that runs on an ``asyncio`` event loop.
This class follows the usual Tornado semantics for creating new
``IOLoops``; these loops are not necessarily related to the
``asyncio`` default event loop. Recommended usage::
from tornado.ioloop import IOLoop
IOLoop.configure('tornado.platform.asyncio.AsyncIOLoop')
IOLoop.current().start()
Each ``AsyncIOLoop`` creates a new ``asyncio.EventLoop``; this object
can be accessed with the ``asyncio_loop`` attribute.
"""
def initialize(self, **kwargs):
loop = asyncio.new_event_loop()
try:
super(AsyncIOLoop, self).initialize(loop, close_loop=True, **kwargs)
except Exception:
# If initialize() does not succeed (taking ownership of the loop),
# we have to close it.
loop.close()
raise
def to_tornado_future(asyncio_future):
"""Convert an `asyncio.Future` to a `tornado.concurrent.Future`.
.. versionadded:: 4.1
"""
tf = tornado.concurrent.Future()
tornado.concurrent.chain_future(asyncio_future, tf)
return tf
def to_asyncio_future(tornado_future):
"""Convert a Tornado yieldable object to an `asyncio.Future`.
.. versionadded:: 4.1
.. versionchanged:: 4.3
Now accepts any yieldable object, not just
`tornado.concurrent.Future`.
"""
tornado_future = convert_yielded(tornado_future)
af = asyncio.Future()
tornado.concurrent.chain_future(tornado_future, af)
return af
if hasattr(convert_yielded, 'register'):
convert_yielded.register(asyncio.Future, to_tornado_future)

View file

@ -27,13 +27,14 @@ from __future__ import absolute_import, division, print_function, with_statement
import os
if os.name == 'nt':
from tornado.platform.common import Waker
from tornado.platform.windows import set_close_exec
elif 'APPENGINE_RUNTIME' in os.environ:
if 'APPENGINE_RUNTIME' in os.environ:
from tornado.platform.common import Waker
def set_close_exec(fd):
pass
elif os.name == 'nt':
from tornado.platform.common import Waker
from tornado.platform.windows import set_close_exec
else:
from tornado.platform.posix import set_close_exec, Waker
@ -41,9 +42,13 @@ try:
# monotime monkey-patches the time module to have a monotonic function
# in versions of python before 3.3.
import monotime
# Silence pyflakes warning about this unused import
monotime
except ImportError:
pass
try:
from time import monotonic as monotonic_time
except ImportError:
monotonic_time = None
__all__ = ['Waker', 'set_close_exec', 'monotonic_time']

View file

@ -18,6 +18,9 @@ class CaresResolver(Resolver):
so it is only recommended for use in ``AF_INET`` (i.e. IPv4). This is
the default for ``tornado.simple_httpclient``, but other libraries
may default to ``AF_UNSPEC``.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
def initialize(self, io_loop=None):
self.io_loop = io_loop or IOLoop.current()

View file

@ -54,8 +54,7 @@ class _KQueue(object):
if events & IOLoop.WRITE:
kevents.append(select.kevent(
fd, filter=select.KQ_FILTER_WRITE, flags=flags))
if events & IOLoop.READ or not kevents:
# Always read when there is not a write
if events & IOLoop.READ:
kevents.append(select.kevent(
fd, filter=select.KQ_FILTER_READ, flags=flags))
# Even though control() takes a list, it seems to return EINVAL

View file

@ -47,7 +47,7 @@ class _Select(object):
# Closed connections are reported as errors by epoll and kqueue,
# but as zero-byte reads by select, so when errors are requested
# we need to listen for both read and error.
self.read_fds.add(fd)
# self.read_fds.add(fd)
def modify(self, fd, events):
self.unregister(fd)

View file

@ -12,10 +12,6 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
# Note: This module's docs are not currently extracted automatically,
# so changes must be made manually to twisted.rst
# TODO: refactor doc build process to use an appropriate virtualenv
"""Bridges between the Twisted reactor and Tornado IOLoop.
This module lets you run applications and libraries written for
@ -23,45 +19,6 @@ Twisted in a Tornado application. It can be used in two modes,
depending on which library's underlying event loop you want to use.
This module has been tested with Twisted versions 11.0.0 and newer.
Twisted on Tornado
------------------
`TornadoReactor` implements the Twisted reactor interface on top of
the Tornado IOLoop. To use it, simply call `install` at the beginning
of the application::
import tornado.platform.twisted
tornado.platform.twisted.install()
from twisted.internet import reactor
When the app is ready to start, call `IOLoop.instance().start()`
instead of `reactor.run()`.
It is also possible to create a non-global reactor by calling
`tornado.platform.twisted.TornadoReactor(io_loop)`. However, if
the `IOLoop` and reactor are to be short-lived (such as those used in
unit tests), additional cleanup may be required. Specifically, it is
recommended to call::
reactor.fireSystemEvent('shutdown')
reactor.disconnectAll()
before closing the `IOLoop`.
Tornado on Twisted
------------------
`TwistedIOLoop` implements the Tornado IOLoop interface on top of the Twisted
reactor. Recommended usage::
from tornado.platform.twisted import TwistedIOLoop
from twisted.internet import reactor
TwistedIOLoop().install()
# Set up your tornado application as usual using `IOLoop.instance`
reactor.run()
`TwistedIOLoop` always uses the global Twisted reactor.
"""
from __future__ import absolute_import, division, print_function, with_statement
@ -70,8 +27,10 @@ import datetime
import functools
import numbers
import socket
import sys
import twisted.internet.abstract
from twisted.internet.defer import Deferred
from twisted.internet.posixbase import PosixReactorBase
from twisted.internet.interfaces import \
IReactorFDSet, IDelayedCall, IReactorTime, IReadDescriptor, IWriteDescriptor
@ -84,6 +43,7 @@ import twisted.names.resolve
from zope.interface import implementer
from tornado.concurrent import Future
from tornado.escape import utf8
from tornado import gen
import tornado.ioloop
@ -141,12 +101,30 @@ class TornadoDelayedCall(object):
class TornadoReactor(PosixReactorBase):
"""Twisted reactor built on the Tornado IOLoop.
Since it is intented to be used in applications where the top-level
event loop is ``io_loop.start()`` rather than ``reactor.run()``,
it is implemented a little differently than other Twisted reactors.
We override `mainLoop` instead of `doIteration` and must implement
timed call functionality on top of `IOLoop.add_timeout` rather than
using the implementation in `PosixReactorBase`.
`TornadoReactor` implements the Twisted reactor interface on top of
the Tornado IOLoop. To use it, simply call `install` at the beginning
of the application::
import tornado.platform.twisted
tornado.platform.twisted.install()
from twisted.internet import reactor
When the app is ready to start, call ``IOLoop.current().start()``
instead of ``reactor.run()``.
It is also possible to create a non-global reactor by calling
``tornado.platform.twisted.TornadoReactor(io_loop)``. However, if
the `.IOLoop` and reactor are to be short-lived (such as those used in
unit tests), additional cleanup may be required. Specifically, it is
recommended to call::
reactor.fireSystemEvent('shutdown')
reactor.disconnectAll()
before closing the `.IOLoop`.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
def __init__(self, io_loop=None):
if not io_loop:
@ -185,7 +163,6 @@ class TornadoReactor(PosixReactorBase):
# IReactorThreads
def callFromThread(self, f, *args, **kw):
"""See `twisted.internet.interfaces.IReactorThreads.callFromThread`"""
assert callable(f), "%s is not callable" % f
with NullContext():
# This NullContext is mainly for an edge case when running
@ -231,7 +208,6 @@ class TornadoReactor(PosixReactorBase):
writer.writeConnectionLost(failure.Failure(err))
def addReader(self, reader):
"""Add a FileDescriptor for notification of data available to read."""
if reader in self._readers:
# Don't add the reader if it's already there
return
@ -251,7 +227,6 @@ class TornadoReactor(PosixReactorBase):
IOLoop.READ)
def addWriter(self, writer):
"""Add a FileDescriptor for notification of data available to write."""
if writer in self._writers:
return
fd = writer.fileno()
@ -270,7 +245,6 @@ class TornadoReactor(PosixReactorBase):
IOLoop.WRITE)
def removeReader(self, reader):
"""Remove a Selectable for notification of data available to read."""
if reader in self._readers:
fd = self._readers.pop(reader)
(_, writer) = self._fds[fd]
@ -287,7 +261,6 @@ class TornadoReactor(PosixReactorBase):
self._io_loop.remove_handler(fd)
def removeWriter(self, writer):
"""Remove a Selectable for notification of data available to write."""
if writer in self._writers:
fd = self._writers.pop(writer)
(reader, _) = self._fds[fd]
@ -328,6 +301,14 @@ class TornadoReactor(PosixReactorBase):
raise NotImplementedError("doIteration")
def mainLoop(self):
# Since this class is intended to be used in applications
# where the top-level event loop is ``io_loop.start()`` rather
# than ``reactor.run()``, it is implemented a little
# differently than other Twisted reactors. We override
# ``mainLoop`` instead of ``doIteration`` and must implement
# timed call functionality on top of `.IOLoop.add_timeout`
# rather than using the implementation in
# ``PosixReactorBase``.
self._io_loop.start()
@ -356,7 +337,20 @@ class _TestReactor(TornadoReactor):
def install(io_loop=None):
"""Install this package as the default Twisted reactor."""
"""Install this package as the default Twisted reactor.
``install()`` must be called very early in the startup process,
before most other twisted-related imports. Conversely, because it
initializes the `.IOLoop`, it cannot be called before
`.fork_processes` or multi-process `~.TCPServer.start`. These
conflicting requirements make it difficult to use `.TornadoReactor`
in multi-process mode, and an external process manager such as
``supervisord`` is recommended instead.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
if not io_loop:
io_loop = tornado.ioloop.IOLoop.current()
reactor = TornadoReactor(io_loop)
@ -398,21 +392,30 @@ class _FD(object):
class TwistedIOLoop(tornado.ioloop.IOLoop):
"""IOLoop implementation that runs on Twisted.
`TwistedIOLoop` implements the Tornado IOLoop interface on top of
the Twisted reactor. Recommended usage::
from tornado.platform.twisted import TwistedIOLoop
from twisted.internet import reactor
TwistedIOLoop().install()
# Set up your tornado application as usual using `IOLoop.instance`
reactor.run()
Uses the global Twisted reactor by default. To create multiple
`TwistedIOLoops` in the same process, you must pass a unique reactor
``TwistedIOLoops`` in the same process, you must pass a unique reactor
when constructing each one.
Not compatible with `tornado.process.Subprocess.set_exit_callback`
because the ``SIGCHLD`` handlers used by Tornado and Twisted conflict
with each other.
"""
def initialize(self, reactor=None):
def initialize(self, reactor=None, **kwargs):
super(TwistedIOLoop, self).initialize(**kwargs)
if reactor is None:
import twisted.internet.reactor
reactor = twisted.internet.reactor
self.reactor = reactor
self.fds = {}
self.reactor.callWhenRunning(self.make_current)
def close(self, all_fds=False):
fds = self.fds
@ -466,8 +469,16 @@ class TwistedIOLoop(tornado.ioloop.IOLoop):
del self.fds[fd]
def start(self):
self._setup_logging()
self.reactor.run()
old_current = IOLoop.current(instance=False)
try:
self._setup_logging()
self.make_current()
self.reactor.run()
finally:
if old_current is None:
IOLoop.clear_current()
else:
old_current.make_current()
def stop(self):
self.reactor.crash()
@ -512,6 +523,9 @@ class TwistedResolver(Resolver):
``socket.AF_UNSPEC``.
Requires Twisted 12.1 or newer.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
def initialize(self, io_loop=None):
self.io_loop = io_loop or IOLoop.current()
@ -554,3 +568,18 @@ class TwistedResolver(Resolver):
(resolved_family, (resolved, port)),
]
raise gen.Return(result)
if hasattr(gen.convert_yielded, 'register'):
@gen.convert_yielded.register(Deferred)
def _(d):
f = Future()
def errback(failure):
try:
failure.raiseException()
# Should never happen, but just in case
raise Exception("errback called without error")
except:
f.set_exc_info(sys.exc_info())
d.addCallbacks(f.set_result, errback)
return f

View file

@ -29,6 +29,7 @@ import time
from binascii import hexlify
from tornado.concurrent import Future
from tornado import ioloop
from tornado.iostream import PipeIOStream
from tornado.log import gen_log
@ -39,7 +40,7 @@ from tornado.util import errno_from_exception
try:
import multiprocessing
except ImportError:
# Multiprocessing is not availble on Google App Engine.
# Multiprocessing is not available on Google App Engine.
multiprocessing = None
try:
@ -48,6 +49,17 @@ except NameError:
long = int # py3
# Re-export this exception for convenience.
try:
CalledProcessError = subprocess.CalledProcessError
except AttributeError:
# The subprocess module exists in Google App Engine, but is empty.
# This module isn't very useful in that case, but it should
# at least be importable.
if 'APPENGINE_RUNTIME' not in os.environ:
raise
def cpu_count():
"""Returns the number of processors on this machine."""
if multiprocessing is None:
@ -191,6 +203,9 @@ class Subprocess(object):
``tornado.process.Subprocess.STREAM``, which will make the corresponding
attribute of the resulting Subprocess a `.PipeIOStream`.
* A new keyword argument ``io_loop`` may be used to pass in an IOLoop.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
STREAM = object()
@ -240,7 +255,7 @@ class Subprocess(object):
The callback takes one argument, the return code of the process.
This method uses a ``SIGCHILD`` handler, which is a global setting
This method uses a ``SIGCHLD`` handler, which is a global setting
and may conflict if you have other libraries trying to handle the
same signal. If you are using more than one ``IOLoop`` it may
be necessary to call `Subprocess.initialize` first to designate
@ -255,14 +270,44 @@ class Subprocess(object):
Subprocess._waiting[self.pid] = self
Subprocess._try_cleanup_process(self.pid)
def wait_for_exit(self, raise_error=True):
"""Returns a `.Future` which resolves when the process exits.
Usage::
ret = yield proc.wait_for_exit()
This is a coroutine-friendly alternative to `set_exit_callback`
(and a replacement for the blocking `subprocess.Popen.wait`).
By default, raises `subprocess.CalledProcessError` if the process
has a non-zero exit status. Use ``wait_for_exit(raise_error=False)``
to suppress this behavior and return the exit status without raising.
.. versionadded:: 4.2
"""
future = Future()
def callback(ret):
if ret != 0 and raise_error:
# Unfortunately we don't have the original args any more.
future.set_exception(CalledProcessError(ret, None))
else:
future.set_result(ret)
self.set_exit_callback(callback)
return future
@classmethod
def initialize(cls, io_loop=None):
"""Initializes the ``SIGCHILD`` handler.
"""Initializes the ``SIGCHLD`` handler.
The signal handler is run on an `.IOLoop` to avoid locking issues.
Note that the `.IOLoop` used for signal handling need not be the
same one used by individual Subprocess objects (as long as the
``IOLoops`` are each running in separate threads).
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
if cls._initialized:
return
@ -275,7 +320,7 @@ class Subprocess(object):
@classmethod
def uninitialize(cls):
"""Removes the ``SIGCHILD`` handler."""
"""Removes the ``SIGCHLD`` handler."""
if not cls._initialized:
return
signal.signal(signal.SIGCHLD, cls._old_sigchld)

View file

@ -0,0 +1,357 @@
# Copyright 2015 The Tornado Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, with_statement
__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty']
import collections
import heapq
from tornado import gen, ioloop
from tornado.concurrent import Future
from tornado.locks import Event
class QueueEmpty(Exception):
"""Raised by `.Queue.get_nowait` when the queue has no items."""
pass
class QueueFull(Exception):
"""Raised by `.Queue.put_nowait` when a queue is at its maximum size."""
pass
def _set_timeout(future, timeout):
if timeout:
def on_timeout():
future.set_exception(gen.TimeoutError())
io_loop = ioloop.IOLoop.current()
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
future.add_done_callback(
lambda _: io_loop.remove_timeout(timeout_handle))
class _QueueIterator(object):
def __init__(self, q):
self.q = q
def __anext__(self):
return self.q.get()
class Queue(object):
"""Coordinate producer and consumer coroutines.
If maxsize is 0 (the default) the queue size is unbounded.
.. testcode::
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.queues import Queue
q = Queue(maxsize=2)
@gen.coroutine
def consumer():
while True:
item = yield q.get()
try:
print('Doing work on %s' % item)
yield gen.sleep(0.01)
finally:
q.task_done()
@gen.coroutine
def producer():
for item in range(5):
yield q.put(item)
print('Put %s' % item)
@gen.coroutine
def main():
# Start consumer without waiting (since it never finishes).
IOLoop.current().spawn_callback(consumer)
yield producer() # Wait for producer to put all tasks.
yield q.join() # Wait for consumer to finish all tasks.
print('Done')
IOLoop.current().run_sync(main)
.. testoutput::
Put 0
Put 1
Doing work on 0
Put 2
Doing work on 1
Put 3
Doing work on 2
Put 4
Doing work on 3
Doing work on 4
Done
In Python 3.5, `Queue` implements the async iterator protocol, so
``consumer()`` could be rewritten as::
async def consumer():
async for item in q:
try:
print('Doing work on %s' % item)
yield gen.sleep(0.01)
finally:
q.task_done()
.. versionchanged:: 4.3
Added ``async for`` support in Python 3.5.
"""
def __init__(self, maxsize=0):
if maxsize is None:
raise TypeError("maxsize can't be None")
if maxsize < 0:
raise ValueError("maxsize can't be negative")
self._maxsize = maxsize
self._init()
self._getters = collections.deque([]) # Futures.
self._putters = collections.deque([]) # Pairs of (item, Future).
self._unfinished_tasks = 0
self._finished = Event()
self._finished.set()
@property
def maxsize(self):
"""Number of items allowed in the queue."""
return self._maxsize
def qsize(self):
"""Number of items in the queue."""
return len(self._queue)
def empty(self):
return not self._queue
def full(self):
if self.maxsize == 0:
return False
else:
return self.qsize() >= self.maxsize
def put(self, item, timeout=None):
"""Put an item into the queue, perhaps waiting until there is room.
Returns a Future, which raises `tornado.gen.TimeoutError` after a
timeout.
"""
try:
self.put_nowait(item)
except QueueFull:
future = Future()
self._putters.append((item, future))
_set_timeout(future, timeout)
return future
else:
return gen._null_future
def put_nowait(self, item):
"""Put an item into the queue without blocking.
If no free slot is immediately available, raise `QueueFull`.
"""
self._consume_expired()
if self._getters:
assert self.empty(), "queue non-empty, why are getters waiting?"
getter = self._getters.popleft()
self.__put_internal(item)
getter.set_result(self._get())
elif self.full():
raise QueueFull
else:
self.__put_internal(item)
def get(self, timeout=None):
"""Remove and return an item from the queue.
Returns a Future which resolves once an item is available, or raises
`tornado.gen.TimeoutError` after a timeout.
"""
future = Future()
try:
future.set_result(self.get_nowait())
except QueueEmpty:
self._getters.append(future)
_set_timeout(future, timeout)
return future
def get_nowait(self):
"""Remove and return an item from the queue without blocking.
Return an item if one is immediately available, else raise
`QueueEmpty`.
"""
self._consume_expired()
if self._putters:
assert self.full(), "queue not full, why are putters waiting?"
item, putter = self._putters.popleft()
self.__put_internal(item)
putter.set_result(None)
return self._get()
elif self.qsize():
return self._get()
else:
raise QueueEmpty
def task_done(self):
"""Indicate that a formerly enqueued task is complete.
Used by queue consumers. For each `.get` used to fetch a task, a
subsequent call to `.task_done` tells the queue that the processing
on the task is complete.
If a `.join` is blocking, it resumes when all items have been
processed; that is, when every `.put` is matched by a `.task_done`.
Raises `ValueError` if called more times than `.put`.
"""
if self._unfinished_tasks <= 0:
raise ValueError('task_done() called too many times')
self._unfinished_tasks -= 1
if self._unfinished_tasks == 0:
self._finished.set()
def join(self, timeout=None):
"""Block until all items in the queue are processed.
Returns a Future, which raises `tornado.gen.TimeoutError` after a
timeout.
"""
return self._finished.wait(timeout)
@gen.coroutine
def __aiter__(self):
return _QueueIterator(self)
# These three are overridable in subclasses.
def _init(self):
self._queue = collections.deque()
def _get(self):
return self._queue.popleft()
def _put(self, item):
self._queue.append(item)
# End of the overridable methods.
def __put_internal(self, item):
self._unfinished_tasks += 1
self._finished.clear()
self._put(item)
def _consume_expired(self):
# Remove timed-out waiters.
while self._putters and self._putters[0][1].done():
self._putters.popleft()
while self._getters and self._getters[0].done():
self._getters.popleft()
def __repr__(self):
return '<%s at %s %s>' % (
type(self).__name__, hex(id(self)), self._format())
def __str__(self):
return '<%s %s>' % (type(self).__name__, self._format())
def _format(self):
result = 'maxsize=%r' % (self.maxsize, )
if getattr(self, '_queue', None):
result += ' queue=%r' % self._queue
if self._getters:
result += ' getters[%s]' % len(self._getters)
if self._putters:
result += ' putters[%s]' % len(self._putters)
if self._unfinished_tasks:
result += ' tasks=%s' % self._unfinished_tasks
return result
class PriorityQueue(Queue):
"""A `.Queue` that retrieves entries in priority order, lowest first.
Entries are typically tuples like ``(priority number, data)``.
.. testcode::
from tornado.queues import PriorityQueue
q = PriorityQueue()
q.put((1, 'medium-priority item'))
q.put((0, 'high-priority item'))
q.put((10, 'low-priority item'))
print(q.get_nowait())
print(q.get_nowait())
print(q.get_nowait())
.. testoutput::
(0, 'high-priority item')
(1, 'medium-priority item')
(10, 'low-priority item')
"""
def _init(self):
self._queue = []
def _put(self, item):
heapq.heappush(self._queue, item)
def _get(self):
return heapq.heappop(self._queue)
class LifoQueue(Queue):
"""A `.Queue` that retrieves the most recently put items first.
.. testcode::
from tornado.queues import LifoQueue
q = LifoQueue()
q.put(3)
q.put(2)
q.put(1)
print(q.get_nowait())
print(q.get_nowait())
print(q.get_nowait())
.. testoutput::
1
2
3
"""
def _init(self):
self._queue = []
def _put(self, item):
self._queue.append(item)
def _get(self):
return self._queue.pop()

View file

@ -1,13 +1,13 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
from tornado.concurrent import is_future
from tornado.escape import utf8, _unicode
from tornado import gen
from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main, _RequestProxy
from tornado import httputil
from tornado.http1connection import HTTP1Connection, HTTP1ConnectionParameters
from tornado.iostream import StreamClosedError
from tornado.netutil import Resolver, OverrideResolver
from tornado.netutil import Resolver, OverrideResolver, _client_ssl_defaults
from tornado.log import gen_log
from tornado import stack_context
from tornado.tcpclient import TCPClient
@ -19,11 +19,8 @@ import functools
import re
import socket
import sys
from io import BytesIO
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO # python 2
try:
import urlparse # py2
@ -53,9 +50,6 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
"""Non-blocking HTTP client with no external dependencies.
This class implements an HTTP 1.1 client on top of Tornado's IOStreams.
It does not currently implement all applicable parts of the HTTP
specification, but it does enough to work with major web service APIs.
Some features found in the curl-based AsyncHTTPClient are not yet
supported. In particular, proxies are not supported, connections
are not reused, and callers cannot select the network interface to be
@ -63,25 +57,39 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
"""
def initialize(self, io_loop, max_clients=10,
hostname_mapping=None, max_buffer_size=104857600,
resolver=None, defaults=None, max_header_size=None):
resolver=None, defaults=None, max_header_size=None,
max_body_size=None):
"""Creates a AsyncHTTPClient.
Only a single AsyncHTTPClient instance exists per IOLoop
in order to provide limitations on the number of pending connections.
force_instance=True may be used to suppress this behavior.
``force_instance=True`` may be used to suppress this behavior.
max_clients is the number of concurrent requests that can be
in progress. Note that this arguments are only used when the
client is first created, and will be ignored when an existing
client is reused.
Note that because of this implicit reuse, unless ``force_instance``
is used, only the first call to the constructor actually uses
its arguments. It is recommended to use the ``configure`` method
instead of the constructor to ensure that arguments take effect.
hostname_mapping is a dictionary mapping hostnames to IP addresses.
``max_clients`` is the number of concurrent requests that can be
in progress; when this limit is reached additional requests will be
queued. Note that time spent waiting in this queue still counts
against the ``request_timeout``.
``hostname_mapping`` is a dictionary mapping hostnames to IP addresses.
It can be used to make local DNS changes when modifying system-wide
settings like /etc/hosts is not possible or desirable (e.g. in
settings like ``/etc/hosts`` is not possible or desirable (e.g. in
unittests).
max_buffer_size is the number of bytes that can be read by IOStream. It
defaults to 100mb.
``max_buffer_size`` (default 100MB) is the number of bytes
that can be read into memory at once. ``max_body_size``
(defaults to ``max_buffer_size``) is the largest response body
that the client will accept. Without a
``streaming_callback``, the smaller of these two limits
applies; with a ``streaming_callback`` only ``max_body_size``
does.
.. versionchanged:: 4.2
Added the ``max_body_size`` argument.
"""
super(SimpleAsyncHTTPClient, self).initialize(io_loop,
defaults=defaults)
@ -91,6 +99,7 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
self.waiting = {}
self.max_buffer_size = max_buffer_size
self.max_header_size = max_header_size
self.max_body_size = max_body_size
# TCPClient could create a Resolver for us, but we have to do it
# ourselves to support hostname_mapping.
if resolver:
@ -138,10 +147,14 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
release_callback = functools.partial(self._release_fetch, key)
self._handle_request(request, release_callback, callback)
def _connection_class(self):
return _HTTPConnection
def _handle_request(self, request, release_callback, final_callback):
_HTTPConnection(self.io_loop, self, request, release_callback,
final_callback, self.max_buffer_size, self.tcp_client,
self.max_header_size)
self._connection_class()(
self.io_loop, self, request, release_callback,
final_callback, self.max_buffer_size, self.tcp_client,
self.max_header_size, self.max_body_size)
def _release_fetch(self, key):
del self.active[key]
@ -169,7 +182,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
def __init__(self, io_loop, client, request, release_callback,
final_callback, max_buffer_size, tcp_client,
max_header_size):
max_header_size, max_body_size):
self.start_time = io_loop.time()
self.io_loop = io_loop
self.client = client
@ -179,6 +192,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
self.max_buffer_size = max_buffer_size
self.tcp_client = tcp_client
self.max_header_size = max_header_size
self.max_body_size = max_body_size
self.code = None
self.headers = None
self.chunks = []
@ -196,12 +210,8 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
netloc = self.parsed.netloc
if "@" in netloc:
userpass, _, netloc = netloc.rpartition("@")
match = re.match(r'^(.+):(\d+)$', netloc)
if match:
host = match.group(1)
port = int(match.group(2))
else:
host = netloc
host, port = httputil.split_host_and_port(netloc)
if port is None:
port = 443 if self.parsed.scheme == "https" else 80
if re.match(r'^\[.*\]$', host):
# raw ipv6 addresses in urls are enclosed in brackets
@ -222,16 +232,29 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
stack_context.wrap(self._on_timeout))
self.tcp_client.connect(host, port, af=af,
ssl_options=ssl_options,
max_buffer_size=self.max_buffer_size,
callback=self._on_connect)
def _get_ssl_options(self, scheme):
if scheme == "https":
if self.request.ssl_options is not None:
return self.request.ssl_options
# If we are using the defaults, don't construct a
# new SSLContext.
if (self.request.validate_cert and
self.request.ca_certs is None and
self.request.client_cert is None and
self.request.client_key is None):
return _client_ssl_defaults
ssl_options = {}
if self.request.validate_cert:
ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
if self.request.ca_certs is not None:
ssl_options["ca_certs"] = self.request.ca_certs
else:
elif not hasattr(ssl, 'create_default_context'):
# When create_default_context is present,
# we can omit the "ca_certs" parameter entirely,
# which avoids the dependency on "certifi" for py34.
ssl_options["ca_certs"] = _default_ca_certs()
if self.request.client_key is not None:
ssl_options["keyfile"] = self.request.client_key
@ -277,7 +300,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
stream.close()
return
self.stream = stream
self.stream.set_close_callback(self._on_close)
self.stream.set_close_callback(self.on_connection_close)
self._remove_timeout()
if self.final_callback is None:
return
@ -316,18 +339,18 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
if self.request.user_agent:
self.request.headers["User-Agent"] = self.request.user_agent
if not self.request.allow_nonstandard_methods:
if self.request.method in ("POST", "PATCH", "PUT"):
if (self.request.body is None and
self.request.body_producer is None):
raise AssertionError(
'Body must not be empty for "%s" request'
% self.request.method)
else:
if (self.request.body is not None or
self.request.body_producer is not None):
raise AssertionError(
'Body must be empty for "%s" request'
% self.request.method)
# Some HTTP methods nearly always have bodies while others
# almost never do. Fail in this case unless the user has
# opted out of sanity checks with allow_nonstandard_methods.
body_expected = self.request.method in ("POST", "PATCH", "PUT")
body_present = (self.request.body is not None or
self.request.body_producer is not None)
if ((body_expected and not body_present) or
(body_present and not body_expected)):
raise ValueError(
'Body must %sbe None for method %s (unless '
'allow_nonstandard_methods is true)' %
('not ' if body_expected else '', self.request.method))
if self.request.expect_100_continue:
self.request.headers["Expect"] = "100-continue"
if self.request.body is not None:
@ -342,29 +365,35 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
self.request.headers["Accept-Encoding"] = "gzip"
req_path = ((self.parsed.path or '/') +
(('?' + self.parsed.query) if self.parsed.query else ''))
self.stream.set_nodelay(True)
self.connection = HTTP1Connection(
self.stream, True,
HTTP1ConnectionParameters(
no_keep_alive=True,
max_header_size=self.max_header_size,
decompress=self.request.decompress_response),
self._sockaddr)
self.connection = self._create_connection(stream)
start_line = httputil.RequestStartLine(self.request.method,
req_path, 'HTTP/1.1')
req_path, '')
self.connection.write_headers(start_line, self.request.headers)
if self.request.expect_100_continue:
self._read_response()
else:
self._write_body(True)
def _create_connection(self, stream):
stream.set_nodelay(True)
connection = HTTP1Connection(
stream, True,
HTTP1ConnectionParameters(
no_keep_alive=True,
max_header_size=self.max_header_size,
max_body_size=self.max_body_size,
decompress=self.request.decompress_response),
self._sockaddr)
return connection
def _write_body(self, start_read):
if self.request.body is not None:
self.connection.write(self.request.body)
self.connection.finish()
elif self.request.body_producer is not None:
fut = self.request.body_producer(self.connection.write)
if is_future(fut):
if fut is not None:
fut = gen.convert_yielded(fut)
def on_body_written(fut):
fut.result()
self.connection.finish()
@ -372,7 +401,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
self._read_response()
self.io_loop.add_future(fut, on_body_written)
return
self.connection.finish()
self.connection.finish()
if start_read:
self._read_response()
@ -400,7 +429,10 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
if self.final_callback:
self._remove_timeout()
if isinstance(value, StreamClosedError):
value = HTTPError(599, "Stream closed")
if value.real_error is None:
value = HTTPError(599, "Stream closed")
else:
value = value.real_error
self._run_callback(HTTPResponse(self.request, 599, error=value,
request_time=self.io_loop.time() - self.start_time,
))
@ -418,34 +450,26 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
# pass it along, unless it's just the stream being closed.
return isinstance(value, StreamClosedError)
def _on_close(self):
def on_connection_close(self):
if self.final_callback is not None:
message = "Connection closed"
if self.stream.error:
raise self.stream.error
raise HTTPError(599, message)
try:
raise HTTPError(599, message)
except HTTPError:
self._handle_exception(*sys.exc_info())
def headers_received(self, first_line, headers):
if self.request.expect_100_continue and first_line.code == 100:
self._write_body(False)
return
self.headers = headers
self.code = first_line.code
self.reason = first_line.reason
self.headers = headers
if "Content-Length" in self.headers:
if "," in self.headers["Content-Length"]:
# Proxies sometimes cause Content-Length headers to get
# duplicated. If all the values are identical then we can
# use them but if they differ it's an error.
pieces = re.split(r',\s*', self.headers["Content-Length"])
if any(i != pieces[0] for i in pieces):
raise ValueError("Multiple unequal Content-Lengths: %r" %
self.headers["Content-Length"])
self.headers["Content-Length"] = pieces[0]
content_length = int(self.headers["Content-Length"])
else:
content_length = None
if self._should_follow_redirect():
return
if self.request.header_callback is not None:
# Reassemble the start line.
@ -454,22 +478,17 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
self.request.header_callback("%s: %s\r\n" % (k, v))
self.request.header_callback('\r\n')
if 100 <= self.code < 200 or self.code == 204:
# These response codes never have bodies
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
if ("Transfer-Encoding" in self.headers or
content_length not in (None, 0)):
raise ValueError("Response with code %d should not have body" %
self.code)
def _should_follow_redirect(self):
return (self.request.follow_redirects and
self.request.max_redirects > 0 and
self.code in (301, 302, 303, 307))
def finish(self):
data = b''.join(self.chunks)
self._remove_timeout()
original_request = getattr(self.request, "original_request",
self.request)
if (self.request.follow_redirects and
self.request.max_redirects > 0 and
self.code in (301, 302, 303, 307)):
if self._should_follow_redirect():
assert isinstance(self.request, _RequestProxy)
new_request = copy.copy(self.request.request)
new_request.url = urlparse.urljoin(self.request.url,
@ -516,6 +535,9 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
self.stream.close()
def data_received(self, chunk):
if self._should_follow_redirect():
# We're going to follow a redirect so just discard the body.
return
if self.request.streaming_callback is not None:
self.request.streaming_callback(chunk)
else:

View file

@ -41,13 +41,13 @@ Example usage::
sys.exit(1)
with StackContext(die_on_error):
# Any exception thrown here *or in callback and its desendents*
# Any exception thrown here *or in callback and its descendants*
# will cause the process to exit instead of spinning endlessly
# in the ioloop.
http_client.fetch(url, callback)
ioloop.start()
Most applications shouln't have to work with `StackContext` directly.
Most applications shouldn't have to work with `StackContext` directly.
Here are a few rules of thumb for when it's necessary:
* If you're writing an asynchronous library that doesn't rely on a

View file

@ -111,6 +111,7 @@ class _Connector(object):
if self.timeout is not None:
# If the first attempt failed, don't wait for the
# timeout to try an address from the secondary queue.
self.io_loop.remove_timeout(self.timeout)
self.on_timeout()
return
self.clear_timeout()
@ -135,6 +136,9 @@ class _Connector(object):
class TCPClient(object):
"""A non-blocking TCP connection factory.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
def __init__(self, resolver=None, io_loop=None):
self.io_loop = io_loop or IOLoop.current()
@ -163,7 +167,7 @@ class TCPClient(object):
functools.partial(self._create_stream, max_buffer_size))
af, addr, stream = yield connector.start()
# TODO: For better performance we could cache the (af, addr)
# information here and re-use it on sbusequent connections to
# information here and re-use it on subsequent connections to
# the same host. (http://tools.ietf.org/html/rfc6555#section-4.2)
if ssl_options is not None:
stream = yield stream.start_tls(False, ssl_options=ssl_options,

View file

@ -41,14 +41,15 @@ class TCPServer(object):
To use `TCPServer`, define a subclass which overrides the `handle_stream`
method.
To make this server serve SSL traffic, send the ssl_options dictionary
argument with the arguments required for the `ssl.wrap_socket` method,
including "certfile" and "keyfile"::
To make this server serve SSL traffic, send the ``ssl_options`` keyword
argument with an `ssl.SSLContext` object. For compatibility with older
versions of Python ``ssl_options`` may also be a dictionary of keyword
arguments for the `ssl.wrap_socket` method.::
TCPServer(ssl_options={
"certfile": os.path.join(data_dir, "mydomain.crt"),
"keyfile": os.path.join(data_dir, "mydomain.key"),
})
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_ctx.load_cert_chain(os.path.join(data_dir, "mydomain.crt"),
os.path.join(data_dir, "mydomain.key"))
TCPServer(ssl_options=ssl_ctx)
`TCPServer` initialization follows one of three patterns:
@ -56,14 +57,14 @@ class TCPServer(object):
server = TCPServer()
server.listen(8888)
IOLoop.instance().start()
IOLoop.current().start()
2. `bind`/`start`: simple multi-process::
server = TCPServer()
server.bind(8888)
server.start(0) # Forks multiple sub-processes
IOLoop.instance().start()
IOLoop.current().start()
When using this interface, an `.IOLoop` must *not* be passed
to the `TCPServer` constructor. `start` will always start
@ -75,7 +76,7 @@ class TCPServer(object):
tornado.process.fork_processes(0)
server = TCPServer()
server.add_sockets(sockets)
IOLoop.instance().start()
IOLoop.current().start()
The `add_sockets` interface is more complicated, but it can be
used with `tornado.process.fork_processes` to give you more
@ -95,7 +96,7 @@ class TCPServer(object):
self._pending_sockets = []
self._started = False
self.max_buffer_size = max_buffer_size
self.read_chunk_size = None
self.read_chunk_size = read_chunk_size
# Verify the SSL options. Otherwise we don't get errors until clients
# connect. This doesn't verify that the keys are legitimate, but
@ -212,7 +213,20 @@ class TCPServer(object):
sock.close()
def handle_stream(self, stream, address):
"""Override to handle a new `.IOStream` from an incoming connection."""
"""Override to handle a new `.IOStream` from an incoming connection.
This method may be a coroutine; if so any exceptions it raises
asynchronously will be logged. Accepting of incoming connections
will not be blocked by this coroutine.
If this `TCPServer` is configured for SSL, ``handle_stream``
may be called before the SSL handshake has completed. Use
`.SSLIOStream.wait_for_handshake` if you need to verify the client's
certificate or use NPN/ALPN.
.. versionchanged:: 4.2
Added the option for this method to be a coroutine.
"""
raise NotImplementedError()
def _handle_connection(self, connection, address):
@ -252,6 +266,8 @@ class TCPServer(object):
stream = IOStream(connection, io_loop=self.io_loop,
max_buffer_size=self.max_buffer_size,
read_chunk_size=self.read_chunk_size)
self.handle_stream(stream, address)
future = self.handle_stream(stream, address)
if future is not None:
self.io_loop.add_future(future, lambda f: f.result())
except Exception:
app_log.error("Error in connection callback", exc_info=True)

View file

@ -186,6 +186,11 @@ with ``{# ... #}``.
``{% while *condition* %}... {% end %}``
Same as the python ``while`` statement. ``{% break %}`` and
``{% continue %}`` may be used inside the loop.
``{% whitespace *mode* %}``
Sets the whitespace mode for the remainder of the current file
(or until the next ``{% whitespace %}`` directive). See
`filter_whitespace` for available options. New in Tornado 4.3.
"""
from __future__ import absolute_import, division, print_function, with_statement
@ -199,7 +204,7 @@ import threading
from tornado import escape
from tornado.log import app_log
from tornado.util import bytes_type, ObjectDict, exec_in, unicode_type
from tornado.util import ObjectDict, exec_in, unicode_type
try:
from cStringIO import StringIO # py2
@ -210,6 +215,31 @@ _DEFAULT_AUTOESCAPE = "xhtml_escape"
_UNSET = object()
def filter_whitespace(mode, text):
"""Transform whitespace in ``text`` according to ``mode``.
Available modes are:
* ``all``: Return all whitespace unmodified.
* ``single``: Collapse consecutive whitespace with a single whitespace
character, preserving newlines.
* ``oneline``: Collapse all runs of whitespace into a single space
character, removing all newlines in the process.
.. versionadded:: 4.3
"""
if mode == 'all':
return text
elif mode == 'single':
text = re.sub(r"([\t ]+)", " ", text)
text = re.sub(r"(\s*\n\s*)", "\n", text)
return text
elif mode == 'oneline':
return re.sub(r"(\s+)", " ", text)
else:
raise Exception("invalid whitespace mode %s" % mode)
class Template(object):
"""A compiled template.
@ -220,21 +250,58 @@ class Template(object):
# autodoc because _UNSET looks like garbage. When changing
# this signature update website/sphinx/template.rst too.
def __init__(self, template_string, name="<string>", loader=None,
compress_whitespace=None, autoescape=_UNSET):
self.name = name
if compress_whitespace is None:
compress_whitespace = name.endswith(".html") or \
name.endswith(".js")
compress_whitespace=_UNSET, autoescape=_UNSET,
whitespace=None):
"""Construct a Template.
:arg str template_string: the contents of the template file.
:arg str name: the filename from which the template was loaded
(used for error message).
:arg tornado.template.BaseLoader loader: the `~tornado.template.BaseLoader` responsible for this template,
used to resolve ``{% include %}`` and ``{% extend %}``
directives.
:arg bool compress_whitespace: Deprecated since Tornado 4.3.
Equivalent to ``whitespace="single"`` if true and
``whitespace="all"`` if false.
:arg str autoescape: The name of a function in the template
namespace, or ``None`` to disable escaping by default.
:arg str whitespace: A string specifying treatment of whitespace;
see `filter_whitespace` for options.
.. versionchanged:: 4.3
Added ``whitespace`` parameter; deprecated ``compress_whitespace``.
"""
self.name = escape.native_str(name)
if compress_whitespace is not _UNSET:
# Convert deprecated compress_whitespace (bool) to whitespace (str).
if whitespace is not None:
raise Exception("cannot set both whitespace and compress_whitespace")
whitespace = "single" if compress_whitespace else "all"
if whitespace is None:
if loader and loader.whitespace:
whitespace = loader.whitespace
else:
# Whitespace defaults by filename.
if name.endswith(".html") or name.endswith(".js"):
whitespace = "single"
else:
whitespace = "all"
# Validate the whitespace setting.
filter_whitespace(whitespace, '')
if autoescape is not _UNSET:
self.autoescape = autoescape
elif loader:
self.autoescape = loader.autoescape
else:
self.autoescape = _DEFAULT_AUTOESCAPE
self.namespace = loader.namespace if loader else {}
reader = _TemplateReader(name, escape.native_str(template_string))
reader = _TemplateReader(name, escape.native_str(template_string),
whitespace)
self.file = _File(self, _parse(reader, self))
self.code = self._generate_python(loader, compress_whitespace)
self.code = self._generate_python(loader)
self.loader = loader
try:
# Under python2.5, the fake filename used here must match
@ -261,7 +328,7 @@ class Template(object):
"linkify": escape.linkify,
"datetime": datetime,
"_tt_utf8": escape.utf8, # for internal use
"_tt_string_types": (unicode_type, bytes_type),
"_tt_string_types": (unicode_type, bytes),
# __name__ and __loader__ allow the traceback mechanism to find
# the generated source code.
"__name__": self.name.replace('.', '_'),
@ -277,7 +344,7 @@ class Template(object):
linecache.clearcache()
return execute()
def _generate_python(self, loader, compress_whitespace):
def _generate_python(self, loader):
buffer = StringIO()
try:
# named_blocks maps from names to _NamedBlock objects
@ -286,8 +353,8 @@ class Template(object):
ancestors.reverse()
for ancestor in ancestors:
ancestor.find_named_blocks(loader, named_blocks)
writer = _CodeWriter(buffer, named_blocks, loader, ancestors[0].template,
compress_whitespace)
writer = _CodeWriter(buffer, named_blocks, loader,
ancestors[0].template)
ancestors[0].generate(writer)
return buffer.getvalue()
finally:
@ -312,12 +379,26 @@ class BaseLoader(object):
``{% extends %}`` and ``{% include %}``. The loader caches all
templates after they are loaded the first time.
"""
def __init__(self, autoescape=_DEFAULT_AUTOESCAPE, namespace=None):
"""``autoescape`` must be either None or a string naming a function
in the template namespace, such as "xhtml_escape".
def __init__(self, autoescape=_DEFAULT_AUTOESCAPE, namespace=None,
whitespace=None):
"""Construct a template loader.
:arg str autoescape: The name of a function in the template
namespace, such as "xhtml_escape", or ``None`` to disable
autoescaping by default.
:arg dict namespace: A dictionary to be added to the default template
namespace, or ``None``.
:arg str whitespace: A string specifying default behavior for
whitespace in templates; see `filter_whitespace` for options.
Default is "single" for files ending in ".html" and ".js" and
"all" for other files.
.. versionchanged:: 4.3
Added ``whitespace`` parameter.
"""
self.autoescape = autoescape
self.namespace = namespace or {}
self.whitespace = whitespace
self.templates = {}
# self.lock protects self.templates. It's a reentrant lock
# because templates may load other templates via `include` or
@ -558,37 +639,49 @@ class _Module(_Expression):
class _Text(_Node):
def __init__(self, value, line):
def __init__(self, value, line, whitespace):
self.value = value
self.line = line
self.whitespace = whitespace
def generate(self, writer):
value = self.value
# Compress lots of white space to a single character. If the whitespace
# breaks a line, have it continue to break a line, but just with a
# single \n character
if writer.compress_whitespace and "<pre>" not in value:
value = re.sub(r"([\t ]+)", " ", value)
value = re.sub(r"(\s*\n\s*)", "\n", value)
# Compress whitespace if requested, with a crude heuristic to avoid
# altering preformatted whitespace.
if "<pre>" not in value:
value = filter_whitespace(self.whitespace, value)
if value:
writer.write_line('_tt_append(%r)' % escape.utf8(value), self.line)
class ParseError(Exception):
"""Raised for template syntax errors."""
pass
"""Raised for template syntax errors.
``ParseError`` instances have ``filename`` and ``lineno`` attributes
indicating the position of the error.
.. versionchanged:: 4.3
Added ``filename`` and ``lineno`` attributes.
"""
def __init__(self, message, filename, lineno):
self.message = message
# The names "filename" and "lineno" are chosen for consistency
# with python SyntaxError.
self.filename = filename
self.lineno = lineno
def __str__(self):
return '%s at %s:%d' % (self.message, self.filename, self.lineno)
class _CodeWriter(object):
def __init__(self, file, named_blocks, loader, current_template,
compress_whitespace):
def __init__(self, file, named_blocks, loader, current_template):
self.file = file
self.named_blocks = named_blocks
self.loader = loader
self.current_template = current_template
self.compress_whitespace = compress_whitespace
self.apply_counter = 0
self.include_stack = []
self._indent = 0
@ -633,9 +726,10 @@ class _CodeWriter(object):
class _TemplateReader(object):
def __init__(self, name, text):
def __init__(self, name, text, whitespace):
self.name = name
self.text = text
self.whitespace = whitespace
self.line = 1
self.pos = 0
@ -687,6 +781,9 @@ class _TemplateReader(object):
def __str__(self):
return self.text[self.pos:]
def raise_parse_error(self, msg):
raise ParseError(msg, self.name, self.line)
def _format_code(code):
lines = code.splitlines()
@ -704,9 +801,10 @@ def _parse(reader, template, in_block=None, in_loop=None):
if curly == -1 or curly + 1 == reader.remaining():
# EOF
if in_block:
raise ParseError("Missing {%% end %%} block for %s" %
in_block)
body.chunks.append(_Text(reader.consume(), reader.line))
reader.raise_parse_error(
"Missing {%% end %%} block for %s" % in_block)
body.chunks.append(_Text(reader.consume(), reader.line,
reader.whitespace))
return body
# If the first curly brace is not the start of a special token,
# start searching from the character after it
@ -725,7 +823,8 @@ def _parse(reader, template, in_block=None, in_loop=None):
# Append any text before the special token
if curly > 0:
cons = reader.consume(curly)
body.chunks.append(_Text(cons, reader.line))
body.chunks.append(_Text(cons, reader.line,
reader.whitespace))
start_brace = reader.consume(2)
line = reader.line
@ -736,14 +835,15 @@ def _parse(reader, template, in_block=None, in_loop=None):
# which also use double braces.
if reader.remaining() and reader[0] == "!":
reader.consume(1)
body.chunks.append(_Text(start_brace, line))
body.chunks.append(_Text(start_brace, line,
reader.whitespace))
continue
# Comment
if start_brace == "{#":
end = reader.find("#}")
if end == -1:
raise ParseError("Missing end expression #} on line %d" % line)
reader.raise_parse_error("Missing end comment #}")
contents = reader.consume(end).strip()
reader.consume(2)
continue
@ -752,11 +852,11 @@ def _parse(reader, template, in_block=None, in_loop=None):
if start_brace == "{{":
end = reader.find("}}")
if end == -1:
raise ParseError("Missing end expression }} on line %d" % line)
reader.raise_parse_error("Missing end expression }}")
contents = reader.consume(end).strip()
reader.consume(2)
if not contents:
raise ParseError("Empty expression on line %d" % line)
reader.raise_parse_error("Empty expression")
body.chunks.append(_Expression(contents, line))
continue
@ -764,11 +864,11 @@ def _parse(reader, template, in_block=None, in_loop=None):
assert start_brace == "{%", start_brace
end = reader.find("%}")
if end == -1:
raise ParseError("Missing end block %%} on line %d" % line)
reader.raise_parse_error("Missing end block %}")
contents = reader.consume(end).strip()
reader.consume(2)
if not contents:
raise ParseError("Empty block tag ({%% %%}) on line %d" % line)
reader.raise_parse_error("Empty block tag ({% %})")
operator, space, suffix = contents.partition(" ")
suffix = suffix.strip()
@ -783,40 +883,43 @@ def _parse(reader, template, in_block=None, in_loop=None):
allowed_parents = intermediate_blocks.get(operator)
if allowed_parents is not None:
if not in_block:
raise ParseError("%s outside %s block" %
(operator, allowed_parents))
reader.raise_parse_error("%s outside %s block" %
(operator, allowed_parents))
if in_block not in allowed_parents:
raise ParseError("%s block cannot be attached to %s block" % (operator, in_block))
reader.raise_parse_error(
"%s block cannot be attached to %s block" %
(operator, in_block))
body.chunks.append(_IntermediateControlBlock(contents, line))
continue
# End tag
elif operator == "end":
if not in_block:
raise ParseError("Extra {%% end %%} block on line %d" % line)
reader.raise_parse_error("Extra {% end %} block")
return body
elif operator in ("extends", "include", "set", "import", "from",
"comment", "autoescape", "raw", "module"):
"comment", "autoescape", "whitespace", "raw",
"module"):
if operator == "comment":
continue
if operator == "extends":
suffix = suffix.strip('"').strip("'")
if not suffix:
raise ParseError("extends missing file path on line %d" % line)
reader.raise_parse_error("extends missing file path")
block = _ExtendsBlock(suffix)
elif operator in ("import", "from"):
if not suffix:
raise ParseError("import missing statement on line %d" % line)
reader.raise_parse_error("import missing statement")
block = _Statement(contents, line)
elif operator == "include":
suffix = suffix.strip('"').strip("'")
if not suffix:
raise ParseError("include missing file path on line %d" % line)
reader.raise_parse_error("include missing file path")
block = _IncludeBlock(suffix, reader, line)
elif operator == "set":
if not suffix:
raise ParseError("set missing statement on line %d" % line)
reader.raise_parse_error("set missing statement")
block = _Statement(suffix, line)
elif operator == "autoescape":
fn = suffix.strip()
@ -824,6 +927,12 @@ def _parse(reader, template, in_block=None, in_loop=None):
fn = None
template.autoescape = fn
continue
elif operator == "whitespace":
mode = suffix.strip()
# Validate the selected mode
filter_whitespace(mode, '')
reader.whitespace = mode
continue
elif operator == "raw":
block = _Expression(suffix, line, raw=True)
elif operator == "module":
@ -844,11 +953,11 @@ def _parse(reader, template, in_block=None, in_loop=None):
if operator == "apply":
if not suffix:
raise ParseError("apply missing method name on line %d" % line)
reader.raise_parse_error("apply missing method name")
block = _ApplyBlock(suffix, line, block_body)
elif operator == "block":
if not suffix:
raise ParseError("block missing name on line %d" % line)
reader.raise_parse_error("block missing name")
block = _NamedBlock(suffix, block_body, template, line)
else:
block = _ControlBlock(contents, line, block_body)
@ -857,9 +966,10 @@ def _parse(reader, template, in_block=None, in_loop=None):
elif operator in ("break", "continue"):
if not in_loop:
raise ParseError("%s outside %s block" % (operator, set(["for", "while"])))
reader.raise_parse_error("%s outside %s block" %
(operator, set(["for", "while"])))
body.chunks.append(_Statement(contents, line))
continue
else:
raise ParseError("unknown operator: %r" % operator)
reader.raise_parse_error("unknown operator: %r" % operator)

View file

@ -1,4 +0,0 @@
Test coverage is almost non-existent, but it's a start. Be sure to
set PYTHONPATH apprioriately (generally to the root directory of your
tornado checkout) when running tests to make sure you're getting the
version of the tornado package that you expect.

View file

@ -0,0 +1,113 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, with_statement
from tornado import gen
from tornado.testing import AsyncTestCase, gen_test
from tornado.test.util import unittest, skipBefore33, skipBefore35, exec_test
try:
from tornado.platform.asyncio import asyncio
except ImportError:
asyncio = None
else:
from tornado.platform.asyncio import AsyncIOLoop, to_asyncio_future
# This is used in dynamically-evaluated code, so silence pyflakes.
to_asyncio_future
@unittest.skipIf(asyncio is None, "asyncio module not present")
class AsyncIOLoopTest(AsyncTestCase):
def get_new_ioloop(self):
io_loop = AsyncIOLoop()
asyncio.set_event_loop(io_loop.asyncio_loop)
return io_loop
def test_asyncio_callback(self):
# Basic test that the asyncio loop is set up correctly.
asyncio.get_event_loop().call_soon(self.stop)
self.wait()
@gen_test
def test_asyncio_future(self):
# Test that we can yield an asyncio future from a tornado coroutine.
# Without 'yield from', we must wrap coroutines in asyncio.async.
x = yield asyncio.async(
asyncio.get_event_loop().run_in_executor(None, lambda: 42))
self.assertEqual(x, 42)
@skipBefore33
@gen_test
def test_asyncio_yield_from(self):
# Test that we can use asyncio coroutines with 'yield from'
# instead of asyncio.async(). This requires python 3.3 syntax.
namespace = exec_test(globals(), locals(), """
@gen.coroutine
def f():
event_loop = asyncio.get_event_loop()
x = yield from event_loop.run_in_executor(None, lambda: 42)
return x
""")
result = yield namespace['f']()
self.assertEqual(result, 42)
@skipBefore35
def test_asyncio_adapter(self):
# This test demonstrates that when using the asyncio coroutine
# runner (i.e. run_until_complete), the to_asyncio_future
# adapter is needed. No adapter is needed in the other direction,
# as demonstrated by other tests in the package.
@gen.coroutine
def tornado_coroutine():
yield gen.Task(self.io_loop.add_callback)
raise gen.Return(42)
native_coroutine_without_adapter = exec_test(globals(), locals(), """
async def native_coroutine_without_adapter():
return await tornado_coroutine()
""")["native_coroutine_without_adapter"]
native_coroutine_with_adapter = exec_test(globals(), locals(), """
async def native_coroutine_with_adapter():
return await to_asyncio_future(tornado_coroutine())
""")["native_coroutine_with_adapter"]
# Use the adapter, but two degrees from the tornado coroutine.
native_coroutine_with_adapter2 = exec_test(globals(), locals(), """
async def native_coroutine_with_adapter2():
return await to_asyncio_future(native_coroutine_without_adapter())
""")["native_coroutine_with_adapter2"]
# Tornado supports native coroutines both with and without adapters
self.assertEqual(
self.io_loop.run_sync(native_coroutine_without_adapter),
42)
self.assertEqual(
self.io_loop.run_sync(native_coroutine_with_adapter),
42)
self.assertEqual(
self.io_loop.run_sync(native_coroutine_with_adapter2),
42)
# Asyncio only supports coroutines that yield asyncio-compatible
# Futures.
with self.assertRaises(RuntimeError):
asyncio.get_event_loop().run_until_complete(
native_coroutine_without_adapter())
self.assertEqual(
asyncio.get_event_loop().run_until_complete(
native_coroutine_with_adapter()),
42)
self.assertEqual(
asyncio.get_event_loop().run_until_complete(
native_coroutine_with_adapter2()),
42)

View file

@ -5,10 +5,11 @@
from __future__ import absolute_import, division, print_function, with_statement
from tornado.auth import OpenIdMixin, OAuthMixin, OAuth2Mixin, TwitterMixin, GoogleMixin, AuthError
from tornado.auth import OpenIdMixin, OAuthMixin, OAuth2Mixin, TwitterMixin, AuthError, GoogleOAuth2Mixin, FacebookGraphMixin
from tornado.concurrent import Future
from tornado.escape import json_decode
from tornado import gen
from tornado.httputil import url_concat
from tornado.log import gen_log
from tornado.testing import AsyncHTTPTestCase, ExpectLog
from tornado.util import u
@ -125,6 +126,38 @@ class OAuth2ClientLoginHandler(RequestHandler, OAuth2Mixin):
assert res.done()
class FacebookClientLoginHandler(RequestHandler, FacebookGraphMixin):
def initialize(self, test):
self._OAUTH_AUTHORIZE_URL = test.get_url('/facebook/server/authorize')
self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/facebook/server/access_token')
self._FACEBOOK_BASE_URL = test.get_url('/facebook/server')
@gen.coroutine
def get(self):
if self.get_argument("code", None):
user = yield self.get_authenticated_user(
redirect_uri=self.request.full_url(),
client_id=self.settings["facebook_api_key"],
client_secret=self.settings["facebook_secret"],
code=self.get_argument("code"))
self.write(user)
else:
yield self.authorize_redirect(
redirect_uri=self.request.full_url(),
client_id=self.settings["facebook_api_key"],
extra_params={"scope": "read_stream,offline_access"})
class FacebookServerAccessTokenHandler(RequestHandler):
def get(self):
self.write('access_token=asdf')
class FacebookServerMeHandler(RequestHandler):
def get(self):
self.write('{}')
class TwitterClientHandler(RequestHandler, TwitterMixin):
def initialize(self, test):
self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
@ -238,28 +271,6 @@ class TwitterServerVerifyCredentialsHandler(RequestHandler):
self.write(dict(screen_name='foo', name='Foo'))
class GoogleOpenIdClientLoginHandler(RequestHandler, GoogleMixin):
def initialize(self, test):
self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate')
@asynchronous
def get(self):
if self.get_argument("openid.mode", None):
self.get_authenticated_user(self.on_user)
return
res = self.authenticate_redirect()
assert isinstance(res, Future)
assert res.done()
def on_user(self, user):
if user is None:
raise Exception("user is None")
self.finish(user)
def get_auth_http_client(self):
return self.settings['http_client']
class AuthTest(AsyncHTTPTestCase):
def get_app(self):
return Application(
@ -281,25 +292,30 @@ class AuthTest(AsyncHTTPTestCase):
dict(version='1.0a')),
('/oauth2/client/login', OAuth2ClientLoginHandler, dict(test=self)),
('/facebook/client/login', FacebookClientLoginHandler, dict(test=self)),
('/twitter/client/login', TwitterClientLoginHandler, dict(test=self)),
('/twitter/client/login_gen_engine', TwitterClientLoginGenEngineHandler, dict(test=self)),
('/twitter/client/login_gen_coroutine', TwitterClientLoginGenCoroutineHandler, dict(test=self)),
('/twitter/client/show_user', TwitterClientShowUserHandler, dict(test=self)),
('/twitter/client/show_user_future', TwitterClientShowUserFutureHandler, dict(test=self)),
('/google/client/openid_login', GoogleOpenIdClientLoginHandler, dict(test=self)),
# simulated servers
('/openid/server/authenticate', OpenIdServerAuthenticateHandler),
('/oauth1/server/request_token', OAuth1ServerRequestTokenHandler),
('/oauth1/server/access_token', OAuth1ServerAccessTokenHandler),
('/facebook/server/access_token', FacebookServerAccessTokenHandler),
('/facebook/server/me', FacebookServerMeHandler),
('/twitter/server/access_token', TwitterServerAccessTokenHandler),
(r'/twitter/api/users/show/(.*)\.json', TwitterServerShowUserHandler),
(r'/twitter/api/account/verify_credentials\.json', TwitterServerVerifyCredentialsHandler),
],
http_client=self.http_client,
twitter_consumer_key='test_twitter_consumer_key',
twitter_consumer_secret='test_twitter_consumer_secret')
twitter_consumer_secret='test_twitter_consumer_secret',
facebook_api_key='test_facebook_api_key',
facebook_secret='test_facebook_secret')
def test_openid_redirect(self):
response = self.fetch('/openid/client/login', follow_redirects=False)
@ -380,6 +396,13 @@ class AuthTest(AsyncHTTPTestCase):
self.assertEqual(response.code, 302)
self.assertTrue('/oauth2/server/authorize?' in response.headers['Location'])
def test_facebook_login(self):
response = self.fetch('/facebook/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue('/facebook/server/authorize?' in response.headers['Location'])
response = self.fetch('/facebook/client/login?code=1234', follow_redirects=False)
self.assertEqual(response.code, 200)
def base_twitter_redirect(self, url):
# Same as test_oauth10a_redirect
response = self.fetch(url, follow_redirects=False)
@ -437,15 +460,86 @@ class AuthTest(AsyncHTTPTestCase):
self.assertEqual(response.code, 500)
self.assertIn(b'Error response HTTP 500', response.body)
def test_google_redirect(self):
# same as test_openid_redirect
response = self.fetch('/google/client/openid_login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(
'/openid/server/authenticate?' in response.headers['Location'])
def test_google_get_user(self):
response = self.fetch('/google/client/openid_login?openid.mode=blah&openid.ns.ax=http://openid.net/srv/ax/1.0&openid.ax.type.email=http://axschema.org/contact/email&openid.ax.value.email=foo@example.com', follow_redirects=False)
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["email"], "foo@example.com")
class GoogleLoginHandler(RequestHandler, GoogleOAuth2Mixin):
def initialize(self, test):
self.test = test
self._OAUTH_REDIRECT_URI = test.get_url('/client/login')
self._OAUTH_AUTHORIZE_URL = test.get_url('/google/oauth2/authorize')
self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/google/oauth2/token')
@gen.coroutine
def get(self):
code = self.get_argument('code', None)
if code is not None:
# retrieve authenticate google user
access = yield self.get_authenticated_user(self._OAUTH_REDIRECT_URI,
code)
user = yield self.oauth2_request(
self.test.get_url("/google/oauth2/userinfo"),
access_token=access["access_token"])
# return the user and access token as json
user["access_token"] = access["access_token"]
self.write(user)
else:
yield self.authorize_redirect(
redirect_uri=self._OAUTH_REDIRECT_URI,
client_id=self.settings['google_oauth']['key'],
client_secret=self.settings['google_oauth']['secret'],
scope=['profile', 'email'],
response_type='code',
extra_params={'prompt': 'select_account'})
class GoogleOAuth2AuthorizeHandler(RequestHandler):
def get(self):
# issue a fake auth code and redirect to redirect_uri
code = 'fake-authorization-code'
self.redirect(url_concat(self.get_argument('redirect_uri'),
dict(code=code)))
class GoogleOAuth2TokenHandler(RequestHandler):
def post(self):
assert self.get_argument('code') == 'fake-authorization-code'
# issue a fake token
self.finish({
'access_token': 'fake-access-token',
'expires_in': 'never-expires'
})
class GoogleOAuth2UserinfoHandler(RequestHandler):
def get(self):
assert self.get_argument('access_token') == 'fake-access-token'
# return a fake user
self.finish({
'name': 'Foo',
'email': 'foo@example.com'
})
class GoogleOAuth2Test(AsyncHTTPTestCase):
def get_app(self):
return Application(
[
# test endpoints
('/client/login', GoogleLoginHandler, dict(test=self)),
# simulated google authorization server endpoints
('/google/oauth2/authorize', GoogleOAuth2AuthorizeHandler),
('/google/oauth2/token', GoogleOAuth2TokenHandler),
('/google/oauth2/userinfo', GoogleOAuth2UserinfoHandler),
],
google_oauth={
"key": 'fake_google_client_id',
"secret": 'fake_google_client_secret'
})
def test_google_login(self):
response = self.fetch('/client/login')
self.assertDictEqual({
u('name'): u('Foo'),
u('email'): u('foo@example.com'),
u('access_token'): u('fake-access-token'),
}, json_decode(response.body))

View file

@ -21,13 +21,14 @@ import socket
import sys
import traceback
from tornado.concurrent import Future, return_future, ReturnValueIgnoredError
from tornado.concurrent import Future, return_future, ReturnValueIgnoredError, run_on_executor
from tornado.escape import utf8, to_unicode
from tornado import gen
from tornado.iostream import IOStream
from tornado import stack_context
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, LogTrapTestCase, bind_unused_port, gen_test
from tornado.test.util import unittest
try:
@ -334,3 +335,81 @@ class DecoratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
class GeneratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
client_class = GeneratorCapClient
@unittest.skipIf(futures is None, "concurrent.futures module not present")
class RunOnExecutorTest(AsyncTestCase):
@gen_test
def test_no_calling(self):
class Object(object):
def __init__(self, io_loop):
self.io_loop = io_loop
self.executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor
def f(self):
return 42
o = Object(io_loop=self.io_loop)
answer = yield o.f()
self.assertEqual(answer, 42)
@gen_test
def test_call_with_no_args(self):
class Object(object):
def __init__(self, io_loop):
self.io_loop = io_loop
self.executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor()
def f(self):
return 42
o = Object(io_loop=self.io_loop)
answer = yield o.f()
self.assertEqual(answer, 42)
@gen_test
def test_call_with_io_loop(self):
class Object(object):
def __init__(self, io_loop):
self._io_loop = io_loop
self.executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor(io_loop='_io_loop')
def f(self):
return 42
o = Object(io_loop=self.io_loop)
answer = yield o.f()
self.assertEqual(answer, 42)
@gen_test
def test_call_with_executor(self):
class Object(object):
def __init__(self, io_loop):
self.io_loop = io_loop
self.__executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor(executor='_Object__executor')
def f(self):
return 42
o = Object(io_loop=self.io_loop)
answer = yield o.f()
self.assertEqual(answer, 42)
@gen_test
def test_call_with_both(self):
class Object(object):
def __init__(self, io_loop):
self._io_loop = io_loop
self.__executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor(io_loop='_io_loop', executor='_Object__executor')
def f(self):
return 42
o = Object(io_loop=self.io_loop)
answer = yield o.f()
self.assertEqual(answer, 42)

View file

@ -10,6 +10,7 @@ from tornado.test import httpclient_test
from tornado.test.util import unittest
from tornado.web import Application, RequestHandler
try:
import pycurl
except ImportError:
@ -120,3 +121,4 @@ class CurlHTTPClientTestCase(AsyncHTTPTestCase):
def test_fail_custom_reason(self):
response = self.fetch('/custom_fail_reason')
self.assertEqual(str(response.error), "HTTP 400: Custom reason")

View file

@ -4,8 +4,8 @@
from __future__ import absolute_import, division, print_function, with_statement
import tornado.escape
from tornado.escape import utf8, xhtml_escape, xhtml_unescape, url_escape, url_unescape, to_unicode, json_decode, json_encode
from tornado.util import u, unicode_type, bytes_type
from tornado.escape import utf8, xhtml_escape, xhtml_unescape, url_escape, url_unescape, to_unicode, json_decode, json_encode, squeeze, recursive_unicode
from tornado.util import u, unicode_type
from tornado.test.util import unittest
linkify_tests = [
@ -154,6 +154,19 @@ class EscapeTestCase(unittest.TestCase):
self.assertEqual(utf8(xhtml_escape(unescaped)), utf8(escaped))
self.assertEqual(utf8(unescaped), utf8(xhtml_unescape(escaped)))
def test_xhtml_unescape_numeric(self):
tests = [
('foo&#32;bar', 'foo bar'),
('foo&#x20;bar', 'foo bar'),
('foo&#X20;bar', 'foo bar'),
('foo&#xabc;bar', u('foo\u0abcbar')),
('foo&#xyz;bar', 'foo&#xyz;bar'), # invalid encoding
('foo&#;bar', 'foo&#;bar'), # invalid encoding
('foo&#x;bar', 'foo&#x;bar'), # invalid encoding
]
for escaped, unescaped in tests:
self.assertEqual(unescaped, xhtml_unescape(escaped))
def test_url_escape_unicode(self):
tests = [
# byte strings are passed through as-is
@ -212,6 +225,21 @@ class EscapeTestCase(unittest.TestCase):
# convert automatically if they are utf8; on python 3 byte strings
# are not allowed.
self.assertEqual(json_decode(json_encode(u("\u00e9"))), u("\u00e9"))
if bytes_type is str:
if bytes is str:
self.assertEqual(json_decode(json_encode(utf8(u("\u00e9")))), u("\u00e9"))
self.assertRaises(UnicodeDecodeError, json_encode, b"\xe9")
def test_squeeze(self):
self.assertEqual(squeeze(u('sequences of whitespace chars')), u('sequences of whitespace chars'))
def test_recursive_unicode(self):
tests = {
'dict': {b"foo": b"bar"},
'list': [b"foo", b"bar"],
'tuple': (b"foo", b"bar"),
'bytes': b"foo"
}
self.assertEqual(recursive_unicode(tests['dict']), {u("foo"): u("bar")})
self.assertEqual(recursive_unicode(tests['list']), [u("foo"), u("bar")])
self.assertEqual(recursive_unicode(tests['tuple']), (u("foo"), u("bar")))
self.assertEqual(recursive_unicode(tests['bytes']), u("foo"))

View file

@ -6,7 +6,6 @@ import functools
import sys
import textwrap
import time
import platform
import weakref
from tornado.concurrent import return_future, Future
@ -16,7 +15,7 @@ from tornado.ioloop import IOLoop
from tornado.log import app_log
from tornado import stack_context
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
from tornado.test.util import unittest, skipOnTravis
from tornado.test.util import unittest, skipOnTravis, skipBefore33, skipBefore35, skipNotCPython, exec_test
from tornado.web import Application, RequestHandler, asynchronous, HTTPError
from tornado import gen
@ -26,10 +25,6 @@ try:
except ImportError:
futures = None
skipBefore33 = unittest.skipIf(sys.version_info < (3, 3), 'PEP 380 not available')
skipNotCPython = unittest.skipIf(platform.python_implementation() != 'CPython',
'Not CPython implementation')
class GenEngineTest(AsyncTestCase):
def setUp(self):
@ -62,6 +57,11 @@ class GenEngineTest(AsyncTestCase):
def async_future(self, result, callback):
self.io_loop.add_callback(callback, result)
@gen.coroutine
def async_exception(self, e):
yield gen.moment
raise e
def test_no_yield(self):
@gen.engine
def f():
@ -385,11 +385,56 @@ class GenEngineTest(AsyncTestCase):
results = yield [self.async_future(1), self.async_future(2)]
self.assertEqual(results, [1, 2])
@gen_test
def test_multi_future_duplicate(self):
f = self.async_future(2)
results = yield [self.async_future(1), f, self.async_future(3), f]
self.assertEqual(results, [1, 2, 3, 2])
@gen_test
def test_multi_dict_future(self):
results = yield dict(foo=self.async_future(1), bar=self.async_future(2))
self.assertEqual(results, dict(foo=1, bar=2))
@gen_test
def test_multi_exceptions(self):
with ExpectLog(app_log, "Multiple exceptions in yield list"):
with self.assertRaises(RuntimeError) as cm:
yield gen.Multi([self.async_exception(RuntimeError("error 1")),
self.async_exception(RuntimeError("error 2"))])
self.assertEqual(str(cm.exception), "error 1")
# With only one exception, no error is logged.
with self.assertRaises(RuntimeError):
yield gen.Multi([self.async_exception(RuntimeError("error 1")),
self.async_future(2)])
# Exception logging may be explicitly quieted.
with self.assertRaises(RuntimeError):
yield gen.Multi([self.async_exception(RuntimeError("error 1")),
self.async_exception(RuntimeError("error 2"))],
quiet_exceptions=RuntimeError)
@gen_test
def test_multi_future_exceptions(self):
with ExpectLog(app_log, "Multiple exceptions in yield list"):
with self.assertRaises(RuntimeError) as cm:
yield [self.async_exception(RuntimeError("error 1")),
self.async_exception(RuntimeError("error 2"))]
self.assertEqual(str(cm.exception), "error 1")
# With only one exception, no error is logged.
with self.assertRaises(RuntimeError):
yield [self.async_exception(RuntimeError("error 1")),
self.async_future(2)]
# Exception logging may be explicitly quieted.
with self.assertRaises(RuntimeError):
yield gen.multi_future(
[self.async_exception(RuntimeError("error 1")),
self.async_exception(RuntimeError("error 2"))],
quiet_exceptions=RuntimeError)
def test_arguments(self):
@gen.engine
def f():
@ -643,19 +688,13 @@ class GenCoroutineTest(AsyncTestCase):
@skipBefore33
@gen_test
def test_async_return(self):
# It is a compile-time error to return a value in a generator
# before Python 3.3, so we must test this with exec.
# Flatten the real global and local namespace into our fake globals:
# it's all global from the perspective of f().
global_namespace = dict(globals(), **locals())
local_namespace = {}
exec(textwrap.dedent("""
namespace = exec_test(globals(), locals(), """
@gen.coroutine
def f():
yield gen.Task(self.io_loop.add_callback)
return 42
"""), global_namespace, local_namespace)
result = yield local_namespace['f']()
""")
result = yield namespace['f']()
self.assertEqual(result, 42)
self.finished = True
@ -665,19 +704,69 @@ class GenCoroutineTest(AsyncTestCase):
# A yield statement exists but is not executed, which means
# this function "returns" via an exception. This exception
# doesn't happen before the exception handling is set up.
global_namespace = dict(globals(), **locals())
local_namespace = {}
exec(textwrap.dedent("""
namespace = exec_test(globals(), locals(), """
@gen.coroutine
def f():
if True:
return 42
yield gen.Task(self.io_loop.add_callback)
"""), global_namespace, local_namespace)
result = yield local_namespace['f']()
""")
result = yield namespace['f']()
self.assertEqual(result, 42)
self.finished = True
@skipBefore35
@gen_test
def test_async_await(self):
# This test verifies that an async function can await a
# yield-based gen.coroutine, and that a gen.coroutine
# (the test method itself) can yield an async function.
namespace = exec_test(globals(), locals(), """
async def f():
await gen.Task(self.io_loop.add_callback)
return 42
""")
result = yield namespace['f']()
self.assertEqual(result, 42)
self.finished = True
@skipBefore35
@gen_test
def test_async_await_mixed_multi_native_future(self):
namespace = exec_test(globals(), locals(), """
async def f1():
await gen.Task(self.io_loop.add_callback)
return 42
""")
@gen.coroutine
def f2():
yield gen.Task(self.io_loop.add_callback)
raise gen.Return(43)
results = yield [namespace['f1'](), f2()]
self.assertEqual(results, [42, 43])
self.finished = True
@skipBefore35
@gen_test
def test_async_await_mixed_multi_native_yieldpoint(self):
namespace = exec_test(globals(), locals(), """
async def f1():
await gen.Task(self.io_loop.add_callback)
return 42
""")
@gen.coroutine
def f2():
yield gen.Task(self.io_loop.add_callback)
raise gen.Return(43)
f2(callback=(yield gen.Callback('cb')))
results = yield [namespace['f1'](), gen.Wait('cb')]
self.assertEqual(results, [42, 43])
self.finished = True
@gen_test
def test_sync_return_no_value(self):
@gen.coroutine
@ -816,6 +905,7 @@ class GenCoroutineTest(AsyncTestCase):
@gen_test
def test_moment(self):
calls = []
@gen.coroutine
def f(name, yieldable):
for i in range(5):
@ -838,6 +928,35 @@ class GenCoroutineTest(AsyncTestCase):
yield [f('a', gen.moment), f('b', immediate)]
self.assertEqual(''.join(calls), 'abbbbbaaaa')
@gen_test
def test_sleep(self):
yield gen.sleep(0.01)
self.finished = True
@skipBefore33
@gen_test
def test_py3_leak_exception_context(self):
class LeakedException(Exception):
pass
@gen.coroutine
def inner(iteration):
raise LeakedException(iteration)
try:
yield inner(1)
except LeakedException as e:
self.assertEqual(str(e), "1")
self.assertIsNone(e.__context__)
try:
yield inner(2)
except LeakedException as e:
self.assertEqual(str(e), "2")
self.assertIsNone(e.__context__)
self.finished = True
class GenSequenceHandler(RequestHandler):
@asynchronous
@ -936,6 +1055,7 @@ class GenYieldExceptionHandler(RequestHandler):
self.finish('ok')
# "Undecorated" here refers to the absence of @asynchronous.
class UndecoratedCoroutinesHandler(RequestHandler):
@gen.coroutine
def prepare(self):
@ -962,6 +1082,15 @@ class AsyncPrepareErrorHandler(RequestHandler):
self.finish('ok')
class NativeCoroutineHandler(RequestHandler):
if sys.version_info > (3, 5):
exec(textwrap.dedent("""
async def get(self):
await gen.Task(IOLoop.current().add_callback)
self.write("ok")
"""))
class GenWebTest(AsyncHTTPTestCase):
def get_app(self):
return Application([
@ -975,6 +1104,7 @@ class GenWebTest(AsyncHTTPTestCase):
('/yield_exception', GenYieldExceptionHandler),
('/undecorated_coroutine', UndecoratedCoroutinesHandler),
('/async_prepare_error', AsyncPrepareErrorHandler),
('/native_coroutine', NativeCoroutineHandler),
])
def test_sequence_handler(self):
@ -1017,6 +1147,12 @@ class GenWebTest(AsyncHTTPTestCase):
response = self.fetch('/async_prepare_error')
self.assertEqual(response.code, 403)
@skipBefore35
def test_native_coroutine_handler(self):
response = self.fetch('/native_coroutine')
self.assertEqual(response.code, 200)
self.assertEqual(response.body, b'ok')
class WithTimeoutTest(AsyncTestCase):
@gen_test
@ -1031,7 +1167,7 @@ class WithTimeoutTest(AsyncTestCase):
self.io_loop.add_timeout(datetime.timedelta(seconds=0.1),
lambda: future.set_result('asdf'))
result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
future)
future, io_loop=self.io_loop)
self.assertEqual(result, 'asdf')
@gen_test
@ -1039,16 +1175,17 @@ class WithTimeoutTest(AsyncTestCase):
future = Future()
self.io_loop.add_timeout(
datetime.timedelta(seconds=0.1),
lambda: future.set_exception(ZeroDivisionError))
lambda: future.set_exception(ZeroDivisionError()))
with self.assertRaises(ZeroDivisionError):
yield gen.with_timeout(datetime.timedelta(seconds=3600), future)
yield gen.with_timeout(datetime.timedelta(seconds=3600),
future, io_loop=self.io_loop)
@gen_test
def test_already_resolved(self):
future = Future()
future.set_result('asdf')
result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
future)
future, io_loop=self.io_loop)
self.assertEqual(result, 'asdf')
@unittest.skipIf(futures is None, 'futures module not present')
@ -1067,5 +1204,156 @@ class WithTimeoutTest(AsyncTestCase):
executor.submit(lambda: None))
class WaitIteratorTest(AsyncTestCase):
@gen_test
def test_empty_iterator(self):
g = gen.WaitIterator()
self.assertTrue(g.done(), 'empty generator iterated')
with self.assertRaises(ValueError):
g = gen.WaitIterator(False, bar=False)
self.assertEqual(g.current_index, None, "bad nil current index")
self.assertEqual(g.current_future, None, "bad nil current future")
@gen_test
def test_already_done(self):
f1 = Future()
f2 = Future()
f3 = Future()
f1.set_result(24)
f2.set_result(42)
f3.set_result(84)
g = gen.WaitIterator(f1, f2, f3)
i = 0
while not g.done():
r = yield g.next()
# Order is not guaranteed, but the current implementation
# preserves ordering of already-done Futures.
if i == 0:
self.assertEqual(g.current_index, 0)
self.assertIs(g.current_future, f1)
self.assertEqual(r, 24)
elif i == 1:
self.assertEqual(g.current_index, 1)
self.assertIs(g.current_future, f2)
self.assertEqual(r, 42)
elif i == 2:
self.assertEqual(g.current_index, 2)
self.assertIs(g.current_future, f3)
self.assertEqual(r, 84)
i += 1
self.assertEqual(g.current_index, None, "bad nil current index")
self.assertEqual(g.current_future, None, "bad nil current future")
dg = gen.WaitIterator(f1=f1, f2=f2)
while not dg.done():
dr = yield dg.next()
if dg.current_index == "f1":
self.assertTrue(dg.current_future == f1 and dr == 24,
"WaitIterator dict status incorrect")
elif dg.current_index == "f2":
self.assertTrue(dg.current_future == f2 and dr == 42,
"WaitIterator dict status incorrect")
else:
self.fail("got bad WaitIterator index {}".format(
dg.current_index))
i += 1
self.assertEqual(dg.current_index, None, "bad nil current index")
self.assertEqual(dg.current_future, None, "bad nil current future")
def finish_coroutines(self, iteration, futures):
if iteration == 3:
futures[2].set_result(24)
elif iteration == 5:
futures[0].set_exception(ZeroDivisionError())
elif iteration == 8:
futures[1].set_result(42)
futures[3].set_result(84)
if iteration < 8:
self.io_loop.add_callback(self.finish_coroutines, iteration + 1, futures)
@gen_test
def test_iterator(self):
futures = [Future(), Future(), Future(), Future()]
self.finish_coroutines(0, futures)
g = gen.WaitIterator(*futures)
i = 0
while not g.done():
try:
r = yield g.next()
except ZeroDivisionError:
self.assertIs(g.current_future, futures[0],
'exception future invalid')
else:
if i == 0:
self.assertEqual(r, 24, 'iterator value incorrect')
self.assertEqual(g.current_index, 2, 'wrong index')
elif i == 2:
self.assertEqual(r, 42, 'iterator value incorrect')
self.assertEqual(g.current_index, 1, 'wrong index')
elif i == 3:
self.assertEqual(r, 84, 'iterator value incorrect')
self.assertEqual(g.current_index, 3, 'wrong index')
i += 1
@skipBefore35
@gen_test
def test_iterator_async_await(self):
# Recreate the previous test with py35 syntax. It's a little clunky
# because of the way the previous test handles an exception on
# a single iteration.
futures = [Future(), Future(), Future(), Future()]
self.finish_coroutines(0, futures)
self.finished = False
namespace = exec_test(globals(), locals(), """
async def f():
i = 0
g = gen.WaitIterator(*futures)
try:
async for r in g:
if i == 0:
self.assertEqual(r, 24, 'iterator value incorrect')
self.assertEqual(g.current_index, 2, 'wrong index')
else:
raise Exception("expected exception on iteration 1")
i += 1
except ZeroDivisionError:
i += 1
async for r in g:
if i == 2:
self.assertEqual(r, 42, 'iterator value incorrect')
self.assertEqual(g.current_index, 1, 'wrong index')
elif i == 3:
self.assertEqual(r, 84, 'iterator value incorrect')
self.assertEqual(g.current_index, 3, 'wrong index')
else:
raise Exception("didn't expect iteration %d" % i)
i += 1
self.finished = True
""")
yield namespace['f']()
self.assertTrue(self.finished)
@gen_test
def test_no_ref(self):
# In this usage, there is no direct hard reference to the
# WaitIterator itself, only the Future it returns. Since
# WaitIterator uses weak references internally to improve GC
# performance, this used to cause problems.
yield gen.with_timeout(datetime.timedelta(seconds=0.1),
gen.WaitIterator(gen.sleep(0)).next())
if __name__ == '__main__':
unittest.main()

View file

@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2012-06-14 01:10-0700\n"
"POT-Creation-Date: 2015-01-27 11:05+0300\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n"
@ -16,7 +16,32 @@ msgstr ""
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Plural-Forms: nplurals=2; plural=(n > 1);\n"
#: extract_me.py:1
#: extract_me.py:11
msgid "school"
msgstr "école"
#: extract_me.py:12
msgctxt "law"
msgid "right"
msgstr "le droit"
#: extract_me.py:13
msgctxt "good"
msgid "right"
msgstr "le bien"
#: extract_me.py:14
msgctxt "organization"
msgid "club"
msgid_plural "clubs"
msgstr[0] "le club"
msgstr[1] "les clubs"
#: extract_me.py:15
msgctxt "stick"
msgid "club"
msgid_plural "clubs"
msgstr[0] "le bâton"
msgstr[1] "les bâtons"

View file

@ -5,11 +5,15 @@ from __future__ import absolute_import, division, print_function, with_statement
import base64
import binascii
from contextlib import closing
import copy
import functools
import sys
import threading
import datetime
from io import BytesIO
from tornado.escape import utf8
from tornado import gen
from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
@ -19,13 +23,9 @@ from tornado import netutil
from tornado.stack_context import ExceptionStackContext, NullContext
from tornado.testing import AsyncHTTPTestCase, bind_unused_port, gen_test, ExpectLog
from tornado.test.util import unittest, skipOnTravis
from tornado.util import u, bytes_type
from tornado.util import u
from tornado.web import Application, RequestHandler, url
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO
from tornado.httputil import format_timestamp, HTTPHeaders
class HelloWorldHandler(RequestHandler):
@ -41,10 +41,26 @@ class PostHandler(RequestHandler):
self.get_argument("arg1"), self.get_argument("arg2")))
class PutHandler(RequestHandler):
def put(self):
self.write("Put body: ")
self.write(self.request.body)
class RedirectHandler(RequestHandler):
def prepare(self):
self.write('redirects can have bodies too')
self.redirect(self.get_argument("url"),
status=int(self.get_argument("status", "302")))
class ChunkHandler(RequestHandler):
@gen.coroutine
def get(self):
self.write("asdf")
self.flush()
# Wait a bit to ensure the chunks are sent and received separately.
yield gen.sleep(0.01)
self.write("qwer")
@ -83,6 +99,13 @@ class ContentLength304Handler(RequestHandler):
pass
class PatchHandler(RequestHandler):
def patch(self):
"Return the request payload - so we can check it is being kept"
self.write(self.request.body)
class AllMethodsHandler(RequestHandler):
SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ('OTHER',)
@ -101,6 +124,8 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase):
return Application([
url("/hello", HelloWorldHandler),
url("/post", PostHandler),
url("/put", PutHandler),
url("/redirect", RedirectHandler),
url("/chunk", ChunkHandler),
url("/auth", AuthHandler),
url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
@ -108,8 +133,15 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase):
url("/user_agent", UserAgentHandler),
url("/304_with_content_length", ContentLength304Handler),
url("/all_methods", AllMethodsHandler),
url('/patch', PatchHandler),
], gzip=True)
def test_patch_receives_payload(self):
body = b"some patch data"
response = self.fetch("/patch", method='PATCH', body=body)
self.assertEqual(response.code, 200)
self.assertEqual(response.body, body)
@skipOnTravis
def test_hello_world(self):
response = self.fetch("/hello")
@ -152,6 +184,8 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase):
sock, port = bind_unused_port()
with closing(sock):
def write_response(stream, request_data):
if b"HTTP/1." not in request_data:
self.skipTest("requires HTTP/1.x")
stream.write(b"""\
HTTP/1.1 200 OK
Transfer-Encoding: chunked
@ -263,7 +297,7 @@ Transfer-Encoding: chunked
def test_types(self):
response = self.fetch("/hello")
self.assertEqual(type(response.body), bytes_type)
self.assertEqual(type(response.body), bytes)
self.assertEqual(type(response.headers["Content-Type"]), str)
self.assertEqual(type(response.code), int)
self.assertEqual(type(response.effective_url), str)
@ -274,23 +308,26 @@ Transfer-Encoding: chunked
chunks = []
def header_callback(header_line):
if header_line.startswith('HTTP/'):
if header_line.startswith('HTTP/1.1 101'):
# Upgrading to HTTP/2
pass
elif header_line.startswith('HTTP/'):
first_line.append(header_line)
elif header_line != '\r\n':
k, v = header_line.split(':', 1)
headers[k] = v.strip()
headers[k.lower()] = v.strip()
def streaming_callback(chunk):
# All header callbacks are run before any streaming callbacks,
# so the header data is available to process the data as it
# comes in.
self.assertEqual(headers['Content-Type'], 'text/html; charset=UTF-8')
self.assertEqual(headers['content-type'], 'text/html; charset=UTF-8')
chunks.append(chunk)
self.fetch('/chunk', header_callback=header_callback,
streaming_callback=streaming_callback)
self.assertEqual(len(first_line), 1)
self.assertRegexpMatches(first_line[0], 'HTTP/1.[01] 200 OK\r\n')
self.assertEqual(len(first_line), 1, first_line)
self.assertRegexpMatches(first_line[0], 'HTTP/[0-9]\\.[0-9] 200.*\r\n')
self.assertEqual(chunks, [b'asdf', b'qwer'])
def test_header_callback_stack_context(self):
@ -301,7 +338,7 @@ Transfer-Encoding: chunked
return True
def header_callback(header_line):
if header_line.startswith('Content-Type:'):
if header_line.lower().startswith('content-type:'):
1 / 0
with ExceptionStackContext(error_handler):
@ -314,10 +351,53 @@ Transfer-Encoding: chunked
# Construct a new instance of the configured client class
client = self.http_client.__class__(self.io_loop, force_instance=True,
defaults=defaults)
client.fetch(self.get_url('/user_agent'), callback=self.stop)
response = self.wait()
self.assertEqual(response.body, b'TestDefaultUserAgent')
client.close()
try:
client.fetch(self.get_url('/user_agent'), callback=self.stop)
response = self.wait()
self.assertEqual(response.body, b'TestDefaultUserAgent')
finally:
client.close()
def test_header_types(self):
# Header values may be passed as character or utf8 byte strings,
# in a plain dictionary or an HTTPHeaders object.
# Keys must always be the native str type.
# All combinations should have the same results on the wire.
for value in [u("MyUserAgent"), b"MyUserAgent"]:
for container in [dict, HTTPHeaders]:
headers = container()
headers['User-Agent'] = value
resp = self.fetch('/user_agent', headers=headers)
self.assertEqual(
resp.body, b"MyUserAgent",
"response=%r, value=%r, container=%r" %
(resp.body, value, container))
def test_multi_line_headers(self):
# Multi-line http headers are rare but rfc-allowed
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
sock, port = bind_unused_port()
with closing(sock):
def write_response(stream, request_data):
if b"HTTP/1." not in request_data:
self.skipTest("requires HTTP/1.x")
stream.write(b"""\
HTTP/1.1 200 OK
X-XSS-Protection: 1;
\tmode=block
""".replace(b"\n", b"\r\n"), callback=stream.close)
def accept_callback(conn, address):
stream = IOStream(conn, io_loop=self.io_loop)
stream.read_until(b"\r\n\r\n",
functools.partial(write_response, stream))
netutil.add_accept_handler(sock, accept_callback, self.io_loop)
self.http_client.fetch("http://127.0.0.1:%d/" % port, self.stop)
resp = self.wait()
resp.rethrow()
self.assertEqual(resp.headers['X-XSS-Protection'], "1; mode=block")
self.io_loop.remove_handler(sock.fileno())
def test_304_with_content_length(self):
# According to the spec 304 responses SHOULD NOT include
@ -361,6 +441,11 @@ Transfer-Encoding: chunked
self.assertEqual(context.exception.code, 404)
self.assertEqual(context.exception.response.code, 404)
@gen_test
def test_future_http_error_no_raise(self):
response = yield self.http_client.fetch(self.get_url('/notfound'), raise_error=False)
self.assertEqual(response.code, 404)
@gen_test
def test_reuse_request_from_response(self):
# The response.request attribute should be an HTTPRequest, not
@ -387,18 +472,54 @@ Transfer-Encoding: chunked
allow_nonstandard_methods=True)
self.assertEqual(response.body, b'OTHER')
@gen_test
def test_body(self):
hello_url = self.get_url('/hello')
with self.assertRaises(AssertionError) as context:
yield self.http_client.fetch(hello_url, body='data')
def test_body_sanity_checks(self):
# These methods require a body.
for method in ('POST', 'PUT', 'PATCH'):
with self.assertRaises(ValueError) as context:
resp = self.fetch('/all_methods', method=method)
resp.rethrow()
self.assertIn('must not be None', str(context.exception))
self.assertTrue('must be empty' in str(context.exception))
resp = self.fetch('/all_methods', method=method,
allow_nonstandard_methods=True)
self.assertEqual(resp.code, 200)
with self.assertRaises(AssertionError) as context:
yield self.http_client.fetch(hello_url, method='POST')
# These methods don't allow a body.
for method in ('GET', 'DELETE', 'OPTIONS'):
with self.assertRaises(ValueError) as context:
resp = self.fetch('/all_methods', method=method, body=b'asdf')
resp.rethrow()
self.assertIn('must be None', str(context.exception))
self.assertTrue('must not be empty' in str(context.exception))
# In most cases this can be overridden, but curl_httpclient
# does not allow body with a GET at all.
if method != 'GET':
resp = self.fetch('/all_methods', method=method, body=b'asdf',
allow_nonstandard_methods=True)
resp.rethrow()
self.assertEqual(resp.code, 200)
# This test causes odd failures with the combination of
# curl_httpclient (at least with the version of libcurl available
# on ubuntu 12.04), TwistedIOLoop, and epoll. For POST (but not PUT),
# curl decides the response came back too soon and closes the connection
# to start again. It does this *before* telling the socket callback to
# unregister the FD. Some IOLoop implementations have special kernel
# integration to discover this immediately. Tornado's IOLoops
# ignore errors on remove_handler to accommodate this behavior, but
# Twisted's reactor does not. The removeReader call fails and so
# do all future removeAll calls (which our tests do at cleanup).
#
# def test_post_307(self):
# response = self.fetch("/redirect?status=307&url=/post",
# method="POST", body=b"arg1=foo&arg2=bar")
# self.assertEqual(response.body, b"Post arg1: foo, arg2: bar")
def test_put_307(self):
response = self.fetch("/redirect?status=307&url=/put",
method="PUT", body=b"hello")
response.rethrow()
self.assertEqual(response.body, b"Put body: hello")
class RequestProxyTest(unittest.TestCase):
@ -471,14 +592,19 @@ class SyncHTTPClientTest(unittest.TestCase):
def tearDown(self):
def stop_server():
self.server.stop()
self.server_ioloop.stop()
# Delay the shutdown of the IOLoop by one iteration because
# the server may still have some cleanup work left when
# the client finishes with the response (this is noticable
# with http/2, which leaves a Future with an unexamined
# StreamClosedError on the loop).
self.server_ioloop.add_callback(self.server_ioloop.stop)
self.server_ioloop.add_callback(stop_server)
self.server_thread.join()
self.http_client.close()
self.server_ioloop.close(all_fds=True)
def get_url(self, path):
return 'http://localhost:%d%s' % (self.port, path)
return 'http://127.0.0.1:%d%s' % (self.port, path)
def test_sync_client(self):
response = self.http_client.fetch(self.get_url('/'))
@ -515,3 +641,21 @@ class HTTPRequestTestCase(unittest.TestCase):
request = HTTPRequest('http://example.com')
request.body = 'foo'
self.assertEqual(request.body, utf8('foo'))
def test_if_modified_since(self):
http_date = datetime.datetime.utcnow()
request = HTTPRequest('http://example.com', if_modified_since=http_date)
self.assertEqual(request.headers,
{'If-Modified-Since': format_timestamp(http_date)})
class HTTPErrorTestCase(unittest.TestCase):
def test_copy(self):
e = HTTPError(403)
e2 = copy.copy(e)
self.assertIsNot(e, e2)
self.assertEqual(e.code, e2.code)
def test_str(self):
e = HTTPError(403)
self.assertEqual(str(e), "HTTP 403: Forbidden")

View file

@ -9,12 +9,12 @@ from tornado.http1connection import HTTP1Connection
from tornado.httpserver import HTTPServer
from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine
from tornado.iostream import IOStream
from tornado.log import gen_log, app_log
from tornado.log import gen_log
from tornado.netutil import ssl_options_to_context
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test
from tornado.test.util import unittest, skipOnTravis
from tornado.util import u, bytes_type
from tornado.util import u
from tornado.web import Application, RequestHandler, asynchronous, stream_request_body
from contextlib import closing
import datetime
@ -25,17 +25,14 @@ import socket
import ssl
import sys
import tempfile
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO # python 2
from io import BytesIO
def read_stream_body(stream, callback):
"""Reads an HTTP response from `stream` and runs callback with its
headers and body."""
chunks = []
class Delegate(HTTPMessageDelegate):
def headers_received(self, start_line, headers):
self.headers = headers
@ -120,6 +117,16 @@ class SSLTestMixin(object):
response = self.wait()
self.assertEqual(response.code, 599)
def test_error_logging(self):
# No stack traces are logged for SSL errors.
with ExpectLog(gen_log, 'SSL Error') as expect_log:
self.http_client.fetch(
self.get_url("/").replace("https:", "http:"),
self.stop)
response = self.wait()
self.assertEqual(response.code, 599)
self.assertFalse(expect_log.logged_stack)
# Python's SSL implementation differs significantly between versions.
# For example, SSLv3 and TLSv1 throw an exception if you try to read
# from the socket before the handshake is complete, but the default
@ -165,19 +172,22 @@ class BadSSLOptionsTest(unittest.TestCase):
application = Application()
module_dir = os.path.dirname(__file__)
existing_certificate = os.path.join(module_dir, 'test.crt')
existing_key = os.path.join(module_dir, 'test.key')
self.assertRaises(ValueError, HTTPServer, application, ssl_options={
"certfile": "/__mising__.crt",
})
self.assertRaises(ValueError, HTTPServer, application, ssl_options={
"certfile": existing_certificate,
"keyfile": "/__missing__.key"
})
self.assertRaises((ValueError, IOError),
HTTPServer, application, ssl_options={
"certfile": "/__mising__.crt",
})
self.assertRaises((ValueError, IOError),
HTTPServer, application, ssl_options={
"certfile": existing_certificate,
"keyfile": "/__missing__.key"
})
# This actually works because both files exist
HTTPServer(application, ssl_options={
"certfile": existing_certificate,
"keyfile": existing_certificate
"keyfile": existing_key,
})
@ -199,14 +209,14 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
def get_app(self):
return Application(self.get_handlers())
def raw_fetch(self, headers, body):
def raw_fetch(self, headers, body, newline=b"\r\n"):
with closing(IOStream(socket.socket())) as stream:
stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
self.wait()
stream.write(
b"\r\n".join(headers +
[utf8("Content-Length: %d\r\n" % len(body))]) +
b"\r\n" + body)
newline.join(headers +
[utf8("Content-Length: %d" % len(body))]) +
newline + newline + body)
read_stream_body(stream, self.stop)
headers, body = self.wait()
return body
@ -236,12 +246,19 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
self.assertEqual(u("\u00f3"), data["filename"])
self.assertEqual(u("\u00fa"), data["filebody"])
def test_newlines(self):
# We support both CRLF and bare LF as line separators.
for newline in (b"\r\n", b"\n"):
response = self.raw_fetch([b"GET /hello HTTP/1.0"], b"",
newline=newline)
self.assertEqual(response, b'Hello world')
def test_100_continue(self):
# Run through a 100-continue interaction by hand:
# When given Expect: 100-continue, we get a 100 response after the
# headers, and then the real response after the body.
stream = IOStream(socket.socket(), io_loop=self.io_loop)
stream.connect(("localhost", self.get_http_port()), callback=self.stop)
stream.connect(("127.0.0.1", self.get_http_port()), callback=self.stop)
self.wait()
stream.write(b"\r\n".join([b"POST /hello HTTP/1.1",
b"Content-Length: 1024",
@ -297,10 +314,10 @@ class TypeCheckHandler(RequestHandler):
# secure cookies
self.check_type('arg_key', list(self.request.arguments.keys())[0], str)
self.check_type('arg_value', list(self.request.arguments.values())[0][0], bytes_type)
self.check_type('arg_value', list(self.request.arguments.values())[0][0], bytes)
def post(self):
self.check_type('body', self.request.body, bytes_type)
self.check_type('body', self.request.body, bytes)
self.write(self.errors)
def get(self):
@ -358,7 +375,7 @@ class HTTPServerTest(AsyncHTTPTestCase):
# if the data is not utf8. On python 2 parse_qs will work,
# but then the recursive_unicode call in EchoHandler will
# fail.
if str is bytes_type:
if str is bytes:
return
with ExpectLog(gen_log, 'Invalid x-www-form-urlencoded body'):
response = self.fetch(
@ -378,7 +395,7 @@ class HTTPServerRawTest(AsyncHTTPTestCase):
def setUp(self):
super(HTTPServerRawTest, self).setUp()
self.stream = IOStream(socket.socket())
self.stream.connect(('localhost', self.get_http_port()), self.stop)
self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
self.wait()
def tearDown(self):
@ -559,7 +576,7 @@ class UnixSocketTest(AsyncTestCase):
self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
self.stream.read_until(b"\r\n", self.stop)
response = self.wait()
self.assertEqual(response, b"HTTP/1.0 200 OK\r\n")
self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
self.stream.read_until(b"\r\n\r\n", self.stop)
headers = HTTPHeaders.parse(self.wait().decode('latin1'))
self.stream.read_bytes(int(headers["Content-Length"]), self.stop)
@ -587,6 +604,9 @@ class KeepAliveTest(AsyncHTTPTestCase):
def get(self):
self.finish('Hello world')
def post(self):
self.finish('Hello world')
class LargeHandler(RequestHandler):
def get(self):
# 512KB should be bigger than the socket buffers so it will
@ -625,13 +645,13 @@ class KeepAliveTest(AsyncHTTPTestCase):
# The next few methods are a crude manual http client
def connect(self):
self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
self.stream.connect(('localhost', self.get_http_port()), self.stop)
self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
self.wait()
def read_headers(self):
self.stream.read_until(b'\r\n', self.stop)
first_line = self.wait()
self.assertTrue(first_line.startswith(self.http_version + b' 200'), first_line)
self.assertTrue(first_line.startswith(b'HTTP/1.1 200'), first_line)
self.stream.read_until(b'\r\n\r\n', self.stop)
header_bytes = self.wait()
headers = HTTPHeaders.parse(header_bytes.decode('latin1'))
@ -687,6 +707,17 @@ class KeepAliveTest(AsyncHTTPTestCase):
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.close()
def test_http10_keepalive_extra_crlf(self):
self.http_version = b'HTTP/1.0'
self.connect()
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n')
self.read_response()
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
self.read_response()
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.close()
def test_pipelined_requests(self):
self.connect()
self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
@ -715,6 +746,19 @@ class KeepAliveTest(AsyncHTTPTestCase):
self.read_headers()
self.close()
def test_keepalive_chunked(self):
self.http_version = b'HTTP/1.0'
self.connect()
self.stream.write(b'POST / HTTP/1.0\r\nConnection: keep-alive\r\n'
b'Transfer-Encoding: chunked\r\n'
b'\r\n0\r\n')
self.read_response()
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
self.read_response()
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.close()
class GzipBaseTest(object):
def get_app(self):
@ -786,8 +830,8 @@ class StreamingChunkSizeTest(AsyncHTTPTestCase):
def get_app(self):
class App(HTTPServerConnectionDelegate):
def start_request(self, connection):
return StreamingChunkSizeTest.MessageDelegate(connection)
def start_request(self, server_conn, request_conn):
return StreamingChunkSizeTest.MessageDelegate(request_conn)
return App()
def fetch_chunk_sizes(self, **kwargs):
@ -834,6 +878,7 @@ class StreamingChunkSizeTest(AsyncHTTPTestCase):
def test_chunked_compressed(self):
compressed = self.compress(self.BODY)
self.assertGreater(len(compressed), 20)
def body_producer(write):
write(compressed[:20])
write(compressed[20:])
@ -854,9 +899,12 @@ class MaxHeaderSizeTest(AsyncHTTPTestCase):
self.assertEqual(response.body, b"Hello world")
def test_large_headers(self):
with ExpectLog(gen_log, "Unsatisfiable read"):
with ExpectLog(gen_log, "Unsatisfiable read", required=False):
response = self.fetch("/", headers={'X-Filler': 'a' * 1000})
self.assertEqual(response.code, 599)
# 431 is "Request Header Fields Too Large", defined in RFC
# 6585. However, many implementations just close the
# connection in this case, resulting in a 599.
self.assertIn(response.code, (431, 599))
@skipOnTravis
@ -878,7 +926,7 @@ class IdleTimeoutTest(AsyncHTTPTestCase):
def connect(self):
stream = IOStream(socket.socket())
stream.connect(('localhost', self.get_http_port()), self.stop)
stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
self.wait()
self.streams.append(stream)
return stream
@ -1023,6 +1071,15 @@ class LegacyInterfaceTest(AsyncHTTPTestCase):
# delegate interface, and writes its response via request.write
# instead of request.connection.write_headers.
def handle_request(request):
self.http1 = request.version.startswith("HTTP/1.")
if not self.http1:
# This test will be skipped if we're using HTTP/2,
# so just close it out cleanly using the modern interface.
request.connection.write_headers(
ResponseStartLine('', 200, 'OK'),
HTTPHeaders())
request.connection.finish()
return
message = b"Hello world"
request.write(utf8("HTTP/1.1 200 OK\r\n"
"Content-Length: %d\r\n\r\n" % len(message)))
@ -1032,4 +1089,6 @@ class LegacyInterfaceTest(AsyncHTTPTestCase):
def test_legacy_interface(self):
response = self.fetch('/')
if not self.http1:
self.skipTest("requires HTTP/1.x")
self.assertEqual(response.body, b"Hello world")

View file

@ -2,19 +2,21 @@
from __future__ import absolute_import, division, print_function, with_statement
from tornado.httputil import url_concat, parse_multipart_form_data, HTTPHeaders, format_timestamp
from tornado.escape import utf8
from tornado.httputil import url_concat, parse_multipart_form_data, HTTPHeaders, format_timestamp, HTTPServerRequest, parse_request_start_line
from tornado.escape import utf8, native_str
from tornado.log import gen_log
from tornado.testing import ExpectLog
from tornado.test.util import unittest
from tornado.util import u
import copy
import datetime
import logging
import pickle
import time
class TestUrlConcat(unittest.TestCase):
def test_url_concat_no_query_params(self):
url = url_concat(
"https://localhost/path",
@ -228,6 +230,95 @@ Foo: even
("Foo", "bar baz"),
("Foo", "even more lines")])
def test_unicode_newlines(self):
# Ensure that only \r\n is recognized as a header separator, and not
# the other newline-like unicode characters.
# Characters that are likely to be problematic can be found in
# http://unicode.org/standard/reports/tr13/tr13-5.html
# and cpython's unicodeobject.c (which defines the implementation
# of unicode_type.splitlines(), and uses a different list than TR13).
newlines = [
u('\u001b'), # VERTICAL TAB
u('\u001c'), # FILE SEPARATOR
u('\u001d'), # GROUP SEPARATOR
u('\u001e'), # RECORD SEPARATOR
u('\u0085'), # NEXT LINE
u('\u2028'), # LINE SEPARATOR
u('\u2029'), # PARAGRAPH SEPARATOR
]
for newline in newlines:
# Try the utf8 and latin1 representations of each newline
for encoding in ['utf8', 'latin1']:
try:
try:
encoded = newline.encode(encoding)
except UnicodeEncodeError:
# Some chars cannot be represented in latin1
continue
data = b'Cookie: foo=' + encoded + b'bar'
# parse() wants a native_str, so decode through latin1
# in the same way the real parser does.
headers = HTTPHeaders.parse(
native_str(data.decode('latin1')))
expected = [('Cookie', 'foo=' +
native_str(encoded.decode('latin1')) + 'bar')]
self.assertEqual(
expected, list(headers.get_all()))
except Exception:
gen_log.warning("failed while trying %r in %s",
newline, encoding)
raise
def test_optional_cr(self):
# Both CRLF and LF should be accepted as separators. CR should not be
# part of the data when followed by LF, but it is a normal char
# otherwise (or should bare CR be an error?)
headers = HTTPHeaders.parse(
'CRLF: crlf\r\nLF: lf\nCR: cr\rMore: more\r\n')
self.assertEqual(sorted(headers.get_all()),
[('Cr', 'cr\rMore: more'),
('Crlf', 'crlf'),
('Lf', 'lf'),
])
def test_copy(self):
all_pairs = [('A', '1'), ('A', '2'), ('B', 'c')]
h1 = HTTPHeaders()
for k, v in all_pairs:
h1.add(k, v)
h2 = h1.copy()
h3 = copy.copy(h1)
h4 = copy.deepcopy(h1)
for headers in [h1, h2, h3, h4]:
# All the copies are identical, no matter how they were
# constructed.
self.assertEqual(list(sorted(headers.get_all())), all_pairs)
for headers in [h2, h3, h4]:
# Neither the dict or its member lists are reused.
self.assertIsNot(headers, h1)
self.assertIsNot(headers.get_list('A'), h1.get_list('A'))
def test_pickle_roundtrip(self):
headers = HTTPHeaders()
headers.add('Set-Cookie', 'a=b')
headers.add('Set-Cookie', 'c=d')
headers.add('Content-Type', 'text/html')
pickled = pickle.dumps(headers)
unpickled = pickle.loads(pickled)
self.assertEqual(sorted(headers.get_all()), sorted(unpickled.get_all()))
self.assertEqual(sorted(headers.items()), sorted(unpickled.items()))
def test_setdefault(self):
headers = HTTPHeaders()
headers['foo'] = 'bar'
# If a value is present, setdefault returns it without changes.
self.assertEqual(headers.setdefault('foo', 'baz'), 'bar')
self.assertEqual(headers['foo'], 'bar')
# If a value is not present, setdefault sets it for future use.
self.assertEqual(headers.setdefault('quux', 'xyzzy'), 'xyzzy')
self.assertEqual(headers['quux'], 'xyzzy')
self.assertEqual(sorted(headers.get_all()), [('Foo', 'bar'), ('Quux', 'xyzzy')])
class FormatTimestampTest(unittest.TestCase):
# Make sure that all the input types are supported.
@ -253,3 +344,30 @@ class FormatTimestampTest(unittest.TestCase):
def test_datetime(self):
self.check(datetime.datetime.utcfromtimestamp(self.TIMESTAMP))
# HTTPServerRequest is mainly tested incidentally to the server itself,
# but this tests the parts of the class that can be tested in isolation.
class HTTPServerRequestTest(unittest.TestCase):
def test_default_constructor(self):
# All parameters are formally optional, but uri is required
# (and has been for some time). This test ensures that no
# more required parameters slip in.
HTTPServerRequest(uri='/')
def test_body_is_a_byte_string(self):
requets = HTTPServerRequest(uri='/')
self.assertIsInstance(requets.body, bytes)
class ParseRequestStartLineTest(unittest.TestCase):
METHOD = "GET"
PATH = "/foo"
VERSION = "HTTP/1.1"
def test_parse_request_start_line(self):
start_line = " ".join([self.METHOD, self.PATH, self.VERSION])
parsed_start_line = parse_request_start_line(start_line)
self.assertEqual(parsed_start_line.method, self.METHOD)
self.assertEqual(parsed_start_line.path, self.PATH)
self.assertEqual(parsed_start_line.version, self.VERSION)

View file

@ -1,3 +1,4 @@
# flake8: noqa
from __future__ import absolute_import, division, print_function, with_statement
from tornado.test.util import unittest

View file

@ -11,11 +11,12 @@ import threading
import time
from tornado import gen
from tornado.ioloop import IOLoop, TimeoutError
from tornado.ioloop import IOLoop, TimeoutError, PollIOLoop, PeriodicCallback
from tornado.log import app_log
from tornado.platform.select import _Select
from tornado.stack_context import ExceptionStackContext, StackContext, wrap, NullContext
from tornado.testing import AsyncTestCase, bind_unused_port, ExpectLog
from tornado.test.util import unittest, skipIfNonUnix, skipOnTravis
from tornado.test.util import unittest, skipIfNonUnix, skipOnTravis, skipBefore35, exec_test
try:
from concurrent import futures
@ -23,6 +24,42 @@ except ImportError:
futures = None
class FakeTimeSelect(_Select):
def __init__(self):
self._time = 1000
super(FakeTimeSelect, self).__init__()
def time(self):
return self._time
def sleep(self, t):
self._time += t
def poll(self, timeout):
events = super(FakeTimeSelect, self).poll(0)
if events:
return events
self._time += timeout
return []
class FakeTimeIOLoop(PollIOLoop):
"""IOLoop implementation with a fake and deterministic clock.
The clock advances as needed to trigger timeouts immediately.
For use when testing code that involves the passage of time
and no external dependencies.
"""
def initialize(self):
self.fts = FakeTimeSelect()
super(FakeTimeIOLoop, self).initialize(impl=self.fts,
time_func=self.fts.time)
def sleep(self, t):
"""Simulate a blocking sleep by advancing the clock."""
self.fts.sleep(t)
class TestIOLoop(AsyncTestCase):
@skipOnTravis
def test_add_callback_wakeup(self):
@ -173,6 +210,27 @@ class TestIOLoop(AsyncTestCase):
self.io_loop.add_callback(lambda: self.io_loop.add_callback(self.stop))
self.wait()
def test_remove_timeout_from_timeout(self):
calls = [False, False]
# Schedule several callbacks and wait for them all to come due at once.
# t2 should be cancelled by t1, even though it is already scheduled to
# be run before the ioloop even looks at it.
now = self.io_loop.time()
def t1():
calls[0] = True
self.io_loop.remove_timeout(t2_handle)
self.io_loop.add_timeout(now + 0.01, t1)
def t2():
calls[1] = True
t2_handle = self.io_loop.add_timeout(now + 0.02, t2)
self.io_loop.add_timeout(now + 0.03, self.stop)
time.sleep(0.03)
self.wait()
self.assertEqual(calls, [True, False])
def test_timeout_with_arguments(self):
# This tests that all the timeout methods pass through *args correctly.
results = []
@ -185,6 +243,23 @@ class TestIOLoop(AsyncTestCase):
self.wait()
self.assertEqual(results, [1, 2, 3, 4])
def test_add_timeout_return(self):
# All the timeout methods return non-None handles that can be
# passed to remove_timeout.
handle = self.io_loop.add_timeout(self.io_loop.time(), lambda: None)
self.assertFalse(handle is None)
self.io_loop.remove_timeout(handle)
def test_call_at_return(self):
handle = self.io_loop.call_at(self.io_loop.time(), lambda: None)
self.assertFalse(handle is None)
self.io_loop.remove_timeout(handle)
def test_call_later_return(self):
handle = self.io_loop.call_later(0, lambda: None)
self.assertFalse(handle is None)
self.io_loop.remove_timeout(handle)
def test_close_file_object(self):
"""When a file object is used instead of a numeric file descriptor,
the object should be closed (by IOLoop.close(all_fds=True),
@ -216,6 +291,7 @@ class TestIOLoop(AsyncTestCase):
"""The handler callback receives the same fd object it passed in."""
server_sock, port = bind_unused_port()
fds = []
def handle_connection(fd, events):
fds.append(fd)
conn, addr = server_sock.accept()
@ -238,6 +314,7 @@ class TestIOLoop(AsyncTestCase):
def test_mixed_fd_fileobj(self):
server_sock, port = bind_unused_port()
def f(fd, events):
pass
self.io_loop.add_handler(server_sock, f, IOLoop.READ)
@ -252,6 +329,7 @@ class TestIOLoop(AsyncTestCase):
"""Calling start() twice should raise an error, not deadlock."""
returned_from_start = [False]
got_exception = [False]
def callback():
try:
self.io_loop.start()
@ -269,7 +347,7 @@ class TestIOLoop(AsyncTestCase):
# Use a NullContext to keep the exception from being caught by
# AsyncTestCase.
with NullContext():
self.io_loop.add_callback(lambda: 1/0)
self.io_loop.add_callback(lambda: 1 / 0)
self.io_loop.add_callback(self.stop)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
@ -280,41 +358,109 @@ class TestIOLoop(AsyncTestCase):
@gen.coroutine
def callback():
self.io_loop.add_callback(self.stop)
1/0
1 / 0
self.io_loop.add_callback(callback)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
@skipBefore35
def test_exception_logging_native_coro(self):
"""The IOLoop examines exceptions from awaitables and logs them."""
namespace = exec_test(globals(), locals(), """
async def callback():
self.io_loop.add_callback(self.stop)
1 / 0
""")
with NullContext():
self.io_loop.add_callback(namespace["callback"])
with ExpectLog(app_log, "Exception in callback"):
self.wait()
def test_spawn_callback(self):
# An added callback runs in the test's stack_context, so will be
# re-arised in wait().
self.io_loop.add_callback(lambda: 1/0)
self.io_loop.add_callback(lambda: 1 / 0)
with self.assertRaises(ZeroDivisionError):
self.wait()
# A spawned callback is run directly on the IOLoop, so it will be
# logged without stopping the test.
self.io_loop.spawn_callback(lambda: 1/0)
self.io_loop.spawn_callback(lambda: 1 / 0)
self.io_loop.add_callback(self.stop)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
@skipIfNonUnix
def test_remove_handler_from_handler(self):
# Create two sockets with simultaneous read events.
client, server = socket.socketpair()
try:
client.send(b'abc')
server.send(b'abc')
# After reading from one fd, remove the other from the IOLoop.
chunks = []
def handle_read(fd, events):
chunks.append(fd.recv(1024))
if fd is client:
self.io_loop.remove_handler(server)
else:
self.io_loop.remove_handler(client)
self.io_loop.add_handler(client, handle_read, self.io_loop.READ)
self.io_loop.add_handler(server, handle_read, self.io_loop.READ)
self.io_loop.call_later(0.03, self.stop)
self.wait()
# Only one fd was read; the other was cleanly removed.
self.assertEqual(chunks, [b'abc'])
finally:
client.close()
server.close()
# Deliberately not a subclass of AsyncTestCase so the IOLoop isn't
# automatically set as current.
class TestIOLoopCurrent(unittest.TestCase):
def setUp(self):
self.io_loop = IOLoop()
self.io_loop = None
IOLoop.clear_current()
def tearDown(self):
self.io_loop.close()
if self.io_loop is not None:
self.io_loop.close()
def test_current(self):
def f():
self.current_io_loop = IOLoop.current()
self.io_loop.stop()
self.io_loop.add_callback(f)
self.io_loop.start()
self.assertIs(self.current_io_loop, self.io_loop)
def test_default_current(self):
self.io_loop = IOLoop()
# The first IOLoop with default arguments is made current.
self.assertIs(self.io_loop, IOLoop.current())
# A second IOLoop can be created but is not made current.
io_loop2 = IOLoop()
self.assertIs(self.io_loop, IOLoop.current())
io_loop2.close()
def test_non_current(self):
self.io_loop = IOLoop(make_current=False)
# The new IOLoop is not initially made current.
self.assertIsNone(IOLoop.current(instance=False))
# Starting the IOLoop makes it current, and stopping the loop
# makes it non-current. This process is repeatable.
for i in range(3):
def f():
self.current_io_loop = IOLoop.current()
self.io_loop.stop()
self.io_loop.add_callback(f)
self.io_loop.start()
self.assertIs(self.current_io_loop, self.io_loop)
# Now that the loop is stopped, it is no longer current.
self.assertIsNone(IOLoop.current(instance=False))
def test_force_current(self):
self.io_loop = IOLoop(make_current=True)
self.assertIs(self.io_loop, IOLoop.current())
with self.assertRaises(RuntimeError):
# A second make_current=True construction cannot succeed.
IOLoop(make_current=True)
# current() was not affected by the failed construction.
self.assertIs(self.io_loop, IOLoop.current())
class TestIOLoopAddCallback(AsyncTestCase):
@ -424,7 +570,8 @@ class TestIOLoopRunSync(unittest.TestCase):
self.io_loop.close()
def test_sync_result(self):
self.assertEqual(self.io_loop.run_sync(lambda: 42), 42)
with self.assertRaises(gen.BadYieldError):
self.io_loop.run_sync(lambda: 42)
def test_sync_exception(self):
with self.assertRaises(ZeroDivisionError):
@ -456,6 +603,56 @@ class TestIOLoopRunSync(unittest.TestCase):
yield gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1)
self.assertRaises(TimeoutError, self.io_loop.run_sync, f, timeout=0.01)
@skipBefore35
def test_native_coroutine(self):
namespace = exec_test(globals(), locals(), """
async def f():
await gen.Task(self.io_loop.add_callback)
""")
self.io_loop.run_sync(namespace['f'])
class TestPeriodicCallback(unittest.TestCase):
def setUp(self):
self.io_loop = FakeTimeIOLoop()
self.io_loop.make_current()
def tearDown(self):
self.io_loop.close()
def test_basic(self):
calls = []
def cb():
calls.append(self.io_loop.time())
pc = PeriodicCallback(cb, 10000)
pc.start()
self.io_loop.call_later(50, self.io_loop.stop)
self.io_loop.start()
self.assertEqual(calls, [1010, 1020, 1030, 1040, 1050])
def test_overrun(self):
sleep_durations = [9, 9, 10, 11, 20, 20, 35, 35, 0, 0]
expected = [
1010, 1020, 1030, # first 3 calls on schedule
1050, 1070, # next 2 delayed one cycle
1100, 1130, # next 2 delayed 2 cycles
1170, 1210, # next 2 delayed 3 cycles
1220, 1230, # then back on schedule.
]
calls = []
def cb():
calls.append(self.io_loop.time())
if not sleep_durations:
self.io_loop.stop()
return
self.io_loop.sleep(sleep_durations.pop(0))
pc = PeriodicCallback(cb, 10000)
pc.start()
self.io_loop.start()
self.assertEqual(calls, expected)
if __name__ == "__main__":
unittest.main()

View file

@ -7,10 +7,10 @@ from tornado.httputil import HTTPHeaders
from tornado.log import gen_log, app_log
from tornado.netutil import ssl_wrap_socket
from tornado.stack_context import NullContext
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test
from tornado.test.util import unittest, skipIfNonUnix
from tornado.test.util import unittest, skipIfNonUnix, refusing_port
from tornado.web import RequestHandler, Application
import certifi
import errno
import logging
import os
@ -19,6 +19,14 @@ import socket
import ssl
import sys
try:
from unittest import mock # python 3.3
except ImportError:
try:
import mock # third-party mock package
except ImportError:
mock = None
def _server_ssl_options():
return dict(
@ -51,18 +59,18 @@ class TestIOStreamWebMixin(object):
def test_read_until_close(self):
stream = self._make_client_iostream()
stream.connect(('localhost', self.get_http_port()), callback=self.stop)
stream.connect(('127.0.0.1', self.get_http_port()), callback=self.stop)
self.wait()
stream.write(b"GET / HTTP/1.0\r\n\r\n")
stream.read_until_close(self.stop)
data = self.wait()
self.assertTrue(data.startswith(b"HTTP/1.0 200"))
self.assertTrue(data.startswith(b"HTTP/1.1 200"))
self.assertTrue(data.endswith(b"Hello"))
def test_read_zero_bytes(self):
self.stream = self._make_client_iostream()
self.stream.connect(("localhost", self.get_http_port()),
self.stream.connect(("127.0.0.1", self.get_http_port()),
callback=self.stop)
self.wait()
self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
@ -70,7 +78,7 @@ class TestIOStreamWebMixin(object):
# normal read
self.stream.read_bytes(9, self.stop)
data = self.wait()
self.assertEqual(data, b"HTTP/1.0 ")
self.assertEqual(data, b"HTTP/1.1 ")
# zero bytes
self.stream.read_bytes(0, self.stop)
@ -91,7 +99,7 @@ class TestIOStreamWebMixin(object):
def connected_callback():
connected[0] = True
self.stop()
stream.connect(("localhost", self.get_http_port()),
stream.connect(("127.0.0.1", self.get_http_port()),
callback=connected_callback)
# unlike the previous tests, try to write before the connection
# is complete.
@ -121,11 +129,11 @@ class TestIOStreamWebMixin(object):
"""Basic test of IOStream's ability to return Futures."""
stream = self._make_client_iostream()
connect_result = yield stream.connect(
("localhost", self.get_http_port()))
("127.0.0.1", self.get_http_port()))
self.assertIs(connect_result, stream)
yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
first_line = yield stream.read_until(b"\r\n")
self.assertEqual(first_line, b"HTTP/1.0 200 OK\r\n")
self.assertEqual(first_line, b"HTTP/1.1 200 OK\r\n")
# callback=None is equivalent to no callback.
header_data = yield stream.read_until(b"\r\n\r\n", callback=None)
headers = HTTPHeaders.parse(header_data.decode('latin1'))
@ -137,7 +145,7 @@ class TestIOStreamWebMixin(object):
@gen_test
def test_future_close_while_reading(self):
stream = self._make_client_iostream()
yield stream.connect(("localhost", self.get_http_port()))
yield stream.connect(("127.0.0.1", self.get_http_port()))
yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
with self.assertRaises(StreamClosedError):
yield stream.read_bytes(1024 * 1024)
@ -147,7 +155,7 @@ class TestIOStreamWebMixin(object):
def test_future_read_until_close(self):
# Ensure that the data comes through before the StreamClosedError.
stream = self._make_client_iostream()
yield stream.connect(("localhost", self.get_http_port()))
yield stream.connect(("127.0.0.1", self.get_http_port()))
yield stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n")
yield stream.read_until(b"\r\n\r\n")
body = yield stream.read_until_close()
@ -217,17 +225,18 @@ class TestIOStreamMixin(object):
# When a connection is refused, the connect callback should not
# be run. (The kqueue IOLoop used to behave differently from the
# epoll IOLoop in this respect)
server_socket, port = bind_unused_port()
server_socket.close()
cleanup_func, port = refusing_port()
self.addCleanup(cleanup_func)
stream = IOStream(socket.socket(), self.io_loop)
self.connect_called = False
def connect_callback():
self.connect_called = True
self.stop()
stream.set_close_callback(self.stop)
# log messages vary by platform and ioloop implementation
with ExpectLog(gen_log, ".*", required=False):
stream.connect(("localhost", port), connect_callback)
stream.connect(("127.0.0.1", port), connect_callback)
self.wait()
self.assertFalse(self.connect_called)
self.assertTrue(isinstance(stream.error, socket.error), stream.error)
@ -238,18 +247,22 @@ class TestIOStreamMixin(object):
# cygwin's errnos don't match those used on native windows python
self.assertTrue(stream.error.args[0] in _ERRNO_CONNREFUSED)
@unittest.skipIf(mock is None, 'mock package not present')
def test_gaierror(self):
# Test that IOStream sets its exc_info on getaddrinfo error
# Test that IOStream sets its exc_info on getaddrinfo error.
# It's difficult to reliably trigger a getaddrinfo error;
# some resolvers own't even return errors for malformed names,
# so we mock it instead. If IOStream changes to call a Resolver
# before sock.connect, the mock target will need to change too.
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
stream = IOStream(s, io_loop=self.io_loop)
stream.set_close_callback(self.stop)
# To reliably generate a gaierror we use a malformed domain name
# instead of a name that's simply unlikely to exist (since
# opendns and some ISPs return bogus addresses for nonexistent
# domains instead of the proper error codes).
with ExpectLog(gen_log, "Connect error"):
stream.connect(('an invalid domain', 54321))
self.assertTrue(isinstance(stream.error, socket.gaierror), stream.error)
with mock.patch('socket.socket.connect',
side_effect=socket.gaierror('boom')):
with ExpectLog(gen_log, "Connect error"):
stream.connect(('localhost', 80), callback=self.stop)
self.wait()
self.assertIsInstance(stream.error, socket.gaierror)
def test_read_callback_error(self):
# Test that IOStream sets its exc_info when a read callback throws
@ -308,6 +321,7 @@ class TestIOStreamMixin(object):
def streaming_callback(data):
chunks.append(data)
self.stop()
def close_callback(data):
assert not data, data
closed[0] = True
@ -325,6 +339,31 @@ class TestIOStreamMixin(object):
server.close()
client.close()
def test_streaming_until_close_future(self):
server, client = self.make_iostream_pair()
try:
chunks = []
@gen.coroutine
def client_task():
yield client.read_until_close(streaming_callback=chunks.append)
@gen.coroutine
def server_task():
yield server.write(b"1234")
yield gen.sleep(0.01)
yield server.write(b"5678")
server.close()
@gen.coroutine
def f():
yield [client_task(), server_task()]
self.io_loop.run_sync(f)
self.assertEqual(chunks, [b"1234", b"5678"])
finally:
server.close()
client.close()
def test_delayed_close_callback(self):
# The scenario: Server closes the connection while there is a pending
# read that can be served out of buffered data. The client does not
@ -353,6 +392,7 @@ class TestIOStreamMixin(object):
def test_future_delayed_close_callback(self):
# Same as test_delayed_close_callback, but with the future interface.
server, client = self.make_iostream_pair()
# We can't call make_iostream_pair inside a gen_test function
# because the ioloop is not reentrant.
@gen_test
@ -417,6 +457,18 @@ class TestIOStreamMixin(object):
server.close()
client.close()
@unittest.skipIf(mock is None, 'mock package not present')
def test_read_until_close_with_error(self):
server, client = self.make_iostream_pair()
try:
with mock.patch('tornado.iostream.BaseIOStream._try_inline_read',
side_effect=IOError('boom')):
with self.assertRaisesRegexp(IOError, 'boom'):
client.read_until_close(self.stop)
finally:
server.close()
client.close()
def test_streaming_read_until_close_after_close(self):
# Same as the preceding test but with a streaming_callback.
# All data should go through the streaming callback,
@ -511,7 +563,7 @@ class TestIOStreamMixin(object):
server, client = self.make_iostream_pair()
server.set_close_callback(self.stop)
try:
# Start a read that will be fullfilled asynchronously.
# Start a read that will be fulfilled asynchronously.
server.read_bytes(1, lambda data: None)
client.write(b'a')
# Stub out read_from_fd to make it fail.
@ -532,6 +584,7 @@ class TestIOStreamMixin(object):
# and IOStream._maybe_add_error_listener.
server, client = self.make_iostream_pair()
closed = [False]
def close_callback():
closed[0] = True
self.stop()
@ -724,6 +777,26 @@ class TestIOStreamMixin(object):
server.close()
client.close()
def test_flow_control(self):
MB = 1024 * 1024
server, client = self.make_iostream_pair(max_buffer_size=5 * MB)
try:
# Client writes more than the server will accept.
client.write(b"a" * 10 * MB)
# The server pauses while reading.
server.read_bytes(MB, self.stop)
self.wait()
self.io_loop.call_later(0.1, self.stop)
self.wait()
# The client's writes have been blocked; the server can
# continue to read gradually.
for i in range(9):
server.read_bytes(MB, self.stop)
self.wait()
finally:
server.close()
client.close()
class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase):
def _make_client_iostream(self):
@ -732,7 +805,8 @@ class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase):
class TestIOStreamWebHTTPS(TestIOStreamWebMixin, AsyncHTTPSTestCase):
def _make_client_iostream(self):
return SSLIOStream(socket.socket(), io_loop=self.io_loop)
return SSLIOStream(socket.socket(), io_loop=self.io_loop,
ssl_options=dict(cert_reqs=ssl.CERT_NONE))
class TestIOStream(TestIOStreamMixin, AsyncTestCase):
@ -752,7 +826,9 @@ class TestIOStreamSSL(TestIOStreamMixin, AsyncTestCase):
return SSLIOStream(connection, io_loop=self.io_loop, **kwargs)
def _make_client_iostream(self, connection, **kwargs):
return SSLIOStream(connection, io_loop=self.io_loop, **kwargs)
return SSLIOStream(connection, io_loop=self.io_loop,
ssl_options=dict(cert_reqs=ssl.CERT_NONE),
**kwargs)
# This will run some tests that are basically redundant but it's the
@ -820,10 +896,10 @@ class TestIOStreamStartTLS(AsyncTestCase):
recv_line = yield self.client_stream.read_until(b"\r\n")
self.assertEqual(line, recv_line)
def client_start_tls(self, ssl_options=None):
def client_start_tls(self, ssl_options=None, server_hostname=None):
client_stream = self.client_stream
self.client_stream = None
return client_stream.start_tls(False, ssl_options)
return client_stream.start_tls(False, ssl_options, server_hostname)
def server_start_tls(self, ssl_options=None):
server_stream = self.server_stream
@ -842,7 +918,7 @@ class TestIOStreamStartTLS(AsyncTestCase):
yield self.server_send_line(b"250 STARTTLS\r\n")
yield self.client_send_line(b"STARTTLS\r\n")
yield self.server_send_line(b"220 Go ahead\r\n")
client_future = self.client_start_tls()
client_future = self.client_start_tls(dict(cert_reqs=ssl.CERT_NONE))
server_future = self.server_start_tls(_server_ssl_options())
self.client_stream = yield client_future
self.server_stream = yield server_future
@ -853,12 +929,125 @@ class TestIOStreamStartTLS(AsyncTestCase):
@gen_test
def test_handshake_fail(self):
self.server_start_tls(_server_ssl_options())
client_future = self.client_start_tls(
dict(cert_reqs=ssl.CERT_REQUIRED, ca_certs=certifi.where()))
server_future = self.server_start_tls(_server_ssl_options())
# Certificates are verified with the default configuration.
client_future = self.client_start_tls(server_hostname="localhost")
with ExpectLog(gen_log, "SSL Error"):
with self.assertRaises(ssl.SSLError):
yield client_future
with self.assertRaises((ssl.SSLError, socket.error)):
yield server_future
@unittest.skipIf(not hasattr(ssl, 'create_default_context'),
'ssl.create_default_context not present')
@gen_test
def test_check_hostname(self):
# Test that server_hostname parameter to start_tls is being used.
# The check_hostname functionality is only available in python 2.7 and
# up and in python 3.4 and up.
server_future = self.server_start_tls(_server_ssl_options())
client_future = self.client_start_tls(
ssl.create_default_context(),
server_hostname=b'127.0.0.1')
with ExpectLog(gen_log, "SSL Error"):
with self.assertRaises(ssl.SSLError):
# The client fails to connect with an SSL error.
yield client_future
with self.assertRaises(Exception):
# The server fails to connect, but the exact error is unspecified.
yield server_future
class WaitForHandshakeTest(AsyncTestCase):
@gen.coroutine
def connect_to_server(self, server_cls):
server = client = None
try:
sock, port = bind_unused_port()
server = server_cls(ssl_options=_server_ssl_options())
server.add_socket(sock)
client = SSLIOStream(socket.socket(),
ssl_options=dict(cert_reqs=ssl.CERT_NONE))
yield client.connect(('127.0.0.1', port))
self.assertIsNotNone(client.socket.cipher())
finally:
if server is not None:
server.stop()
if client is not None:
client.close()
@gen_test
def test_wait_for_handshake_callback(self):
test = self
handshake_future = Future()
class TestServer(TCPServer):
def handle_stream(self, stream, address):
# The handshake has not yet completed.
test.assertIsNone(stream.socket.cipher())
self.stream = stream
stream.wait_for_handshake(self.handshake_done)
def handshake_done(self):
# Now the handshake is done and ssl information is available.
test.assertIsNotNone(self.stream.socket.cipher())
handshake_future.set_result(None)
yield self.connect_to_server(TestServer)
yield handshake_future
@gen_test
def test_wait_for_handshake_future(self):
test = self
handshake_future = Future()
class TestServer(TCPServer):
def handle_stream(self, stream, address):
test.assertIsNone(stream.socket.cipher())
test.io_loop.spawn_callback(self.handle_connection, stream)
@gen.coroutine
def handle_connection(self, stream):
yield stream.wait_for_handshake()
handshake_future.set_result(None)
yield self.connect_to_server(TestServer)
yield handshake_future
@gen_test
def test_wait_for_handshake_already_waiting_error(self):
test = self
handshake_future = Future()
class TestServer(TCPServer):
def handle_stream(self, stream, address):
stream.wait_for_handshake(self.handshake_done)
test.assertRaises(RuntimeError, stream.wait_for_handshake)
def handshake_done(self):
handshake_future.set_result(None)
yield self.connect_to_server(TestServer)
yield handshake_future
@gen_test
def test_wait_for_handshake_already_connected(self):
handshake_future = Future()
class TestServer(TCPServer):
def handle_stream(self, stream, address):
self.stream = stream
stream.wait_for_handshake(self.handshake_done)
def handshake_done(self):
self.stream.wait_for_handshake(self.handshake2_done)
def handshake2_done(self):
handshake_future.set_result(None)
yield self.connect_to_server(TestServer)
yield handshake_future
@skipIfNonUnix

View file

@ -2,9 +2,12 @@ from __future__ import absolute_import, division, print_function, with_statement
import datetime
import os
import shutil
import tempfile
import tornado.locale
from tornado.escape import utf8
from tornado.test.util import unittest
from tornado.escape import utf8, to_unicode
from tornado.test.util import unittest, skipOnAppEngine
from tornado.util import u, unicode_type
@ -34,6 +37,28 @@ class TranslationLoaderTest(unittest.TestCase):
self.assertTrue(isinstance(locale, tornado.locale.CSVLocale))
self.assertEqual(locale.translate("school"), u("\u00e9cole"))
# tempfile.mkdtemp is not available on app engine.
@skipOnAppEngine
def test_csv_bom(self):
with open(os.path.join(os.path.dirname(__file__), 'csv_translations',
'fr_FR.csv'), 'rb') as f:
char_data = to_unicode(f.read())
# Re-encode our input data (which is utf-8 without BOM) in
# encodings that use the BOM and ensure that we can still load
# it. Note that utf-16-le and utf-16-be do not write a BOM,
# so we only test whichver variant is native to our platform.
for encoding in ['utf-8-sig', 'utf-16']:
tmpdir = tempfile.mkdtemp()
try:
with open(os.path.join(tmpdir, 'fr_FR.csv'), 'wb') as f:
f.write(char_data.encode(encoding))
tornado.locale.load_translations(tmpdir)
locale = tornado.locale.get('fr_FR')
self.assertIsInstance(locale, tornado.locale.CSVLocale)
self.assertEqual(locale.translate("school"), u("\u00e9cole"))
finally:
shutil.rmtree(tmpdir)
def test_gettext(self):
tornado.locale.load_gettext_translations(
os.path.join(os.path.dirname(__file__), 'gettext_translations'),
@ -41,6 +66,12 @@ class TranslationLoaderTest(unittest.TestCase):
locale = tornado.locale.get("fr_FR")
self.assertTrue(isinstance(locale, tornado.locale.GettextLocale))
self.assertEqual(locale.translate("school"), u("\u00e9cole"))
self.assertEqual(locale.pgettext("law", "right"), u("le droit"))
self.assertEqual(locale.pgettext("good", "right"), u("le bien"))
self.assertEqual(locale.pgettext("organization", "club", "clubs", 1), u("le club"))
self.assertEqual(locale.pgettext("organization", "club", "clubs", 2), u("les clubs"))
self.assertEqual(locale.pgettext("stick", "club", "clubs", 1), u("le b\xe2ton"))
self.assertEqual(locale.pgettext("stick", "club", "clubs", 2), u("les b\xe2tons"))
class LocaleDataTest(unittest.TestCase):
@ -57,3 +88,43 @@ class EnglishTest(unittest.TestCase):
date = datetime.datetime(2013, 4, 28, 18, 35)
self.assertEqual(locale.format_date(date, full_format=True),
'April 28, 2013 at 6:35 pm')
self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(seconds=2), full_format=False),
'2 seconds ago')
self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(minutes=2), full_format=False),
'2 minutes ago')
self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(hours=2), full_format=False),
'2 hours ago')
now = datetime.datetime.utcnow()
self.assertEqual(locale.format_date(now - datetime.timedelta(days=1), full_format=False, shorter=True),
'yesterday')
date = now - datetime.timedelta(days=2)
self.assertEqual(locale.format_date(date, full_format=False, shorter=True),
locale._weekdays[date.weekday()])
date = now - datetime.timedelta(days=300)
self.assertEqual(locale.format_date(date, full_format=False, shorter=True),
'%s %d' % (locale._months[date.month - 1], date.day))
date = now - datetime.timedelta(days=500)
self.assertEqual(locale.format_date(date, full_format=False, shorter=True),
'%s %d, %d' % (locale._months[date.month - 1], date.day, date.year))
def test_friendly_number(self):
locale = tornado.locale.get('en_US')
self.assertEqual(locale.friendly_number(1000000), '1,000,000')
def test_list(self):
locale = tornado.locale.get('en_US')
self.assertEqual(locale.list([]), '')
self.assertEqual(locale.list(['A']), 'A')
self.assertEqual(locale.list(['A', 'B']), 'A and B')
self.assertEqual(locale.list(['A', 'B', 'C']), 'A, B and C')
def test_format_day(self):
locale = tornado.locale.get('en_US')
date = datetime.datetime(2013, 4, 28, 18, 35)
self.assertEqual(locale.format_day(date=date, dow=True), 'Sunday, April 28')
self.assertEqual(locale.format_day(date=date, dow=False), 'April 28')

View file

@ -0,0 +1,518 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, with_statement
from datetime import timedelta
from tornado import gen, locks
from tornado.gen import TimeoutError
from tornado.testing import gen_test, AsyncTestCase
from tornado.test.util import unittest, skipBefore35, exec_test
class ConditionTest(AsyncTestCase):
def setUp(self):
super(ConditionTest, self).setUp()
self.history = []
def record_done(self, future, key):
"""Record the resolution of a Future returned by Condition.wait."""
def callback(_):
if not future.result():
# wait() resolved to False, meaning it timed out.
self.history.append('timeout')
else:
self.history.append(key)
future.add_done_callback(callback)
def test_repr(self):
c = locks.Condition()
self.assertIn('Condition', repr(c))
self.assertNotIn('waiters', repr(c))
c.wait()
self.assertIn('waiters', repr(c))
@gen_test
def test_notify(self):
c = locks.Condition()
self.io_loop.call_later(0.01, c.notify)
yield c.wait()
def test_notify_1(self):
c = locks.Condition()
self.record_done(c.wait(), 'wait1')
self.record_done(c.wait(), 'wait2')
c.notify(1)
self.history.append('notify1')
c.notify(1)
self.history.append('notify2')
self.assertEqual(['wait1', 'notify1', 'wait2', 'notify2'],
self.history)
def test_notify_n(self):
c = locks.Condition()
for i in range(6):
self.record_done(c.wait(), i)
c.notify(3)
# Callbacks execute in the order they were registered.
self.assertEqual(list(range(3)), self.history)
c.notify(1)
self.assertEqual(list(range(4)), self.history)
c.notify(2)
self.assertEqual(list(range(6)), self.history)
def test_notify_all(self):
c = locks.Condition()
for i in range(4):
self.record_done(c.wait(), i)
c.notify_all()
self.history.append('notify_all')
# Callbacks execute in the order they were registered.
self.assertEqual(
list(range(4)) + ['notify_all'],
self.history)
@gen_test
def test_wait_timeout(self):
c = locks.Condition()
wait = c.wait(timedelta(seconds=0.01))
self.io_loop.call_later(0.02, c.notify) # Too late.
yield gen.sleep(0.03)
self.assertFalse((yield wait))
@gen_test
def test_wait_timeout_preempted(self):
c = locks.Condition()
# This fires before the wait times out.
self.io_loop.call_later(0.01, c.notify)
wait = c.wait(timedelta(seconds=0.02))
yield gen.sleep(0.03)
yield wait # No TimeoutError.
@gen_test
def test_notify_n_with_timeout(self):
# Register callbacks 0, 1, 2, and 3. Callback 1 has a timeout.
# Wait for that timeout to expire, then do notify(2) and make
# sure everyone runs. Verifies that a timed-out callback does
# not count against the 'n' argument to notify().
c = locks.Condition()
self.record_done(c.wait(), 0)
self.record_done(c.wait(timedelta(seconds=0.01)), 1)
self.record_done(c.wait(), 2)
self.record_done(c.wait(), 3)
# Wait for callback 1 to time out.
yield gen.sleep(0.02)
self.assertEqual(['timeout'], self.history)
c.notify(2)
yield gen.sleep(0.01)
self.assertEqual(['timeout', 0, 2], self.history)
self.assertEqual(['timeout', 0, 2], self.history)
c.notify()
self.assertEqual(['timeout', 0, 2, 3], self.history)
@gen_test
def test_notify_all_with_timeout(self):
c = locks.Condition()
self.record_done(c.wait(), 0)
self.record_done(c.wait(timedelta(seconds=0.01)), 1)
self.record_done(c.wait(), 2)
# Wait for callback 1 to time out.
yield gen.sleep(0.02)
self.assertEqual(['timeout'], self.history)
c.notify_all()
self.assertEqual(['timeout', 0, 2], self.history)
@gen_test
def test_nested_notify(self):
# Ensure no notifications lost, even if notify() is reentered by a
# waiter calling notify().
c = locks.Condition()
# Three waiters.
futures = [c.wait() for _ in range(3)]
# First and second futures resolved. Second future reenters notify(),
# resolving third future.
futures[1].add_done_callback(lambda _: c.notify())
c.notify(2)
self.assertTrue(all(f.done() for f in futures))
@gen_test
def test_garbage_collection(self):
# Test that timed-out waiters are occasionally cleaned from the queue.
c = locks.Condition()
for _ in range(101):
c.wait(timedelta(seconds=0.01))
future = c.wait()
self.assertEqual(102, len(c._waiters))
# Let first 101 waiters time out, triggering a collection.
yield gen.sleep(0.02)
self.assertEqual(1, len(c._waiters))
# Final waiter is still active.
self.assertFalse(future.done())
c.notify()
self.assertTrue(future.done())
class EventTest(AsyncTestCase):
def test_repr(self):
event = locks.Event()
self.assertTrue('clear' in str(event))
self.assertFalse('set' in str(event))
event.set()
self.assertFalse('clear' in str(event))
self.assertTrue('set' in str(event))
def test_event(self):
e = locks.Event()
future_0 = e.wait()
e.set()
future_1 = e.wait()
e.clear()
future_2 = e.wait()
self.assertTrue(future_0.done())
self.assertTrue(future_1.done())
self.assertFalse(future_2.done())
@gen_test
def test_event_timeout(self):
e = locks.Event()
with self.assertRaises(TimeoutError):
yield e.wait(timedelta(seconds=0.01))
# After a timed-out waiter, normal operation works.
self.io_loop.add_timeout(timedelta(seconds=0.01), e.set)
yield e.wait(timedelta(seconds=1))
def test_event_set_multiple(self):
e = locks.Event()
e.set()
e.set()
self.assertTrue(e.is_set())
def test_event_wait_clear(self):
e = locks.Event()
f0 = e.wait()
e.clear()
f1 = e.wait()
e.set()
self.assertTrue(f0.done())
self.assertTrue(f1.done())
class SemaphoreTest(AsyncTestCase):
def test_negative_value(self):
self.assertRaises(ValueError, locks.Semaphore, value=-1)
def test_repr(self):
sem = locks.Semaphore()
self.assertIn('Semaphore', repr(sem))
self.assertIn('unlocked,value:1', repr(sem))
sem.acquire()
self.assertIn('locked', repr(sem))
self.assertNotIn('waiters', repr(sem))
sem.acquire()
self.assertIn('waiters', repr(sem))
def test_acquire(self):
sem = locks.Semaphore()
f0 = sem.acquire()
self.assertTrue(f0.done())
# Wait for release().
f1 = sem.acquire()
self.assertFalse(f1.done())
f2 = sem.acquire()
sem.release()
self.assertTrue(f1.done())
self.assertFalse(f2.done())
sem.release()
self.assertTrue(f2.done())
sem.release()
# Now acquire() is instant.
self.assertTrue(sem.acquire().done())
self.assertEqual(0, len(sem._waiters))
@gen_test
def test_acquire_timeout(self):
sem = locks.Semaphore(2)
yield sem.acquire()
yield sem.acquire()
acquire = sem.acquire(timedelta(seconds=0.01))
self.io_loop.call_later(0.02, sem.release) # Too late.
yield gen.sleep(0.3)
with self.assertRaises(gen.TimeoutError):
yield acquire
sem.acquire()
f = sem.acquire()
self.assertFalse(f.done())
sem.release()
self.assertTrue(f.done())
@gen_test
def test_acquire_timeout_preempted(self):
sem = locks.Semaphore(1)
yield sem.acquire()
# This fires before the wait times out.
self.io_loop.call_later(0.01, sem.release)
acquire = sem.acquire(timedelta(seconds=0.02))
yield gen.sleep(0.03)
yield acquire # No TimeoutError.
def test_release_unacquired(self):
# Unbounded releases are allowed, and increment the semaphore's value.
sem = locks.Semaphore()
sem.release()
sem.release()
# Now the counter is 3. We can acquire three times before blocking.
self.assertTrue(sem.acquire().done())
self.assertTrue(sem.acquire().done())
self.assertTrue(sem.acquire().done())
self.assertFalse(sem.acquire().done())
@gen_test
def test_garbage_collection(self):
# Test that timed-out waiters are occasionally cleaned from the queue.
sem = locks.Semaphore(value=0)
futures = [sem.acquire(timedelta(seconds=0.01)) for _ in range(101)]
future = sem.acquire()
self.assertEqual(102, len(sem._waiters))
# Let first 101 waiters time out, triggering a collection.
yield gen.sleep(0.02)
self.assertEqual(1, len(sem._waiters))
# Final waiter is still active.
self.assertFalse(future.done())
sem.release()
self.assertTrue(future.done())
# Prevent "Future exception was never retrieved" messages.
for future in futures:
self.assertRaises(TimeoutError, future.result)
class SemaphoreContextManagerTest(AsyncTestCase):
@gen_test
def test_context_manager(self):
sem = locks.Semaphore()
with (yield sem.acquire()) as yielded:
self.assertTrue(yielded is None)
# Semaphore was released and can be acquired again.
self.assertTrue(sem.acquire().done())
@skipBefore35
@gen_test
def test_context_manager_async_await(self):
# Repeat the above test using 'async with'.
sem = locks.Semaphore()
namespace = exec_test(globals(), locals(), """
async def f():
async with sem as yielded:
self.assertTrue(yielded is None)
""")
yield namespace['f']()
# Semaphore was released and can be acquired again.
self.assertTrue(sem.acquire().done())
@gen_test
def test_context_manager_exception(self):
sem = locks.Semaphore()
with self.assertRaises(ZeroDivisionError):
with (yield sem.acquire()):
1 / 0
# Semaphore was released and can be acquired again.
self.assertTrue(sem.acquire().done())
@gen_test
def test_context_manager_timeout(self):
sem = locks.Semaphore()
with (yield sem.acquire(timedelta(seconds=0.01))):
pass
# Semaphore was released and can be acquired again.
self.assertTrue(sem.acquire().done())
@gen_test
def test_context_manager_timeout_error(self):
sem = locks.Semaphore(value=0)
with self.assertRaises(gen.TimeoutError):
with (yield sem.acquire(timedelta(seconds=0.01))):
pass
# Counter is still 0.
self.assertFalse(sem.acquire().done())
@gen_test
def test_context_manager_contended(self):
sem = locks.Semaphore()
history = []
@gen.coroutine
def f(index):
with (yield sem.acquire()):
history.append('acquired %d' % index)
yield gen.sleep(0.01)
history.append('release %d' % index)
yield [f(i) for i in range(2)]
expected_history = []
for i in range(2):
expected_history.extend(['acquired %d' % i, 'release %d' % i])
self.assertEqual(expected_history, history)
@gen_test
def test_yield_sem(self):
# Ensure we catch a "with (yield sem)", which should be
# "with (yield sem.acquire())".
with self.assertRaises(gen.BadYieldError):
with (yield locks.Semaphore()):
pass
def test_context_manager_misuse(self):
# Ensure we catch a "with sem", which should be
# "with (yield sem.acquire())".
with self.assertRaises(RuntimeError):
with locks.Semaphore():
pass
class BoundedSemaphoreTest(AsyncTestCase):
def test_release_unacquired(self):
sem = locks.BoundedSemaphore()
self.assertRaises(ValueError, sem.release)
# Value is 0.
sem.acquire()
# Block on acquire().
future = sem.acquire()
self.assertFalse(future.done())
sem.release()
self.assertTrue(future.done())
# Value is 1.
sem.release()
self.assertRaises(ValueError, sem.release)
class LockTests(AsyncTestCase):
def test_repr(self):
lock = locks.Lock()
# No errors.
repr(lock)
lock.acquire()
repr(lock)
def test_acquire_release(self):
lock = locks.Lock()
self.assertTrue(lock.acquire().done())
future = lock.acquire()
self.assertFalse(future.done())
lock.release()
self.assertTrue(future.done())
@gen_test
def test_acquire_fifo(self):
lock = locks.Lock()
self.assertTrue(lock.acquire().done())
N = 5
history = []
@gen.coroutine
def f(idx):
with (yield lock.acquire()):
history.append(idx)
futures = [f(i) for i in range(N)]
self.assertFalse(any(future.done() for future in futures))
lock.release()
yield futures
self.assertEqual(list(range(N)), history)
@skipBefore35
@gen_test
def test_acquire_fifo_async_with(self):
# Repeat the above test using `async with lock:`
# instead of `with (yield lock.acquire()):`.
lock = locks.Lock()
self.assertTrue(lock.acquire().done())
N = 5
history = []
namespace = exec_test(globals(), locals(), """
async def f(idx):
async with lock:
history.append(idx)
""")
futures = [namespace['f'](i) for i in range(N)]
lock.release()
yield futures
self.assertEqual(list(range(N)), history)
@gen_test
def test_acquire_timeout(self):
lock = locks.Lock()
lock.acquire()
with self.assertRaises(gen.TimeoutError):
yield lock.acquire(timeout=timedelta(seconds=0.01))
# Still locked.
self.assertFalse(lock.acquire().done())
def test_multi_release(self):
lock = locks.Lock()
self.assertRaises(RuntimeError, lock.release)
lock.acquire()
lock.release()
self.assertRaises(RuntimeError, lock.release)
@gen_test
def test_yield_lock(self):
# Ensure we catch a "with (yield lock)", which should be
# "with (yield lock.acquire())".
with self.assertRaises(gen.BadYieldError):
with (yield locks.Lock()):
pass
def test_context_manager_misuse(self):
# Ensure we catch a "with lock", which should be
# "with (yield lock.acquire())".
with self.assertRaises(RuntimeError):
with locks.Lock():
pass
if __name__ == '__main__':
unittest.main()

View file

@ -29,7 +29,7 @@ from tornado.escape import utf8
from tornado.log import LogFormatter, define_logging_options, enable_pretty_logging
from tornado.options import OptionParser
from tornado.test.util import unittest
from tornado.util import u, bytes_type, basestring_type
from tornado.util import u, basestring_type
@contextlib.contextmanager
@ -95,8 +95,9 @@ class LogFormatterTest(unittest.TestCase):
self.assertEqual(self.get_output(), utf8(repr(b"\xe9")))
def test_utf8_logging(self):
self.logger.error(u("\u00e9").encode("utf8"))
if issubclass(bytes_type, basestring_type):
with ignore_bytes_warning():
self.logger.error(u("\u00e9").encode("utf8"))
if issubclass(bytes, basestring_type):
# on python 2, utf8 byte strings (and by extension ascii byte
# strings) are passed through as-is.
self.assertEqual(self.get_output(), utf8(u("\u00e9")))
@ -159,6 +160,39 @@ class EnablePrettyLoggingTest(unittest.TestCase):
os.unlink(filename)
os.rmdir(tmpdir)
def test_log_file_with_timed_rotating(self):
tmpdir = tempfile.mkdtemp()
try:
self.options.log_file_prefix = tmpdir + '/test_log'
self.options.log_rotate_mode = 'time'
enable_pretty_logging(options=self.options, logger=self.logger)
self.logger.error('hello')
self.logger.handlers[0].flush()
filenames = glob.glob(tmpdir + '/test_log*')
self.assertEqual(1, len(filenames))
with open(filenames[0]) as f:
self.assertRegexpMatches(
f.read(),
r'^\[E [^]]*\] hello$')
finally:
for handler in self.logger.handlers:
handler.flush()
handler.close()
for filename in glob.glob(tmpdir + '/test_log*'):
os.unlink(filename)
os.rmdir(tmpdir)
def test_wrong_rotate_mode_value(self):
try:
self.options.log_file_prefix = 'some_path'
self.options.log_rotate_mode = 'wrong_mode'
self.assertRaises(ValueError, enable_pretty_logging,
options=self.options, logger=self.logger)
finally:
for handler in self.logger.handlers:
handler.flush()
handler.close()
class LoggingOptionTest(unittest.TestCase):
"""Test the ability to enable and disable Tornado's logging hooks."""

View file

@ -9,7 +9,7 @@ import time
from tornado.netutil import BlockingResolver, ThreadedResolver, is_valid_ip, bind_sockets
from tornado.stack_context import ExceptionStackContext
from tornado.testing import AsyncTestCase, gen_test
from tornado.testing import AsyncTestCase, gen_test, bind_unused_port
from tornado.test.util import unittest, skipIfNoNetwork
try:
@ -34,15 +34,6 @@ else:
class _ResolverTestMixin(object):
def skipOnCares(self):
# Some DNS-hijacking ISPs (e.g. Time Warner) return non-empty results
# with an NXDOMAIN status code. Most resolvers treat this as an error;
# C-ares returns the results, making the "bad_host" tests unreliable.
# C-ares will try to resolve even malformed names, such as the
# name with spaces used in this test.
if self.resolver.__class__.__name__ == 'CaresResolver':
self.skipTest("CaresResolver doesn't recognize fake NXDOMAIN")
def test_localhost(self):
self.resolver.resolve('localhost', 80, callback=self.stop)
result = self.wait()
@ -55,8 +46,11 @@ class _ResolverTestMixin(object):
self.assertIn((socket.AF_INET, ('127.0.0.1', 80)),
addrinfo)
# It is impossible to quickly and consistently generate an error in name
# resolution, so test this case separately, using mocks as needed.
class _ResolverErrorTestMixin(object):
def test_bad_host(self):
self.skipOnCares()
def handler(exc_typ, exc_val, exc_tb):
self.stop(exc_val)
return True # Halt propagation.
@ -69,12 +63,16 @@ class _ResolverTestMixin(object):
@gen_test
def test_future_interface_bad_host(self):
self.skipOnCares()
with self.assertRaises(Exception):
yield self.resolver.resolve('an invalid domain', 80,
socket.AF_UNSPEC)
def _failing_getaddrinfo(*args):
"""Dummy implementation of getaddrinfo for use in mocks"""
raise socket.gaierror("mock: lookup failed")
@skipIfNoNetwork
class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
@ -82,6 +80,21 @@ class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
self.resolver = BlockingResolver(io_loop=self.io_loop)
# getaddrinfo-based tests need mocking to reliably generate errors;
# some configurations are slow to produce errors and take longer than
# our default timeout.
class BlockingResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
def setUp(self):
super(BlockingResolverErrorTest, self).setUp()
self.resolver = BlockingResolver(io_loop=self.io_loop)
self.real_getaddrinfo = socket.getaddrinfo
socket.getaddrinfo = _failing_getaddrinfo
def tearDown(self):
socket.getaddrinfo = self.real_getaddrinfo
super(BlockingResolverErrorTest, self).tearDown()
@skipIfNoNetwork
@unittest.skipIf(futures is None, "futures module not present")
class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
@ -94,6 +107,18 @@ class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
super(ThreadedResolverTest, self).tearDown()
class ThreadedResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
def setUp(self):
super(ThreadedResolverErrorTest, self).setUp()
self.resolver = BlockingResolver(io_loop=self.io_loop)
self.real_getaddrinfo = socket.getaddrinfo
socket.getaddrinfo = _failing_getaddrinfo
def tearDown(self):
socket.getaddrinfo = self.real_getaddrinfo
super(ThreadedResolverErrorTest, self).tearDown()
@skipIfNoNetwork
@unittest.skipIf(futures is None, "futures module not present")
@unittest.skipIf(sys.platform == 'win32', "preexec_fn not available on win32")
@ -121,6 +146,12 @@ class ThreadedResolverImportTest(unittest.TestCase):
self.fail("import timed out")
# We do not test errors with CaresResolver:
# Some DNS-hijacking ISPs (e.g. Time Warner) return non-empty results
# with an NXDOMAIN status code. Most resolvers treat this as an error;
# C-ares returns the results, making the "bad_host" tests unreliable.
# C-ares will try to resolve even malformed names, such as the
# name with spaces used in this test.
@skipIfNoNetwork
@unittest.skipIf(pycares is None, "pycares module not present")
class CaresResolverTest(AsyncTestCase, _ResolverTestMixin):
@ -129,10 +160,13 @@ class CaresResolverTest(AsyncTestCase, _ResolverTestMixin):
self.resolver = CaresResolver(io_loop=self.io_loop)
# TwistedResolver produces consistent errors in our test cases so we
# can test the regular and error cases in the same class.
@skipIfNoNetwork
@unittest.skipIf(twisted is None, "twisted module not present")
@unittest.skipIf(getattr(twisted, '__version__', '0.0') < "12.1", "old version of twisted")
class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin):
class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin,
_ResolverErrorTestMixin):
def setUp(self):
super(TwistedResolverTest, self).setUp()
self.resolver = TwistedResolver(io_loop=self.io_loop)
@ -166,3 +200,14 @@ class TestPortAllocation(unittest.TestCase):
finally:
for sock in sockets:
sock.close()
@unittest.skipIf(not hasattr(socket, "SO_REUSEPORT"), "SO_REUSEPORT is not supported")
def test_reuse_port(self):
socket, port = bind_unused_port(reuse_port=True)
try:
sockets = bind_sockets(port, 'localhost', reuse_port=True)
self.assertTrue(all(s.getsockname()[1] == port for s in sockets))
finally:
socket.close()
for sock in sockets:
sock.close()

View file

@ -1,2 +1,5 @@
port=443
port=443
port=443
username='李康'
foo_bar='a'

View file

@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, with_statement
import datetime
@ -32,9 +33,11 @@ class OptionsTest(unittest.TestCase):
def test_parse_config_file(self):
options = OptionParser()
options.define("port", default=80)
options.define("username", default='foo')
options.parse_config_file(os.path.join(os.path.dirname(__file__),
"options_test.cfg"))
self.assertEquals(options.port, 443)
self.assertEqual(options.username, "李康")
def test_parse_callbacks(self):
options = OptionParser()
@ -218,3 +221,45 @@ class OptionsTest(unittest.TestCase):
options.define('foo')
self.assertRegexpMatches(str(cm.exception),
'Option.*foo.*already defined')
def test_dash_underscore_cli(self):
# Dashes and underscores should be interchangeable.
for defined_name in ['foo-bar', 'foo_bar']:
for flag in ['--foo-bar=a', '--foo_bar=a']:
options = OptionParser()
options.define(defined_name)
options.parse_command_line(['main.py', flag])
# Attr-style access always uses underscores.
self.assertEqual(options.foo_bar, 'a')
# Dict-style access allows both.
self.assertEqual(options['foo-bar'], 'a')
self.assertEqual(options['foo_bar'], 'a')
def test_dash_underscore_file(self):
# No matter how an option was defined, it can be set with underscores
# in a config file.
for defined_name in ['foo-bar', 'foo_bar']:
options = OptionParser()
options.define(defined_name)
options.parse_config_file(os.path.join(os.path.dirname(__file__),
"options_test.cfg"))
self.assertEqual(options.foo_bar, 'a')
def test_dash_underscore_introspection(self):
# Original names are preserved in introspection APIs.
options = OptionParser()
options.define('with-dash', group='g')
options.define('with_underscore', group='g')
all_options = ['help', 'with-dash', 'with_underscore']
self.assertEqual(sorted(options), all_options)
self.assertEqual(sorted(k for (k, v) in options.items()), all_options)
self.assertEqual(sorted(options.as_dict().keys()), all_options)
self.assertEqual(sorted(options.group_dict('g')),
['with-dash', 'with_underscore'])
# --help shows CLI-style names with dashes.
buf = StringIO()
options.print_help(buf)
self.assertIn('--with-dash', buf.getvalue())
self.assertIn('--with-underscore', buf.getvalue())

View file

@ -13,7 +13,7 @@ from tornado.ioloop import IOLoop
from tornado.log import gen_log
from tornado.process import fork_processes, task_id, Subprocess
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.testing import bind_unused_port, ExpectLog, AsyncTestCase
from tornado.testing import bind_unused_port, ExpectLog, AsyncTestCase, gen_test
from tornado.test.util import unittest, skipIfNonUnix
from tornado.web import RequestHandler, Application
@ -85,7 +85,7 @@ class ProcessTest(unittest.TestCase):
self.assertEqual(id, task_id())
server = HTTPServer(self.get_app())
server.add_sockets([sock])
IOLoop.instance().start()
IOLoop.current().start()
elif id == 2:
self.assertEqual(id, task_id())
sock.close()
@ -200,6 +200,16 @@ class SubprocessTest(AsyncTestCase):
self.assertEqual(ret, 0)
self.assertEqual(subproc.returncode, ret)
@gen_test
def test_sigchild_future(self):
skip_if_twisted()
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, '-c', 'pass'])
ret = yield subproc.wait_for_exit()
self.assertEqual(ret, 0)
self.assertEqual(subproc.returncode, ret)
def test_sigchild_signal(self):
skip_if_twisted()
Subprocess.initialize(io_loop=self.io_loop)
@ -212,3 +222,22 @@ class SubprocessTest(AsyncTestCase):
ret = self.wait()
self.assertEqual(subproc.returncode, ret)
self.assertEqual(ret, -signal.SIGTERM)
@gen_test
def test_wait_for_exit_raise(self):
skip_if_twisted()
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, '-c', 'import sys; sys.exit(1)'])
with self.assertRaises(subprocess.CalledProcessError) as cm:
yield subproc.wait_for_exit()
self.assertEqual(cm.exception.returncode, 1)
@gen_test
def test_wait_for_exit_raise_disabled(self):
skip_if_twisted()
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, '-c', 'import sys; sys.exit(1)'])
ret = yield subproc.wait_for_exit(raise_error=False)
self.assertEqual(ret, 1)

View file

@ -0,0 +1,423 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, with_statement
from datetime import timedelta
from random import random
from tornado import gen, queues
from tornado.gen import TimeoutError
from tornado.testing import gen_test, AsyncTestCase
from tornado.test.util import unittest, skipBefore35, exec_test
class QueueBasicTest(AsyncTestCase):
def test_repr_and_str(self):
q = queues.Queue(maxsize=1)
self.assertIn(hex(id(q)), repr(q))
self.assertNotIn(hex(id(q)), str(q))
q.get()
for q_str in repr(q), str(q):
self.assertTrue(q_str.startswith('<Queue'))
self.assertIn('maxsize=1', q_str)
self.assertIn('getters[1]', q_str)
self.assertNotIn('putters', q_str)
self.assertNotIn('tasks', q_str)
q.put(None)
q.put(None)
# Now the queue is full, this putter blocks.
q.put(None)
for q_str in repr(q), str(q):
self.assertNotIn('getters', q_str)
self.assertIn('putters[1]', q_str)
self.assertIn('tasks=2', q_str)
def test_order(self):
q = queues.Queue()
for i in [1, 3, 2]:
q.put_nowait(i)
items = [q.get_nowait() for _ in range(3)]
self.assertEqual([1, 3, 2], items)
@gen_test
def test_maxsize(self):
self.assertRaises(TypeError, queues.Queue, maxsize=None)
self.assertRaises(ValueError, queues.Queue, maxsize=-1)
q = queues.Queue(maxsize=2)
self.assertTrue(q.empty())
self.assertFalse(q.full())
self.assertEqual(2, q.maxsize)
self.assertTrue(q.put(0).done())
self.assertTrue(q.put(1).done())
self.assertFalse(q.empty())
self.assertTrue(q.full())
put2 = q.put(2)
self.assertFalse(put2.done())
self.assertEqual(0, (yield q.get())) # Make room.
self.assertTrue(put2.done())
self.assertFalse(q.empty())
self.assertTrue(q.full())
class QueueGetTest(AsyncTestCase):
@gen_test
def test_blocking_get(self):
q = queues.Queue()
q.put_nowait(0)
self.assertEqual(0, (yield q.get()))
def test_nonblocking_get(self):
q = queues.Queue()
q.put_nowait(0)
self.assertEqual(0, q.get_nowait())
def test_nonblocking_get_exception(self):
q = queues.Queue()
self.assertRaises(queues.QueueEmpty, q.get_nowait)
@gen_test
def test_get_with_putters(self):
q = queues.Queue(1)
q.put_nowait(0)
put = q.put(1)
self.assertEqual(0, (yield q.get()))
self.assertIsNone((yield put))
@gen_test
def test_blocking_get_wait(self):
q = queues.Queue()
q.put(0)
self.io_loop.call_later(0.01, q.put, 1)
self.io_loop.call_later(0.02, q.put, 2)
self.assertEqual(0, (yield q.get(timeout=timedelta(seconds=1))))
self.assertEqual(1, (yield q.get(timeout=timedelta(seconds=1))))
@gen_test
def test_get_timeout(self):
q = queues.Queue()
get_timeout = q.get(timeout=timedelta(seconds=0.01))
get = q.get()
with self.assertRaises(TimeoutError):
yield get_timeout
q.put_nowait(0)
self.assertEqual(0, (yield get))
@gen_test
def test_get_timeout_preempted(self):
q = queues.Queue()
get = q.get(timeout=timedelta(seconds=0.01))
q.put(0)
yield gen.sleep(0.02)
self.assertEqual(0, (yield get))
@gen_test
def test_get_clears_timed_out_putters(self):
q = queues.Queue(1)
# First putter succeeds, remainder block.
putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)]
put = q.put(10)
self.assertEqual(10, len(q._putters))
yield gen.sleep(0.02)
self.assertEqual(10, len(q._putters))
self.assertFalse(put.done()) # Final waiter is still active.
q.put(11)
self.assertEqual(0, (yield q.get())) # get() clears the waiters.
self.assertEqual(1, len(q._putters))
for putter in putters[1:]:
self.assertRaises(TimeoutError, putter.result)
@gen_test
def test_get_clears_timed_out_getters(self):
q = queues.Queue()
getters = [q.get(timedelta(seconds=0.01)) for _ in range(10)]
get = q.get()
self.assertEqual(11, len(q._getters))
yield gen.sleep(0.02)
self.assertEqual(11, len(q._getters))
self.assertFalse(get.done()) # Final waiter is still active.
q.get() # get() clears the waiters.
self.assertEqual(2, len(q._getters))
for getter in getters:
self.assertRaises(TimeoutError, getter.result)
@skipBefore35
@gen_test
def test_async_for(self):
q = queues.Queue()
for i in range(5):
q.put(i)
namespace = exec_test(globals(), locals(), """
async def f():
results = []
async for i in q:
results.append(i)
if i == 4:
return results
""")
results = yield namespace['f']()
self.assertEqual(results, list(range(5)))
class QueuePutTest(AsyncTestCase):
@gen_test
def test_blocking_put(self):
q = queues.Queue()
q.put(0)
self.assertEqual(0, q.get_nowait())
def test_nonblocking_put_exception(self):
q = queues.Queue(1)
q.put(0)
self.assertRaises(queues.QueueFull, q.put_nowait, 1)
@gen_test
def test_put_with_getters(self):
q = queues.Queue()
get0 = q.get()
get1 = q.get()
yield q.put(0)
self.assertEqual(0, (yield get0))
yield q.put(1)
self.assertEqual(1, (yield get1))
@gen_test
def test_nonblocking_put_with_getters(self):
q = queues.Queue()
get0 = q.get()
get1 = q.get()
q.put_nowait(0)
# put_nowait does *not* immediately unblock getters.
yield gen.moment
self.assertEqual(0, (yield get0))
q.put_nowait(1)
yield gen.moment
self.assertEqual(1, (yield get1))
@gen_test
def test_blocking_put_wait(self):
q = queues.Queue(1)
q.put_nowait(0)
self.io_loop.call_later(0.01, q.get)
self.io_loop.call_later(0.02, q.get)
futures = [q.put(0), q.put(1)]
self.assertFalse(any(f.done() for f in futures))
yield futures
@gen_test
def test_put_timeout(self):
q = queues.Queue(1)
q.put_nowait(0) # Now it's full.
put_timeout = q.put(1, timeout=timedelta(seconds=0.01))
put = q.put(2)
with self.assertRaises(TimeoutError):
yield put_timeout
self.assertEqual(0, q.get_nowait())
# 1 was never put in the queue.
self.assertEqual(2, (yield q.get()))
# Final get() unblocked this putter.
yield put
@gen_test
def test_put_timeout_preempted(self):
q = queues.Queue(1)
q.put_nowait(0)
put = q.put(1, timeout=timedelta(seconds=0.01))
q.get()
yield gen.sleep(0.02)
yield put # No TimeoutError.
@gen_test
def test_put_clears_timed_out_putters(self):
q = queues.Queue(1)
# First putter succeeds, remainder block.
putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)]
put = q.put(10)
self.assertEqual(10, len(q._putters))
yield gen.sleep(0.02)
self.assertEqual(10, len(q._putters))
self.assertFalse(put.done()) # Final waiter is still active.
q.put(11) # put() clears the waiters.
self.assertEqual(2, len(q._putters))
for putter in putters[1:]:
self.assertRaises(TimeoutError, putter.result)
@gen_test
def test_put_clears_timed_out_getters(self):
q = queues.Queue()
getters = [q.get(timedelta(seconds=0.01)) for _ in range(10)]
get = q.get()
q.get()
self.assertEqual(12, len(q._getters))
yield gen.sleep(0.02)
self.assertEqual(12, len(q._getters))
self.assertFalse(get.done()) # Final waiters still active.
q.put(0) # put() clears the waiters.
self.assertEqual(1, len(q._getters))
self.assertEqual(0, (yield get))
for getter in getters:
self.assertRaises(TimeoutError, getter.result)
@gen_test
def test_float_maxsize(self):
# Non-int maxsize must round down: http://bugs.python.org/issue21723
q = queues.Queue(maxsize=1.3)
self.assertTrue(q.empty())
self.assertFalse(q.full())
q.put_nowait(0)
q.put_nowait(1)
self.assertFalse(q.empty())
self.assertTrue(q.full())
self.assertRaises(queues.QueueFull, q.put_nowait, 2)
self.assertEqual(0, q.get_nowait())
self.assertFalse(q.empty())
self.assertFalse(q.full())
yield q.put(2)
put = q.put(3)
self.assertFalse(put.done())
self.assertEqual(1, (yield q.get()))
yield put
self.assertTrue(q.full())
class QueueJoinTest(AsyncTestCase):
queue_class = queues.Queue
def test_task_done_underflow(self):
q = self.queue_class()
self.assertRaises(ValueError, q.task_done)
@gen_test
def test_task_done(self):
q = self.queue_class()
for i in range(100):
q.put_nowait(i)
self.accumulator = 0
@gen.coroutine
def worker():
while True:
item = yield q.get()
self.accumulator += item
q.task_done()
yield gen.sleep(random() * 0.01)
# Two coroutines share work.
worker()
worker()
yield q.join()
self.assertEqual(sum(range(100)), self.accumulator)
@gen_test
def test_task_done_delay(self):
# Verify it is task_done(), not get(), that unblocks join().
q = self.queue_class()
q.put_nowait(0)
join = q.join()
self.assertFalse(join.done())
yield q.get()
self.assertFalse(join.done())
yield gen.moment
self.assertFalse(join.done())
q.task_done()
self.assertTrue(join.done())
@gen_test
def test_join_empty_queue(self):
q = self.queue_class()
yield q.join()
yield q.join()
@gen_test
def test_join_timeout(self):
q = self.queue_class()
q.put(0)
with self.assertRaises(TimeoutError):
yield q.join(timeout=timedelta(seconds=0.01))
class PriorityQueueJoinTest(QueueJoinTest):
queue_class = queues.PriorityQueue
@gen_test
def test_order(self):
q = self.queue_class(maxsize=2)
q.put_nowait((1, 'a'))
q.put_nowait((0, 'b'))
self.assertTrue(q.full())
q.put((3, 'c'))
q.put((2, 'd'))
self.assertEqual((0, 'b'), q.get_nowait())
self.assertEqual((1, 'a'), (yield q.get()))
self.assertEqual((2, 'd'), q.get_nowait())
self.assertEqual((3, 'c'), (yield q.get()))
self.assertTrue(q.empty())
class LifoQueueJoinTest(QueueJoinTest):
queue_class = queues.LifoQueue
@gen_test
def test_order(self):
q = self.queue_class(maxsize=2)
q.put_nowait(1)
q.put_nowait(0)
self.assertTrue(q.full())
q.put(3)
q.put(2)
self.assertEqual(3, q.get_nowait())
self.assertEqual(2, (yield q.get()))
self.assertEqual(0, q.get_nowait())
self.assertEqual(1, (yield q.get()))
self.assertTrue(q.empty())
class ProducerConsumerTest(AsyncTestCase):
@gen_test
def test_producer_consumer(self):
q = queues.Queue(maxsize=3)
history = []
# We don't yield between get() and task_done(), so get() must wait for
# the next tick. Otherwise we'd immediately call task_done and unblock
# join() before q.put() resumes, and we'd only process the first four
# items.
@gen.coroutine
def consumer():
while True:
history.append((yield q.get()))
q.task_done()
@gen.coroutine
def producer():
for item in range(10):
yield q.put(item)
consumer()
yield producer()
yield q.join()
self.assertEqual(list(range(10)), history)
if __name__ == '__main__':
unittest.main()

View file

@ -8,6 +8,7 @@ import operator
import textwrap
import sys
from tornado.httpclient import AsyncHTTPClient
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from tornado.netutil import Resolver
from tornado.options import define, options, add_parse_callback
@ -22,6 +23,7 @@ TEST_MODULES = [
'tornado.httputil.doctests',
'tornado.iostream.doctests',
'tornado.util.doctests',
'tornado.test.asyncio_test',
'tornado.test.auth_test',
'tornado.test.concurrent_test',
'tornado.test.curl_httpclient_test',
@ -34,13 +36,16 @@ TEST_MODULES = [
'tornado.test.ioloop_test',
'tornado.test.iostream_test',
'tornado.test.locale_test',
'tornado.test.locks_test',
'tornado.test.netutil_test',
'tornado.test.log_test',
'tornado.test.options_test',
'tornado.test.process_test',
'tornado.test.queues_test',
'tornado.test.simple_httpclient_test',
'tornado.test.stack_context_test',
'tornado.test.tcpclient_test',
'tornado.test.tcpserver_test',
'tornado.test.template_test',
'tornado.test.testing_test',
'tornado.test.twisted_test',
@ -67,6 +72,21 @@ class TornadoTextTestRunner(unittest.TextTestRunner):
return result
class LogCounter(logging.Filter):
"""Counts the number of WARNING or higher log records."""
def __init__(self, *args, **kwargs):
# Can't use super() because logging.Filter is an old-style class in py26
logging.Filter.__init__(self, *args, **kwargs)
self.warning_count = self.error_count = 0
def filter(self, record):
if record.levelno >= logging.ERROR:
self.error_count += 1
elif record.levelno >= logging.WARNING:
self.warning_count += 1
return True
def main():
# The -W command-line option does not work in a virtualenv with
# python 3 (as of virtualenv 1.7), so configure warnings
@ -92,12 +112,21 @@ def main():
# 2.7 and 3.2
warnings.filterwarnings("ignore", category=DeprecationWarning,
message="Please use assert.* instead")
# unittest2 0.6 on py26 reports these as PendingDeprecationWarnings
# instead of DeprecationWarnings.
warnings.filterwarnings("ignore", category=PendingDeprecationWarning,
message="Please use assert.* instead")
# Twisted 15.0.0 triggers some warnings on py3 with -bb.
warnings.filterwarnings("ignore", category=BytesWarning,
module=r"twisted\..*")
logging.getLogger("tornado.access").setLevel(logging.CRITICAL)
define('httpclient', type=str, default=None,
callback=lambda s: AsyncHTTPClient.configure(
s, defaults=dict(allow_ipv6=False)))
define('httpserver', type=str, default=None,
callback=HTTPServer.configure)
define('ioloop', type=str, default=None)
define('ioloop_time_monotonic', default=False)
define('resolver', type=str, default=None,
@ -121,6 +150,10 @@ def main():
IOLoop.configure(options.ioloop, **kwargs)
add_parse_callback(configure_ioloop)
log_counter = LogCounter()
add_parse_callback(
lambda: logging.getLogger().handlers[0].addFilter(log_counter))
import tornado.testing
kwargs = {}
if sys.version_info >= (3, 2):
@ -131,7 +164,16 @@ def main():
# detail. http://bugs.python.org/issue15626
kwargs['warnings'] = False
kwargs['testRunner'] = TornadoTextTestRunner
tornado.testing.main(**kwargs)
try:
tornado.testing.main(**kwargs)
finally:
# The tests should run clean; consider it a failure if they logged
# any warnings or errors. We'd like to ban info logs too, but
# we can't count them cleanly due to interactions with LogTrapTestCase.
if log_counter.warning_count > 0 or log_counter.error_count > 0:
logging.error("logged %d warnings and %d errors",
log_counter.warning_count, log_counter.error_count)
sys.exit(1)
if __name__ == '__main__':
main()

View file

@ -8,19 +8,21 @@ import logging
import os
import re
import socket
import ssl
import sys
from tornado.escape import to_unicode
from tornado import gen
from tornado.httpclient import AsyncHTTPClient
from tornado.httputil import HTTPHeaders
from tornado.httputil import HTTPHeaders, ResponseStartLine
from tornado.ioloop import IOLoop
from tornado.log import gen_log, app_log
from tornado.log import gen_log
from tornado.netutil import Resolver, bind_sockets
from tornado.simple_httpclient import SimpleAsyncHTTPClient, _default_ca_certs
from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler, RedirectHandler
from tornado.test import httpclient_test
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog
from tornado.test.util import skipOnTravis, skipIfNoIPv6
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog
from tornado.test.util import skipOnTravis, skipIfNoIPv6, refusing_port, unittest, skipBefore35, exec_test
from tornado.web import RequestHandler, Application, asynchronous, url, stream_request_body
@ -97,15 +99,18 @@ class HostEchoHandler(RequestHandler):
class NoContentLengthHandler(RequestHandler):
@gen.coroutine
@asynchronous
def get(self):
# Emulate the old HTTP/1.0 behavior of returning a body with no
# content-length. Tornado handles content-length at the framework
# level so we have to go around it.
stream = self.request.connection.stream
yield stream.write(b"HTTP/1.0 200 OK\r\n\r\n"
b"hello")
stream.close()
if self.request.version.startswith('HTTP/1'):
# Emulate the old HTTP/1.0 behavior of returning a body with no
# content-length. Tornado handles content-length at the framework
# level so we have to go around it.
stream = self.request.connection.detach()
stream.write(b"HTTP/1.0 200 OK\r\n\r\n"
b"hello")
stream.close()
else:
self.finish('HTTP/1 required')
class EchoPostHandler(RequestHandler):
@ -141,6 +146,7 @@ class SimpleHTTPClientTestMixin(object):
url("/no_content_length", NoContentLengthHandler),
url("/echo_post", EchoPostHandler),
url("/respond_in_prepare", RespondInPrepareHandler),
url("/redirect", RedirectHandler),
], gzip=True)
def test_singleton(self):
@ -191,9 +197,6 @@ class SimpleHTTPClientTestMixin(object):
response = self.wait()
response.rethrow()
def test_default_certificates_exist(self):
open(_default_ca_certs()).close()
def test_gzip(self):
# All the tests in this file should be using gzip, but this test
# ensures that it is in fact getting compressed.
@ -204,6 +207,7 @@ class SimpleHTTPClientTestMixin(object):
self.assertEqual(response.headers["Content-Encoding"], "gzip")
self.assertNotEqual(response.body, b"asdfqwer")
# Our test data gets bigger when gzipped. Oops. :)
# Chunked encoding bypasses the MIN_LENGTH check.
self.assertEqual(len(response.body), 34)
f = gzip.GzipFile(mode="r", fileobj=response.buffer)
self.assertEqual(f.read(), b"asdfqwer")
@ -235,9 +239,16 @@ class SimpleHTTPClientTestMixin(object):
@skipOnTravis
def test_request_timeout(self):
response = self.fetch('/trigger?wake=false', request_timeout=0.1)
timeout = 0.1
timeout_min, timeout_max = 0.099, 0.15
if os.name == 'nt':
timeout = 0.5
timeout_min, timeout_max = 0.4, 0.6
response = self.fetch('/trigger?wake=false', request_timeout=timeout)
self.assertEqual(response.code, 599)
self.assertTrue(0.099 < response.request_time < 0.15, response.request_time)
self.assertTrue(timeout_min < response.request_time < timeout_max,
response.request_time)
self.assertEqual(str(response.error), "HTTP 599: Timeout")
# trigger the hanging request to let it clean up after itself
self.triggers.popleft()()
@ -294,10 +305,13 @@ class SimpleHTTPClientTestMixin(object):
self.assertEqual(response.code, 204)
# 204 status doesn't need a content-length, but tornado will
# add a zero content-length anyway.
#
# A test without a content-length header is included below
# in HTTP204NoContentTestCase.
self.assertEqual(response.headers["Content-length"], "0")
# 204 status with non-zero content length is malformed
with ExpectLog(app_log, "Uncaught exception"):
with ExpectLog(gen_log, "Malformed HTTP message"):
response = self.fetch("/no_content?error=1")
self.assertEqual(response.code, 599)
@ -312,10 +326,10 @@ class SimpleHTTPClientTestMixin(object):
self.assertTrue(host_re.match(response.body), response.body)
def test_connection_refused(self):
server_socket, port = bind_unused_port()
server_socket.close()
cleanup_func, port = refusing_port()
self.addCleanup(cleanup_func)
with ExpectLog(gen_log, ".*", required=False):
self.http_client.fetch("http://localhost:%d/" % port, self.stop)
self.http_client.fetch("http://127.0.0.1:%d/" % port, self.stop)
response = self.wait()
self.assertEqual(599, response.code)
@ -349,7 +363,10 @@ class SimpleHTTPClientTestMixin(object):
def test_no_content_length(self):
response = self.fetch("/no_content_length")
self.assertEquals(b"hello", response.body)
if response.body == b"HTTP/1 required":
self.skipTest("requires HTTP/1.x")
else:
self.assertEquals(b"hello", response.body)
def sync_body_producer(self, write):
write(b'1234')
@ -387,6 +404,33 @@ class SimpleHTTPClientTestMixin(object):
response.rethrow()
self.assertEqual(response.body, b"12345678")
@skipBefore35
def test_native_body_producer_chunked(self):
namespace = exec_test(globals(), locals(), """
async def body_producer(write):
await write(b'1234')
await gen.Task(IOLoop.current().add_callback)
await write(b'5678')
""")
response = self.fetch("/echo_post", method="POST",
body_producer=namespace["body_producer"])
response.rethrow()
self.assertEqual(response.body, b"12345678")
@skipBefore35
def test_native_body_producer_content_length(self):
namespace = exec_test(globals(), locals(), """
async def body_producer(write):
await write(b'1234')
await gen.Task(IOLoop.current().add_callback)
await write(b'5678')
""")
response = self.fetch("/echo_post", method="POST",
body_producer=namespace["body_producer"],
headers={'Content-Length': '8'})
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_100_continue(self):
response = self.fetch("/echo_post", method="POST",
body=b"1234",
@ -401,6 +445,24 @@ class SimpleHTTPClientTestMixin(object):
expect_100_continue=True)
self.assertEqual(response.code, 403)
def test_streaming_follow_redirects(self):
# When following redirects, header and streaming callbacks
# should only be called for the final result.
# TODO(bdarnell): this test belongs in httpclient_test instead of
# simple_httpclient_test, but it fails with the version of libcurl
# available on travis-ci. Move it when that has been upgraded
# or we have a better framework to skip tests based on curl version.
headers = []
chunks = []
self.fetch("/redirect?url=/hello",
header_callback=headers.append,
streaming_callback=chunks.append)
chunks = list(map(to_unicode, chunks))
self.assertEqual(chunks, ['Hello world!'])
# Make sure we only got one set of headers.
num_start_lines = len([h for h in headers if h.startswith("HTTP/")])
self.assertEqual(num_start_lines, 1)
class SimpleHTTPClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPTestCase):
def setUp(self):
@ -422,6 +484,43 @@ class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase):
defaults=dict(validate_cert=False),
**kwargs)
def test_ssl_options(self):
resp = self.fetch("/hello", ssl_options={})
self.assertEqual(resp.body, b"Hello world!")
@unittest.skipIf(not hasattr(ssl, 'SSLContext'),
'ssl.SSLContext not present')
def test_ssl_context(self):
resp = self.fetch("/hello",
ssl_options=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
self.assertEqual(resp.body, b"Hello world!")
def test_ssl_options_handshake_fail(self):
with ExpectLog(gen_log, "SSL Error|Uncaught exception",
required=False):
resp = self.fetch(
"/hello", ssl_options=dict(cert_reqs=ssl.CERT_REQUIRED))
self.assertRaises(ssl.SSLError, resp.rethrow)
@unittest.skipIf(not hasattr(ssl, 'SSLContext'),
'ssl.SSLContext not present')
def test_ssl_context_handshake_fail(self):
with ExpectLog(gen_log, "SSL Error|Uncaught exception"):
ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
ctx.verify_mode = ssl.CERT_REQUIRED
resp = self.fetch("/hello", ssl_options=ctx)
self.assertRaises(ssl.SSLError, resp.rethrow)
def test_error_logging(self):
# No stack traces are logged for SSL errors (in this case,
# failure to validate the testing self-signed cert).
# The SSLError is exposed through ssl.SSLError.
with ExpectLog(gen_log, '.*') as expect_log:
response = self.fetch("/", validate_cert=True)
self.assertEqual(response.code, 599)
self.assertIsInstance(response.error, ssl.SSLError)
self.assertFalse(expect_log.logged_stack)
class CreateAsyncHTTPClientTestCase(AsyncTestCase):
def setUp(self):
@ -457,6 +556,12 @@ class CreateAsyncHTTPClientTestCase(AsyncTestCase):
class HTTP100ContinueTestCase(AsyncHTTPTestCase):
def respond_100(self, request):
self.http1 = request.version.startswith('HTTP/1.')
if not self.http1:
request.connection.write_headers(ResponseStartLine('', 200, 'OK'),
HTTPHeaders())
request.connection.finish()
return
self.request = request
self.request.connection.stream.write(
b"HTTP/1.1 100 CONTINUE\r\n\r\n",
@ -473,9 +578,43 @@ class HTTP100ContinueTestCase(AsyncHTTPTestCase):
def test_100_continue(self):
res = self.fetch('/')
if not self.http1:
self.skipTest("requires HTTP/1.x")
self.assertEqual(res.body, b'A')
class HTTP204NoContentTestCase(AsyncHTTPTestCase):
def respond_204(self, request):
self.http1 = request.version.startswith('HTTP/1.')
if not self.http1:
# Close the request cleanly in HTTP/2; it will be skipped anyway.
request.connection.write_headers(ResponseStartLine('', 200, 'OK'),
HTTPHeaders())
request.connection.finish()
return
# A 204 response never has a body, even if doesn't have a content-length
# (which would otherwise mean read-until-close). Tornado always
# sends a content-length, so we simulate here a server that sends
# no content length and does not close the connection.
#
# Tests of a 204 response with a Content-Length header are included
# in SimpleHTTPClientTestMixin.
stream = request.connection.detach()
stream.write(
b"HTTP/1.1 204 No content\r\n\r\n")
stream.close()
def get_app(self):
return self.respond_204
def test_204_no_content(self):
resp = self.fetch('/')
if not self.http1:
self.skipTest("requires HTTP/1.x")
self.assertEqual(resp.code, 204)
self.assertEqual(resp.body, b'')
class HostnameMappingTestCase(AsyncHTTPTestCase):
def setUp(self):
super(HostnameMappingTestCase, self).setUp()
@ -550,3 +689,71 @@ class MaxHeaderSizeTest(AsyncHTTPTestCase):
with ExpectLog(gen_log, "Unsatisfiable read"):
response = self.fetch('/large')
self.assertEqual(response.code, 599)
class MaxBodySizeTest(AsyncHTTPTestCase):
def get_app(self):
class SmallBody(RequestHandler):
def get(self):
self.write("a" * 1024 * 64)
class LargeBody(RequestHandler):
def get(self):
self.write("a" * 1024 * 100)
return Application([('/small', SmallBody),
('/large', LargeBody)])
def get_http_client(self):
return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_body_size=1024 * 64)
def test_small_body(self):
response = self.fetch('/small')
response.rethrow()
self.assertEqual(response.body, b'a' * 1024 * 64)
def test_large_body(self):
with ExpectLog(gen_log, "Malformed HTTP message from None: Content-Length too long"):
response = self.fetch('/large')
self.assertEqual(response.code, 599)
class MaxBufferSizeTest(AsyncHTTPTestCase):
def get_app(self):
class LargeBody(RequestHandler):
def get(self):
self.write("a" * 1024 * 100)
return Application([('/large', LargeBody)])
def get_http_client(self):
# 100KB body with 64KB buffer
return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_body_size=1024 * 100, max_buffer_size=1024 * 64)
def test_large_body(self):
response = self.fetch('/large')
response.rethrow()
self.assertEqual(response.body, b'a' * 1024 * 100)
class ChunkedWithContentLengthTest(AsyncHTTPTestCase):
def get_app(self):
class ChunkedWithContentLength(RequestHandler):
def get(self):
# Add an invalid Transfer-Encoding to the response
self.set_header('Transfer-Encoding', 'chunked')
self.write("Hello world")
return Application([('/chunkwithcl', ChunkedWithContentLength)])
def get_http_client(self):
return SimpleAsyncHTTPClient()
def test_chunked_with_content_length(self):
# Make sure the invalid headers are detected
with ExpectLog(gen_log, ("Malformed HTTP message from None: Response "
"with both Transfer-Encoding and Content-Length")):
response = self.fetch('/chunkwithcl')
self.assertEqual(response.code, 599)

View file

@ -0,0 +1,23 @@
<?xml version="1.0"?>
<data>
<country name="Liechtenstein">
<rank>1</rank>
<year>2008</year>
<gdppc>141100</gdppc>
<neighbor name="Austria" direction="E"/>
<neighbor name="Switzerland" direction="W"/>
</country>
<country name="Singapore">
<rank>4</rank>
<year>2011</year>
<gdppc>59900</gdppc>
<neighbor name="Malaysia" direction="N"/>
</country>
<country name="Panama">
<rank>68</rank>
<year>2011</year>
<gdppc>13600</gdppc>
<neighbor name="Costa Rica" direction="W"/>
<neighbor name="Colombia" direction="E"/>
</country>
</data>

View file

@ -0,0 +1,2 @@
This file should not be served by StaticFileHandler even though
its name starts with "static".

View file

@ -24,8 +24,8 @@ from tornado.concurrent import Future
from tornado.netutil import bind_sockets, Resolver
from tornado.tcpclient import TCPClient, _Connector
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, bind_unused_port, gen_test
from tornado.test.util import skipIfNoIPv6, unittest
from tornado.testing import AsyncTestCase, gen_test
from tornado.test.util import skipIfNoIPv6, unittest, refusing_port
# Fake address families for testing. Used in place of AF_INET
# and AF_INET6 because some installations do not have AF_INET6.
@ -72,7 +72,9 @@ class TCPClientTest(AsyncTestCase):
super(TCPClientTest, self).tearDown()
def skipIfLocalhostV4(self):
Resolver().resolve('localhost', 0, callback=self.stop)
# The port used here doesn't matter, but some systems require it
# to be non-zero if we do not also pass AI_PASSIVE.
Resolver().resolve('localhost', 80, callback=self.stop)
addrinfo = self.wait()
families = set(addr[0] for addr in addrinfo)
if socket.AF_INET6 not in families:
@ -118,8 +120,8 @@ class TCPClientTest(AsyncTestCase):
@gen_test
def test_refused_ipv4(self):
sock, port = bind_unused_port()
sock.close()
cleanup_func, port = refusing_port()
self.addCleanup(cleanup_func)
with self.assertRaises(IOError):
yield self.client.connect('127.0.0.1', port)

View file

@ -0,0 +1,39 @@
from __future__ import absolute_import, division, print_function, with_statement
import socket
from tornado import gen
from tornado.iostream import IOStream
from tornado.log import app_log
from tornado.stack_context import NullContext
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test
class TCPServerTest(AsyncTestCase):
@gen_test
def test_handle_stream_coroutine_logging(self):
# handle_stream may be a coroutine and any exception in its
# Future will be logged.
class TestServer(TCPServer):
@gen.coroutine
def handle_stream(self, stream, address):
yield gen.moment
stream.close()
1 / 0
server = client = None
try:
sock, port = bind_unused_port()
with NullContext():
server = TestServer()
server.add_socket(sock)
client = IOStream(socket.socket())
with ExpectLog(app_log, "Exception in callback"):
yield client.connect(('localhost', port))
yield client.read_until_close()
yield gen.moment
finally:
if server is not None:
server.stop()
if client is not None:
client.close()

View file

@ -7,7 +7,7 @@ import traceback
from tornado.escape import utf8, native_str, to_unicode
from tornado.template import Template, DictLoader, ParseError, Loader
from tornado.test.util import unittest
from tornado.util import u, bytes_type, ObjectDict, unicode_type
from tornado.util import u, ObjectDict, unicode_type
class TemplateTest(unittest.TestCase):
@ -173,6 +173,10 @@ try{% set y = 1/x %}
template = Template('{{ 1 / 2 }}')
self.assertEqual(template.generate(), '0')
def test_non_ascii_name(self):
loader = DictLoader({u("t\u00e9st.html"): "hello"})
self.assertEqual(loader.load(u("t\u00e9st.html")).generate(), b"hello")
class StackTraceTest(unittest.TestCase):
def test_error_line_number_expression(self):
@ -264,6 +268,19 @@ three{%end%}
traceback.format_exc())
class ParseErrorDetailTest(unittest.TestCase):
def test_details(self):
loader = DictLoader({
"foo.html": "\n\n{{",
})
with self.assertRaises(ParseError) as cm:
loader.load("foo.html")
self.assertEqual("Missing end expression }} at foo.html:3",
str(cm.exception))
self.assertEqual("foo.html", cm.exception.filename)
self.assertEqual(3, cm.exception.lineno)
class AutoEscapeTest(unittest.TestCase):
def setUp(self):
self.templates = {
@ -374,7 +391,7 @@ raw: {% raw name %}""",
"{% autoescape py_escape %}s = {{ name }}\n"})
def py_escape(s):
self.assertEqual(type(s), bytes_type)
self.assertEqual(type(s), bytes)
return repr(native_str(s))
def render(template, name):
@ -387,7 +404,7 @@ raw: {% raw name %}""",
self.assertEqual(render("foo.py", ["not a string"]),
b"""s = "['not a string']"\n""")
def test_minimize_whitespace(self):
def test_manual_minimize_whitespace(self):
# Whitespace including newlines is allowed within template tags
# and directives, and this is one way to avoid long lines while
# keeping extra whitespace out of the rendered output.
@ -401,6 +418,62 @@ raw: {% raw name %}""",
self.assertEqual(loader.load("foo.txt").generate(items=range(5)),
b"0, 1, 2, 3, 4")
def test_whitespace_by_filename(self):
# Default whitespace handling depends on the template filename.
loader = DictLoader({
"foo.html": " \n\t\n asdf\t ",
"bar.js": " \n\n\n\t qwer ",
"baz.txt": "\t zxcv\n\n",
"include.html": " {% include baz.txt %} \n ",
"include.txt": "\t\t{% include foo.html %} ",
})
# HTML and JS files have whitespace compressed by default.
self.assertEqual(loader.load("foo.html").generate(),
b"\nasdf ")
self.assertEqual(loader.load("bar.js").generate(),
b"\nqwer ")
# TXT files do not.
self.assertEqual(loader.load("baz.txt").generate(),
b"\t zxcv\n\n")
# Each file maintains its own status even when included in
# a file of the other type.
self.assertEqual(loader.load("include.html").generate(),
b" \t zxcv\n\n\n")
self.assertEqual(loader.load("include.txt").generate(),
b"\t\t\nasdf ")
def test_whitespace_by_loader(self):
templates = {
"foo.html": "\t\tfoo\n\n",
"bar.txt": "\t\tbar\n\n",
}
loader = DictLoader(templates, whitespace='all')
self.assertEqual(loader.load("foo.html").generate(), b"\t\tfoo\n\n")
self.assertEqual(loader.load("bar.txt").generate(), b"\t\tbar\n\n")
loader = DictLoader(templates, whitespace='single')
self.assertEqual(loader.load("foo.html").generate(), b" foo\n")
self.assertEqual(loader.load("bar.txt").generate(), b" bar\n")
loader = DictLoader(templates, whitespace='oneline')
self.assertEqual(loader.load("foo.html").generate(), b" foo ")
self.assertEqual(loader.load("bar.txt").generate(), b" bar ")
def test_whitespace_directive(self):
loader = DictLoader({
"foo.html": """\
{% whitespace oneline %}
{% for i in range(3) %}
{{ i }}
{% end %}
{% whitespace all %}
pre\tformatted
"""})
self.assertEqual(loader.load("foo.html").generate(),
b" 0 1 2 \n pre\tformatted\n")
class TemplateLoaderTest(unittest.TestCase):
def setUp(self):

View file

@ -3,12 +3,13 @@
from __future__ import absolute_import, division, print_function, with_statement
from tornado import gen, ioloop
from tornado.testing import AsyncTestCase, gen_test
from tornado.test.util import unittest
from tornado.log import app_log
from tornado.testing import AsyncTestCase, gen_test, ExpectLog
from tornado.test.util import unittest, skipBefore35, exec_test
import contextlib
import os
import traceback
import warnings
@contextlib.contextmanager
@ -57,11 +58,22 @@ class AsyncTestCaseTest(AsyncTestCase):
This test makes sure that a second call to wait()
clears the first timeout.
"""
self.io_loop.add_timeout(self.io_loop.time() + 0.01, self.stop)
self.io_loop.add_timeout(self.io_loop.time() + 0.00, self.stop)
self.wait(timeout=0.02)
self.io_loop.add_timeout(self.io_loop.time() + 0.03, self.stop)
self.wait(timeout=0.15)
def test_multiple_errors(self):
def fail(message):
raise Exception(message)
self.io_loop.add_callback(lambda: fail("error one"))
self.io_loop.add_callback(lambda: fail("error two"))
# The first error gets raised; the second gets logged.
with ExpectLog(app_log, "multiple unhandled exceptions"):
with self.assertRaises(Exception) as cm:
self.wait()
self.assertEqual(str(cm.exception), "error one")
class AsyncTestCaseWrapperTest(unittest.TestCase):
def test_undecorated_generator(self):
@ -74,6 +86,26 @@ class AsyncTestCaseWrapperTest(unittest.TestCase):
self.assertEqual(len(result.errors), 1)
self.assertIn("should be decorated", result.errors[0][1])
@skipBefore35
def test_undecorated_coroutine(self):
namespace = exec_test(globals(), locals(), """
class Test(AsyncTestCase):
async def test_coro(self):
pass
""")
test_class = namespace['Test']
test = test_class('test_coro')
result = unittest.TestResult()
# Silence "RuntimeWarning: coroutine 'test_coro' was never awaited".
with warnings.catch_warnings():
warnings.simplefilter('ignore')
test.run(result)
self.assertEqual(len(result.errors), 1)
self.assertIn("should be decorated", result.errors[0][1])
def test_undecorated_generator_with_skip(self):
class Test(AsyncTestCase):
@unittest.skip("don't run this")
@ -216,5 +248,31 @@ class GenTest(AsyncTestCase):
test_with_kwargs(self, test='test')
self.finished = True
@skipBefore35
def test_native_coroutine(self):
namespace = exec_test(globals(), locals(), """
@gen_test
async def test(self):
self.finished = True
""")
namespace['test'](self)
@skipBefore35
def test_native_coroutine_timeout(self):
# Set a short timeout and exceed it.
namespace = exec_test(globals(), locals(), """
@gen_test(timeout=0.1)
async def test(self):
await gen.sleep(1)
""")
try:
namespace['test'](self)
self.fail("did not get expected exception")
except ioloop.TimeoutError:
self.finished = True
if __name__ == '__main__':
unittest.main()

View file

@ -19,15 +19,18 @@ Unittest for the twisted-style reactor.
from __future__ import absolute_import, division, print_function, with_statement
import logging
import os
import shutil
import signal
import sys
import tempfile
import threading
import warnings
try:
import fcntl
from twisted.internet.defer import Deferred
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue
from twisted.internet.interfaces import IReadDescriptor, IWriteDescriptor
from twisted.internet.protocol import Protocol
from twisted.python import log
@ -40,10 +43,12 @@ except ImportError:
# The core of Twisted 12.3.0 is available on python 3, but twisted.web is not
# so test for it separately.
try:
from twisted.web.client import Agent
from twisted.web.client import Agent, readBody
from twisted.web.resource import Resource
from twisted.web.server import Site
have_twisted_web = True
# As of Twisted 15.0.0, twisted.web is present but fails our
# tests due to internal str/bytes errors.
have_twisted_web = sys.version_info < (3,)
except ImportError:
have_twisted_web = False
@ -52,6 +57,8 @@ try:
except ImportError:
import _thread as thread # py3
from tornado.escape import utf8
from tornado import gen
from tornado.httpclient import AsyncHTTPClient
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
@ -65,6 +72,9 @@ from tornado.web import RequestHandler, Application
skipIfNoTwisted = unittest.skipUnless(have_twisted,
"twisted module not present")
skipIfPy26 = unittest.skipIf(sys.version_info < (2, 7),
"twisted incompatible with singledispatch in py26")
def save_signal_handlers():
saved = {}
@ -407,7 +417,7 @@ class CompatibilityTests(unittest.TestCase):
# http://twistedmatrix.com/documents/current/web/howto/client.html
chunks = []
client = Agent(self.reactor)
d = client.request('GET', url)
d = client.request(b'GET', utf8(url))
class Accumulator(Protocol):
def __init__(self, finished):
@ -425,37 +435,98 @@ class CompatibilityTests(unittest.TestCase):
return finished
d.addCallback(callback)
def shutdown(ignored):
self.stop_loop()
def shutdown(failure):
if hasattr(self, 'stop_loop'):
self.stop_loop()
elif failure is not None:
# loop hasn't been initialized yet; try our best to
# get an error message out. (the runner() interaction
# should probably be refactored).
try:
failure.raiseException()
except:
logging.error('exception before starting loop', exc_info=True)
d.addBoth(shutdown)
runner()
self.assertTrue(chunks)
return ''.join(chunks)
def twisted_coroutine_fetch(self, url, runner):
body = [None]
@gen.coroutine
def f():
# This is simpler than the non-coroutine version, but it cheats
# by reading the body in one blob instead of streaming it with
# a Protocol.
client = Agent(self.reactor)
response = yield client.request(b'GET', utf8(url))
with warnings.catch_warnings():
# readBody has a buggy DeprecationWarning in Twisted 15.0:
# https://twistedmatrix.com/trac/changeset/43379
warnings.simplefilter('ignore', category=DeprecationWarning)
body[0] = yield readBody(response)
self.stop_loop()
self.io_loop.add_callback(f)
runner()
return body[0]
def testTwistedServerTornadoClientIOLoop(self):
self.start_twisted_server()
response = self.tornado_fetch(
'http://localhost:%d' % self.twisted_port, self.run_ioloop)
'http://127.0.0.1:%d' % self.twisted_port, self.run_ioloop)
self.assertEqual(response.body, 'Hello from twisted!')
def testTwistedServerTornadoClientReactor(self):
self.start_twisted_server()
response = self.tornado_fetch(
'http://localhost:%d' % self.twisted_port, self.run_reactor)
'http://127.0.0.1:%d' % self.twisted_port, self.run_reactor)
self.assertEqual(response.body, 'Hello from twisted!')
def testTornadoServerTwistedClientIOLoop(self):
self.start_tornado_server()
response = self.twisted_fetch(
'http://localhost:%d' % self.tornado_port, self.run_ioloop)
'http://127.0.0.1:%d' % self.tornado_port, self.run_ioloop)
self.assertEqual(response, 'Hello from tornado!')
def testTornadoServerTwistedClientReactor(self):
self.start_tornado_server()
response = self.twisted_fetch(
'http://localhost:%d' % self.tornado_port, self.run_reactor)
'http://127.0.0.1:%d' % self.tornado_port, self.run_reactor)
self.assertEqual(response, 'Hello from tornado!')
@skipIfPy26
def testTornadoServerTwistedCoroutineClientIOLoop(self):
self.start_tornado_server()
response = self.twisted_coroutine_fetch(
'http://127.0.0.1:%d' % self.tornado_port, self.run_ioloop)
self.assertEqual(response, 'Hello from tornado!')
@skipIfNoTwisted
@skipIfPy26
class ConvertDeferredTest(unittest.TestCase):
def test_success(self):
@inlineCallbacks
def fn():
if False:
# inlineCallbacks doesn't work with regular functions;
# must have a yield even if it's unreachable.
yield
returnValue(42)
f = gen.convert_yielded(fn())
self.assertEqual(f.result(), 42)
def test_failure(self):
@inlineCallbacks
def fn():
if False:
yield
1 / 0
f = gen.convert_yielded(fn())
with self.assertRaises(ZeroDivisionError):
f.result()
if have_twisted:
# Import and run as much of twisted's test suite as possible.
@ -481,9 +552,13 @@ if have_twisted:
# with py27+, but not unittest2 on py26.
'test_changeGID',
'test_changeUID',
# This test sometimes fails with EPIPE on a call to
# kqueue.control. Happens consistently for me with
# trollius but not asyncio or other IOLoops.
'test_childConnectionLost',
],
# Process tests appear to work on OSX 10.7, but not 10.6
#'twisted.internet.test.test_process.PTYProcessTestsBuilder': [
# 'twisted.internet.test.test_process.PTYProcessTestsBuilder': [
# 'test_systemCallUninterruptedByChildExit',
# ],
'twisted.internet.test.test_tcp.TCPClientTestsBuilder': [
@ -502,7 +577,7 @@ if have_twisted:
'twisted.internet.test.test_threads.ThreadTestsBuilder': [],
'twisted.internet.test.test_time.TimeTestsBuilder': [],
# Extra third-party dependencies (pyOpenSSL)
#'twisted.internet.test.test_tls.SSLClientTestsMixin': [],
# 'twisted.internet.test.test_tls.SSLClientTestsMixin': [],
'twisted.internet.test.test_udp.UDPServerTestsBuilder': [],
'twisted.internet.test.test_unix.UNIXTestsBuilder': [
# Platform-specific. These tests would be skipped automatically
@ -522,6 +597,14 @@ if have_twisted:
],
'twisted.internet.test.test_unix.UNIXPortTestsBuilder': [],
}
if sys.version_info >= (3,):
# In Twisted 15.2.0 on Python 3.4, the process tests will try to run
# but fail, due in part to interactions between Tornado's strict
# warnings-as-errors policy and Twisted's own warning handling
# (it was not obvious how to configure the warnings module to
# reconcile the two), and partly due to what looks like a packaging
# error (process_cli.py missing). For now, just skip it.
del twisted_tests['twisted.internet.test.test_process.ProcessTestsBuilder']
for test_name, blacklist in twisted_tests.items():
try:
test_class = import_object(test_name)
@ -551,6 +634,24 @@ if have_twisted:
os.chdir(self.__curdir)
shutil.rmtree(self.__tempdir)
def flushWarnings(self, *args, **kwargs):
# This is a hack because Twisted and Tornado have
# differing approaches to warnings in tests.
# Tornado sets up a global set of warnings filters
# in runtests.py, while Twisted patches the filter
# list in each test. The net effect is that
# Twisted's tests run with Tornado's increased
# strictness (BytesWarning and ResourceWarning are
# enabled) but without our filter rules to ignore those
# warnings from Twisted code.
filtered = []
for w in super(TornadoTest, self).flushWarnings(
*args, **kwargs):
if w['category'] in (BytesWarning, ResourceWarning):
continue
filtered.append(w)
return filtered
def buildReactor(self):
self.__saved_signals = save_signal_handlers()
return test_class.buildReactor(self)
@ -579,6 +680,14 @@ if have_twisted:
# log.startLoggingWithObserver(log.PythonLoggingObserver().emit, setStdout=0)
# import logging; logging.getLogger('twisted').setLevel(logging.WARNING)
# Twisted recently introduced a new logger; disable that one too.
try:
from twisted.logger import globalLogBeginner
except ImportError:
pass
else:
globalLogBeginner.beginLoggingTo([])
if have_twisted:
class LayeredTwistedIOLoop(TwistedIOLoop):
"""Layers a TwistedIOLoop on top of a TornadoReactor on a SelectIOLoop.
@ -588,13 +697,13 @@ if have_twisted:
correctly. In some tests another TornadoReactor is layered on top
of the whole stack.
"""
def initialize(self):
def initialize(self, **kwargs):
# When configured to use LayeredTwistedIOLoop we can't easily
# get the next-best IOLoop implementation, so use the lowest common
# denominator.
self.real_io_loop = SelectIOLoop()
self.real_io_loop = SelectIOLoop(make_current=False)
reactor = TornadoReactor(io_loop=self.real_io_loop)
super(LayeredTwistedIOLoop, self).initialize(reactor=reactor)
super(LayeredTwistedIOLoop, self).initialize(reactor=reactor, **kwargs)
self.add_callback(self.make_current)
def close(self, all_fds=False):
@ -612,7 +721,12 @@ if have_twisted:
# tornado-on-twisted-on-tornado. I'm clearly missing something
# about the startup/crash semantics, but since stop and crash
# are really only used in tests it doesn't really matter.
self.reactor.callWhenRunning(self.reactor.crash)
def f():
self.reactor.crash()
# Become current again on restart. This is needed to
# override real_io_loop's claim to being the current loop.
self.add_callback(self.make_current)
self.reactor.callWhenRunning(f)
if __name__ == "__main__":
unittest.main()

View file

@ -1,8 +1,12 @@
from __future__ import absolute_import, division, print_function, with_statement
import os
import platform
import socket
import sys
import textwrap
from tornado.testing import bind_unused_port
# Encapsulate the choice of unittest or unittest2 here.
# To be used as 'from tornado.test.util import unittest'.
@ -22,9 +26,53 @@ skipIfNonUnix = unittest.skipIf(os.name != 'posix' or sys.platform == 'cygwin',
skipOnTravis = unittest.skipIf('TRAVIS' in os.environ,
'timing tests unreliable on travis')
skipOnAppEngine = unittest.skipIf('APPENGINE_RUNTIME' in os.environ,
'not available on Google App Engine')
# Set the environment variable NO_NETWORK=1 to disable any tests that
# depend on an external network.
skipIfNoNetwork = unittest.skipIf('NO_NETWORK' in os.environ,
'network access disabled')
skipIfNoIPv6 = unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present')
skipBefore33 = unittest.skipIf(sys.version_info < (3, 3), 'PEP 380 (yield from) not available')
skipBefore35 = unittest.skipIf(sys.version_info < (3, 5), 'PEP 492 (async/await) not available')
skipNotCPython = unittest.skipIf(platform.python_implementation() != 'CPython',
'Not CPython implementation')
def refusing_port():
"""Returns a local port number that will refuse all connections.
Return value is (cleanup_func, port); the cleanup function
must be called to free the port to be reused.
"""
# On travis-ci, port numbers are reassigned frequently. To avoid
# collisions with other tests, we use an open client-side socket's
# ephemeral port number to ensure that nothing can listen on that
# port.
server_socket, port = bind_unused_port()
server_socket.setblocking(1)
client_socket = socket.socket()
client_socket.connect(("127.0.0.1", port))
conn, client_addr = server_socket.accept()
conn.close()
server_socket.close()
return (client_socket.close, client_addr[1])
def exec_test(caller_globals, caller_locals, s):
"""Execute ``s`` in a given context and return the result namespace.
Used to define functions for tests in particular python
versions that would be syntax errors in older versions.
"""
# Flatten the real global and local namespace into our fake
# globals: it's all global from the perspective of code defined
# in s.
global_namespace = dict(caller_globals, **caller_locals)
local_namespace = {}
exec(textwrap.dedent(s), global_namespace, local_namespace)
return local_namespace

View file

@ -1,9 +1,11 @@
# coding: utf-8
from __future__ import absolute_import, division, print_function, with_statement
import sys
import datetime
import tornado.escape
from tornado.escape import utf8
from tornado.util import raise_exc_info, Configurable, u, exec_in, ArgReplacer
from tornado.util import raise_exc_info, Configurable, u, exec_in, ArgReplacer, timedelta_to_seconds, import_object
from tornado.test.util import unittest
try:
@ -44,13 +46,15 @@ class TestConfigurable(Configurable):
class TestConfig1(TestConfigurable):
def initialize(self, a=None):
def initialize(self, pos_arg=None, a=None):
self.a = a
self.pos_arg = pos_arg
class TestConfig2(TestConfigurable):
def initialize(self, b=None):
def initialize(self, pos_arg=None, b=None):
self.b = b
self.pos_arg = pos_arg
class ConfigurableTest(unittest.TestCase):
@ -100,9 +104,10 @@ class ConfigurableTest(unittest.TestCase):
self.assertIsInstance(obj, TestConfig1)
self.assertEqual(obj.a, 3)
obj = TestConfigurable(a=4)
obj = TestConfigurable(42, a=4)
self.assertIsInstance(obj, TestConfig1)
self.assertEqual(obj.a, 4)
self.assertEqual(obj.pos_arg, 42)
self.checkSubclasses()
# args bound in configure don't apply when using the subclass directly
@ -115,9 +120,10 @@ class ConfigurableTest(unittest.TestCase):
self.assertIsInstance(obj, TestConfig2)
self.assertEqual(obj.b, 5)
obj = TestConfigurable(b=6)
obj = TestConfigurable(42, b=6)
self.assertIsInstance(obj, TestConfig2)
self.assertEqual(obj.b, 6)
self.assertEqual(obj.pos_arg, 42)
self.checkSubclasses()
# args bound in configure don't apply when using the subclass directly
@ -170,3 +176,26 @@ class ArgReplacerTest(unittest.TestCase):
self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old')
self.assertEqual(self.replacer.replace('new', args, kwargs),
('old', (1,), dict(y=2, callback='new', z=3)))
class TimedeltaToSecondsTest(unittest.TestCase):
def test_timedelta_to_seconds(self):
time_delta = datetime.timedelta(hours=1)
self.assertEqual(timedelta_to_seconds(time_delta), 3600.0)
class ImportObjectTest(unittest.TestCase):
def test_import_member(self):
self.assertIs(import_object('tornado.escape.utf8'), utf8)
def test_import_member_unicode(self):
self.assertIs(import_object(u('tornado.escape.utf8')), utf8)
def test_import_module(self):
self.assertIs(import_object('tornado.escape'), tornado.escape)
def test_import_module_unicode(self):
# The internal implementation of __import__ differs depending on
# whether the thing being imported is a module or not.
# This variant requires a byte string in python 2.
self.assertIs(import_object(u('tornado.escape')), tornado.escape)

View file

@ -3,25 +3,29 @@ from tornado.concurrent import Future
from tornado import gen
from tornado.escape import json_decode, utf8, to_unicode, recursive_unicode, native_str, to_basestring
from tornado.httputil import format_timestamp
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream
from tornado import locale
from tornado.log import app_log, gen_log
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.template import DictLoader
from tornado.testing import AsyncHTTPTestCase, ExpectLog, gen_test
from tornado.test.util import unittest
from tornado.util import u, bytes_type, ObjectDict, unicode_type
from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
from tornado.test.util import unittest, skipBefore35, exec_test
from tornado.util import u, ObjectDict, unicode_type, timedelta_to_seconds
from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish, removeslash, addslash, RedirectHandler as WebRedirectHandler, get_signature_key_version, GZipContentEncoding
import binascii
import contextlib
import copy
import datetime
import email.utils
import gzip
from io import BytesIO
import itertools
import logging
import os
import re
import socket
import sys
try:
import urllib.parse as urllib_parse # py3
@ -71,10 +75,14 @@ class HelloHandler(RequestHandler):
class CookieTestRequestHandler(RequestHandler):
# stub out enough methods to make the secure_cookie functions work
def __init__(self):
def __init__(self, cookie_secret='0123456789', key_version=None):
# don't call super.__init__
self._cookies = {}
self.application = ObjectDict(settings=dict(cookie_secret='0123456789'))
if key_version is None:
self.application = ObjectDict(settings=dict(cookie_secret=cookie_secret))
else:
self.application = ObjectDict(settings=dict(cookie_secret=cookie_secret,
key_version=key_version))
def get_cookie(self, name):
return self._cookies.get(name)
@ -128,6 +136,51 @@ class SecureCookieV1Test(unittest.TestCase):
self.assertEqual(handler.get_secure_cookie('foo', min_version=1), b'\xe9')
# See SignedValueTest below for more.
class SecureCookieV2Test(unittest.TestCase):
KEY_VERSIONS = {
0: 'ajklasdf0ojaisdf',
1: 'aslkjasaolwkjsdf'
}
def test_round_trip(self):
handler = CookieTestRequestHandler()
handler.set_secure_cookie('foo', b'bar', version=2)
self.assertEqual(handler.get_secure_cookie('foo', min_version=2), b'bar')
def test_key_version_roundtrip(self):
handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
key_version=0)
handler.set_secure_cookie('foo', b'bar')
self.assertEqual(handler.get_secure_cookie('foo'), b'bar')
def test_key_version_roundtrip_differing_version(self):
handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
key_version=1)
handler.set_secure_cookie('foo', b'bar')
self.assertEqual(handler.get_secure_cookie('foo'), b'bar')
def test_key_version_increment_version(self):
handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
key_version=0)
handler.set_secure_cookie('foo', b'bar')
new_handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
key_version=1)
new_handler._cookies = handler._cookies
self.assertEqual(new_handler.get_secure_cookie('foo'), b'bar')
def test_key_version_invalidate_version(self):
handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
key_version=0)
handler.set_secure_cookie('foo', b'bar')
new_key_versions = self.KEY_VERSIONS.copy()
new_key_versions.pop(0)
new_handler = CookieTestRequestHandler(cookie_secret=new_key_versions,
key_version=1)
new_handler._cookies = handler._cookies
self.assertEqual(new_handler.get_secure_cookie('foo'), None)
class CookieTest(WebTestCase):
def get_handlers(self):
class SetCookieHandler(RequestHandler):
@ -163,11 +216,29 @@ class CookieTest(WebTestCase):
# Attributes from the first call are not carried over.
self.set_cookie("a", "e")
class SetCookieMaxAgeHandler(RequestHandler):
def get(self):
self.set_cookie("foo", "bar", max_age=10)
class SetCookieExpiresDaysHandler(RequestHandler):
def get(self):
self.set_cookie("foo", "bar", expires_days=10)
class SetCookieFalsyFlags(RequestHandler):
def get(self):
self.set_cookie("a", "1", secure=True)
self.set_cookie("b", "1", secure=False)
self.set_cookie("c", "1", httponly=True)
self.set_cookie("d", "1", httponly=False)
return [("/set", SetCookieHandler),
("/get", GetCookieHandler),
("/set_domain", SetCookieDomainHandler),
("/special_char", SetCookieSpecialCharHandler),
("/set_overwrite", SetCookieOverwriteHandler),
("/set_max_age", SetCookieMaxAgeHandler),
("/set_expires_days", SetCookieExpiresDaysHandler),
("/set_falsy_flags", SetCookieFalsyFlags)
]
def test_set_cookie(self):
@ -222,6 +293,33 @@ class CookieTest(WebTestCase):
self.assertEqual(sorted(headers),
["a=e; Path=/", "c=d; Domain=example.com; Path=/"])
def test_set_cookie_max_age(self):
response = self.fetch("/set_max_age")
headers = response.headers.get_list("Set-Cookie")
self.assertEqual(sorted(headers),
["foo=bar; Max-Age=10; Path=/"])
def test_set_cookie_expires_days(self):
response = self.fetch("/set_expires_days")
header = response.headers.get("Set-Cookie")
match = re.match("foo=bar; expires=(?P<expires>.+); Path=/", header)
self.assertIsNotNone(match)
expires = datetime.datetime.utcnow() + datetime.timedelta(days=10)
header_expires = datetime.datetime(
*email.utils.parsedate(match.groupdict()["expires"])[:6])
self.assertTrue(abs(timedelta_to_seconds(expires - header_expires)) < 10)
def test_set_cookie_false_flags(self):
response = self.fetch("/set_falsy_flags")
headers = sorted(response.headers.get_list("Set-Cookie"))
# The secure and httponly headers are capitalized in py35 and
# lowercase in older versions.
self.assertEqual(headers[0].lower(), 'a=1; path=/; secure')
self.assertEqual(headers[1].lower(), 'b=1; path=/')
self.assertEqual(headers[2].lower(), 'c=1; httponly; path=/')
self.assertEqual(headers[3].lower(), 'd=1; path=/')
class AuthRedirectRequestHandler(RequestHandler):
def initialize(self, login_url):
@ -278,7 +376,7 @@ class ConnectionCloseTest(WebTestCase):
def test_connection_close(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
s.connect(("localhost", self.get_http_port()))
s.connect(("127.0.0.1", self.get_http_port()))
self.stream = IOStream(s, io_loop=self.io_loop)
self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
self.wait()
@ -302,7 +400,7 @@ class EchoHandler(RequestHandler):
if type(key) != str:
raise Exception("incorrect type for key: %r" % type(key))
for value in self.request.arguments[key]:
if type(value) != bytes_type:
if type(value) != bytes:
raise Exception("incorrect type for value: %r" %
type(value))
for value in self.get_arguments(key):
@ -352,6 +450,12 @@ class RequestEncodingTest(WebTestCase):
path_args=["a/b", "c/d"],
args={}))
def test_error(self):
# Percent signs (encoded as %25) should not mess up printf-style
# messages in logs
with ExpectLog(gen_log, ".*Invalid unicode"):
self.fetch("/group/?arg=%25%e9")
class TypeCheckHandler(RequestHandler):
def prepare(self):
@ -370,10 +474,10 @@ class TypeCheckHandler(RequestHandler):
if list(self.cookies.keys()) != ['asdf']:
raise Exception("unexpected values for cookie keys: %r" %
self.cookies.keys())
self.check_type('get_secure_cookie', self.get_secure_cookie('asdf'), bytes_type)
self.check_type('get_secure_cookie', self.get_secure_cookie('asdf'), bytes)
self.check_type('get_cookie', self.get_cookie('asdf'), str)
self.check_type('xsrf_token', self.xsrf_token, bytes_type)
self.check_type('xsrf_token', self.xsrf_token, bytes)
self.check_type('xsrf_form_html', self.xsrf_form_html(), str)
self.check_type('reverse_url', self.reverse_url('typecheck', 'foo'), str)
@ -399,7 +503,7 @@ class TypeCheckHandler(RequestHandler):
class DecodeArgHandler(RequestHandler):
def decode_argument(self, value, name=None):
if type(value) != bytes_type:
if type(value) != bytes:
raise Exception("unexpected type for value: %r" % type(value))
# use self.request.arguments directly to avoid recursion
if 'encoding' in self.request.arguments:
@ -409,7 +513,7 @@ class DecodeArgHandler(RequestHandler):
def get(self, arg):
def describe(s):
if type(s) == bytes_type:
if type(s) == bytes:
return ["bytes", native_str(binascii.b2a_hex(s))]
elif type(s) == unicode_type:
return ["unicode", s]
@ -470,8 +574,8 @@ class RedirectHandler(RequestHandler):
class EmptyFlushCallbackHandler(RequestHandler):
@gen.engine
@asynchronous
@gen.engine
def get(self):
# Ensure that the flush callback is run whether or not there
# was any output. The gen.Task and direct yield forms are
@ -550,6 +654,9 @@ class WSGISafeWebTest(WebTestCase):
url("/optional_path/(.+)?", OptionalPathHandler),
url("/multi_header", MultiHeaderHandler),
url("/redirect", RedirectHandler),
url("/web_redirect_permanent", WebRedirectHandler, {"url": "/web_redirect_newpath"}),
url("/web_redirect", WebRedirectHandler, {"url": "/web_redirect_newpath", "permanent": False}),
url("//web_redirect_double_slash", WebRedirectHandler, {"url": '/web_redirect_newpath'}),
url("/header_injection", HeaderInjectionHandler),
url("/get_argument", GetArgumentHandler),
url("/get_arguments", GetArgumentsHandler),
@ -675,6 +782,19 @@ js_embed()
response = self.fetch("/redirect?status=307", follow_redirects=False)
self.assertEqual(response.code, 307)
def test_web_redirect(self):
response = self.fetch("/web_redirect_permanent", follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], '/web_redirect_newpath')
response = self.fetch("/web_redirect", follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertEqual(response.headers['Location'], '/web_redirect_newpath')
def test_web_redirect_double_slash(self):
response = self.fetch("//web_redirect_double_slash", follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], '/web_redirect_newpath')
def test_header_injection(self):
response = self.fetch("/header_injection")
self.assertEqual(response.body, b"ok")
@ -851,7 +971,8 @@ class StaticFileTest(WebTestCase):
return [('/static_url/(.*)', StaticUrlHandler),
('/abs_static_url/(.*)', AbsoluteStaticUrlHandler),
('/override_static_url/(.*)', OverrideStaticUrlHandler)]
('/override_static_url/(.*)', OverrideStaticUrlHandler),
('/root_static/(.*)', StaticFileHandler, dict(path='/'))]
def get_app_kwargs(self):
return dict(static_path=relpath('static'))
@ -862,6 +983,19 @@ class StaticFileTest(WebTestCase):
response = self.fetch('/static/robots.txt')
self.assertTrue(b"Disallow: /" in response.body)
self.assertEqual(response.headers.get("Content-Type"), "text/plain")
def test_static_compressed_files(self):
response = self.fetch("/static/sample.xml.gz")
self.assertEqual(response.headers.get("Content-Type"),
"application/gzip")
response = self.fetch("/static/sample.xml.bz2")
self.assertEqual(response.headers.get("Content-Type"),
"application/octet-stream")
# make sure the uncompressed file still has the correct type
response = self.fetch("/static/sample.xml")
self.assertTrue(response.headers.get("Content-Type")
in set(("text/xml", "application/xml")))
def test_static_url(self):
response = self.fetch("/static_url/robots.txt")
@ -1065,6 +1199,30 @@ class StaticFileTest(WebTestCase):
response = self.get_and_head('/static/blarg')
self.assertEqual(response.code, 404)
def test_path_traversal_protection(self):
# curl_httpclient processes ".." on the client side, so we
# must test this with simple_httpclient.
self.http_client.close()
self.http_client = SimpleAsyncHTTPClient()
with ExpectLog(gen_log, ".*not in root static directory"):
response = self.get_and_head('/static/../static_foo.txt')
# Attempted path traversal should result in 403, not 200
# (which means the check failed and the file was served)
# or 404 (which means that the file didn't exist and
# is probably a packaging error).
self.assertEqual(response.code, 403)
@unittest.skipIf(os.name != 'posix', 'non-posix OS')
def test_root_static_path(self):
# Sometimes people set the StaticFileHandler's path to '/'
# to disable Tornado's path validation (in conjunction with
# their own validation in get_absolute_path). Make sure
# that the stricter validation in 4.2.1 doesn't break them.
path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'static/robots.txt')
response = self.get_and_head('/root_static' + urllib_parse.quote(path))
self.assertEqual(response.code, 200)
@wsgi_safe
class StaticDefaultFilenameTest(WebTestCase):
@ -1345,10 +1503,13 @@ class GzipTestCase(SimpleHandlerTestCase):
def get(self):
if self.get_argument('vary', None):
self.set_header('Vary', self.get_argument('vary'))
self.write('hello world')
# Must write at least MIN_LENGTH bytes to activate compression.
self.write('hello world' + ('!' * GZipContentEncoding.MIN_LENGTH))
def get_app_kwargs(self):
return dict(gzip=True)
return dict(
gzip=True,
static_path=os.path.join(os.path.dirname(__file__), 'static'))
def test_gzip(self):
response = self.fetch('/')
@ -1361,6 +1522,17 @@ class GzipTestCase(SimpleHandlerTestCase):
'gzip')
self.assertEqual(response.headers['Vary'], 'Accept-Encoding')
def test_gzip_static(self):
# The streaming responses in StaticFileHandler have subtle
# interactions with the gzip output so test this case separately.
response = self.fetch('/robots.txt')
self.assertEqual(
response.headers.get(
'Content-Encoding',
response.headers.get('X-Consumed-Content-Encoding')),
'gzip')
self.assertEqual(response.headers['Vary'], 'Accept-Encoding')
def test_gzip_not_requested(self):
response = self.fetch('/', use_gzip=False)
self.assertNotIn('Content-Encoding', response.headers)
@ -1409,8 +1581,11 @@ class ClearAllCookiesTest(SimpleHandlerTestCase):
def test_clear_all_cookies(self):
response = self.fetch('/', headers={'Cookie': 'foo=bar; baz=xyzzy'})
set_cookies = sorted(response.headers.get_list('Set-Cookie'))
self.assertTrue(set_cookies[0].startswith('baz=;'))
self.assertTrue(set_cookies[1].startswith('foo=;'))
# Python 3.5 sends 'baz="";'; older versions use 'baz=;'
self.assertTrue(set_cookies[0].startswith('baz=;') or
set_cookies[0].startswith('baz="";'))
self.assertTrue(set_cookies[1].startswith('foo=;') or
set_cookies[1].startswith('foo="";'))
class PermissionError(Exception):
@ -1467,6 +1642,22 @@ class ExceptionHandlerTest(SimpleHandlerTestCase):
self.assertEqual(response.code, 403)
@wsgi_safe
class BuggyLoggingTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
def get(self):
1 / 0
def log_exception(self, typ, value, tb):
1 / 0
def test_buggy_log_exception(self):
# Something gets logged even though the application's
# logger is broken.
with ExpectLog(app_log, '.*'):
self.fetch('/')
@wsgi_safe
class UIMethodUIModuleTest(SimpleHandlerTestCase):
"""Test that UI methods and modules are created correctly and
@ -1483,6 +1674,7 @@ class UIMethodUIModuleTest(SimpleHandlerTestCase):
def my_ui_method(handler, x):
return "In my_ui_method(%s) with handler value %s." % (
x, handler.value())
class MyModule(UIModule):
def render(self, x):
return "In MyModule(%s) with handler value %s." % (
@ -1554,19 +1746,26 @@ class MultipleExceptionTest(SimpleHandlerTestCase):
@wsgi_safe
class SetCurrentUserTest(SimpleHandlerTestCase):
class SetLazyPropertiesTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
def prepare(self):
self.current_user = 'Ben'
self.locale = locale.get('en_US')
def get_user_locale(self):
raise NotImplementedError()
def get_current_user(self):
raise NotImplementedError()
def get(self):
self.write('Hello %s' % self.current_user)
self.write('Hello %s (%s)' % (self.current_user, self.locale.code))
def test_set_current_user(self):
def test_set_properties(self):
# Ensure that current_user can be assigned to normally for apps
# that want to forgo the lazy get_current_user property
response = self.fetch('/')
self.assertEqual(response.body, b'Hello Ben')
self.assertEqual(response.body, b'Hello Ben (en_US)')
@wsgi_safe
@ -1850,7 +2049,7 @@ class StreamingRequestBodyTest(WebTestCase):
def connect(self, url, connection_close):
# Use a raw connection so we can control the sending of data.
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
s.connect(("localhost", self.get_http_port()))
s.connect(("127.0.0.1", self.get_http_port()))
stream = IOStream(s, io_loop=self.io_loop)
stream.write(b"GET " + url + b" HTTP/1.1\r\n")
if connection_close:
@ -1902,57 +2101,55 @@ class StreamingRequestBodyTest(WebTestCase):
yield self.close_future
class StreamingRequestFlowControlTest(WebTestCase):
def get_handlers(self):
from tornado.ioloop import IOLoop
# Each method in this handler returns a yieldable object and yields to the
# IOLoop so the future is not immediately ready. Ensure that the
# yieldables are respected and no method is called before the previous
# one has completed.
@stream_request_body
class BaseFlowControlHandler(RequestHandler):
def initialize(self, test):
self.test = test
self.method = None
self.methods = []
# Each method in this handler returns a Future and yields to the
# IOLoop so the future is not immediately ready. Ensure that the
# Futures are respected and no method is called before the previous
# one has completed.
@stream_request_body
class FlowControlHandler(RequestHandler):
def initialize(self, test):
self.test = test
self.method = None
self.methods = []
@contextlib.contextmanager
def in_method(self, method):
if self.method is not None:
self.test.fail("entered method %s while in %s" %
(method, self.method))
self.method = method
self.methods.append(method)
try:
yield
finally:
self.method = None
@contextlib.contextmanager
def in_method(self, method):
if self.method is not None:
self.test.fail("entered method %s while in %s" %
(method, self.method))
self.method = method
self.methods.append(method)
try:
yield
finally:
self.method = None
@gen.coroutine
def prepare(self):
# Note that asynchronous prepare() does not block data_received,
# so we don't use in_method here.
self.methods.append('prepare')
yield gen.Task(IOLoop.current().add_callback)
@gen.coroutine
def prepare(self):
with self.in_method('prepare'):
yield gen.Task(IOLoop.current().add_callback)
@gen.coroutine
def post(self):
with self.in_method('post'):
yield gen.Task(IOLoop.current().add_callback)
self.write(dict(methods=self.methods))
@gen.coroutine
def data_received(self, data):
with self.in_method('data_received'):
yield gen.Task(IOLoop.current().add_callback)
@gen.coroutine
def post(self):
with self.in_method('post'):
yield gen.Task(IOLoop.current().add_callback)
self.write(dict(methods=self.methods))
return [('/', FlowControlHandler, dict(test=self))]
class BaseStreamingRequestFlowControlTest(object):
def get_httpserver_options(self):
# Use a small chunk size so flow control is relevant even though
# all the data arrives at once.
return dict(chunk_size=10)
return dict(chunk_size=10, decompress_request=True)
def test_flow_control(self):
def get_http_client(self):
# simple_httpclient only: curl doesn't support body_producer.
return SimpleAsyncHTTPClient(io_loop=self.io_loop)
# Test all the slightly different code paths for fixed, chunked, etc bodies.
def test_flow_control_fixed_body(self):
response = self.fetch('/', body='abcdefghijklmnopqrstuvwxyz',
method='POST')
response.rethrow()
@ -1961,6 +2158,58 @@ class StreamingRequestFlowControlTest(WebTestCase):
'data_received', 'data_received',
'post']))
def test_flow_control_chunked_body(self):
chunks = [b'abcd', b'efgh', b'ijkl']
@gen.coroutine
def body_producer(write):
for i in chunks:
yield write(i)
response = self.fetch('/', body_producer=body_producer, method='POST')
response.rethrow()
self.assertEqual(json_decode(response.body),
dict(methods=['prepare', 'data_received',
'data_received', 'data_received',
'post']))
def test_flow_control_compressed_body(self):
bytesio = BytesIO()
gzip_file = gzip.GzipFile(mode='w', fileobj=bytesio)
gzip_file.write(b'abcdefghijklmnopqrstuvwxyz')
gzip_file.close()
compressed_body = bytesio.getvalue()
response = self.fetch('/', body=compressed_body, method='POST',
headers={'Content-Encoding': 'gzip'})
response.rethrow()
self.assertEqual(json_decode(response.body),
dict(methods=['prepare', 'data_received',
'data_received', 'data_received',
'post']))
class DecoratedStreamingRequestFlowControlTest(
BaseStreamingRequestFlowControlTest,
WebTestCase):
def get_handlers(self):
class DecoratedFlowControlHandler(BaseFlowControlHandler):
@gen.coroutine
def data_received(self, data):
with self.in_method('data_received'):
yield gen.Task(IOLoop.current().add_callback)
return [('/', DecoratedFlowControlHandler, dict(test=self))]
@skipBefore35
class NativeStreamingRequestFlowControlTest(
BaseStreamingRequestFlowControlTest,
WebTestCase):
def get_handlers(self):
class NativeFlowControlHandler(BaseFlowControlHandler):
data_received = exec_test(globals(), locals(), """
async def data_received(self, data):
with self.in_method('data_received'):
await gen.Task(IOLoop.current().add_callback)
""")["data_received"]
return [('/', NativeFlowControlHandler, dict(test=self))]
@wsgi_safe
class IncorrectContentLengthTest(SimpleHandlerTestCase):
@ -1994,9 +2243,10 @@ class IncorrectContentLengthTest(SimpleHandlerTestCase):
# When the content-length is too high, the connection is simply
# closed without completing the response. An error is logged on
# the server.
with ExpectLog(app_log, "Uncaught exception"):
with ExpectLog(app_log, "(Uncaught exception|Exception in callback)"):
with ExpectLog(gen_log,
"Cannot send error response after headers written"):
"(Cannot send error response after headers written"
"|Failed to flush partial response)"):
response = self.fetch("/high")
self.assertEqual(response.code, 599)
self.assertEqual(str(self.server_error),
@ -2006,9 +2256,10 @@ class IncorrectContentLengthTest(SimpleHandlerTestCase):
# When the content-length is too low, the connection is closed
# without writing the last chunk, so the client never sees the request
# complete (which would be a framing error).
with ExpectLog(app_log, "Uncaught exception"):
with ExpectLog(app_log, "(Uncaught exception|Exception in callback)"):
with ExpectLog(gen_log,
"Cannot send error response after headers written"):
"(Cannot send error response after headers written"
"|Failed to flush partial response)"):
response = self.fetch("/low")
self.assertEqual(response.code, 599)
self.assertEqual(str(self.server_error),
@ -2018,21 +2269,28 @@ class IncorrectContentLengthTest(SimpleHandlerTestCase):
class ClientCloseTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
def get(self):
# Simulate a connection closed by the client during
# request processing. The client will see an error, but the
# server should respond gracefully (without logging errors
# because we were unable to write out as many bytes as
# Content-Length said we would)
self.request.connection.stream.close()
self.write('hello')
if self.request.version.startswith('HTTP/1'):
# Simulate a connection closed by the client during
# request processing. The client will see an error, but the
# server should respond gracefully (without logging errors
# because we were unable to write out as many bytes as
# Content-Length said we would)
self.request.connection.stream.close()
self.write('hello')
else:
# TODO: add a HTTP2-compatible version of this test.
self.write('requires HTTP/1.x')
def test_client_close(self):
response = self.fetch('/')
if response.body == b'requires HTTP/1.x':
self.skipTest('requires HTTP/1.x')
self.assertEqual(response.code, 599)
class SignedValueTest(unittest.TestCase):
SECRET = "It's a secret to everybody"
SECRET_DICT = {0: "asdfbasdf", 1: "12312312", 2: "2342342"}
def past(self):
return self.present() - 86400 * 32
@ -2094,6 +2352,7 @@ class SignedValueTest(unittest.TestCase):
def test_payload_tampering(self):
# These cookies are variants of the one in test_known_values.
sig = "3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e152"
def validate(prefix):
return (b'value' ==
decode_signed_value(SignedValueTest.SECRET, "key",
@ -2108,6 +2367,7 @@ class SignedValueTest(unittest.TestCase):
def test_signature_tampering(self):
prefix = "2|1:0|10:1300000000|3:key|8:dmFsdWU=|"
def validate(sig):
return (b'value' ==
decode_signed_value(SignedValueTest.SECRET, "key",
@ -2137,6 +2397,43 @@ class SignedValueTest(unittest.TestCase):
clock=self.present)
self.assertEqual(value, decoded)
def test_key_versioning_read_write_default_key(self):
value = b"\xe9"
signed = create_signed_value(SignedValueTest.SECRET_DICT,
"key", value, clock=self.present,
key_version=0)
decoded = decode_signed_value(SignedValueTest.SECRET_DICT,
"key", signed, clock=self.present)
self.assertEqual(value, decoded)
def test_key_versioning_read_write_non_default_key(self):
value = b"\xe9"
signed = create_signed_value(SignedValueTest.SECRET_DICT,
"key", value, clock=self.present,
key_version=1)
decoded = decode_signed_value(SignedValueTest.SECRET_DICT,
"key", signed, clock=self.present)
self.assertEqual(value, decoded)
def test_key_versioning_invalid_key(self):
value = b"\xe9"
signed = create_signed_value(SignedValueTest.SECRET_DICT,
"key", value, clock=self.present,
key_version=0)
newkeys = SignedValueTest.SECRET_DICT.copy()
newkeys.pop(0)
decoded = decode_signed_value(newkeys,
"key", signed, clock=self.present)
self.assertEqual(None, decoded)
def test_key_version_retrieval(self):
value = b"\xe9"
signed = create_signed_value(SignedValueTest.SECRET_DICT,
"key", value, clock=self.present,
key_version=1)
key_version = get_signature_key_version(signed)
self.assertEqual(1, key_version)
@wsgi_safe
class XSRFTest(SimpleHandlerTestCase):
@ -2239,7 +2536,7 @@ class XSRFTest(SimpleHandlerTestCase):
token2 = self.get_token()
# Each token can be used to authenticate its own request.
for token in (self.xsrf_token, token2):
response = self.fetch(
response = self.fetch(
"/", method="POST",
body=urllib_parse.urlencode(dict(_xsrf=token)),
headers=self.cookie_headers(token))
@ -2298,18 +2595,171 @@ class XSRFTest(SimpleHandlerTestCase):
self.assertEqual(response.code, 200)
@wsgi_safe
class XSRFCookieKwargsTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
def get(self):
self.write(self.xsrf_token)
def get_app_kwargs(self):
return dict(xsrf_cookies=True,
xsrf_cookie_kwargs=dict(httponly=True))
def test_xsrf_httponly(self):
response = self.fetch("/")
self.assertIn('httponly;', response.headers['Set-Cookie'].lower())
@wsgi_safe
class FinishExceptionTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
def get(self):
self.set_status(401)
self.set_header('WWW-Authenticate', 'Basic realm="something"')
self.write('authentication required')
raise Finish()
if self.get_argument('finish_value', ''):
raise Finish('authentication required')
else:
self.write('authentication required')
raise Finish()
def test_finish_exception(self):
response = self.fetch('/')
self.assertEqual(response.code, 401)
self.assertEqual('Basic realm="something"',
response.headers.get('WWW-Authenticate'))
self.assertEqual(b'authentication required', response.body)
for url in ['/', '/?finish_value=1']:
response = self.fetch(url)
self.assertEqual(response.code, 401)
self.assertEqual('Basic realm="something"',
response.headers.get('WWW-Authenticate'))
self.assertEqual(b'authentication required', response.body)
@wsgi_safe
class DecoratorTest(WebTestCase):
def get_handlers(self):
class RemoveSlashHandler(RequestHandler):
@removeslash
def get(self):
pass
class AddSlashHandler(RequestHandler):
@addslash
def get(self):
pass
return [("/removeslash/", RemoveSlashHandler),
("/addslash", AddSlashHandler),
]
def test_removeslash(self):
response = self.fetch("/removeslash/", follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], "/removeslash")
response = self.fetch("/removeslash/?foo=bar", follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], "/removeslash?foo=bar")
def test_addslash(self):
response = self.fetch("/addslash", follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], "/addslash/")
response = self.fetch("/addslash?foo=bar", follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], "/addslash/?foo=bar")
@wsgi_safe
class CacheTest(WebTestCase):
def get_handlers(self):
class EtagHandler(RequestHandler):
def get(self, computed_etag):
self.write(computed_etag)
def compute_etag(self):
return self._write_buffer[0]
return [
('/etag/(.*)', EtagHandler)
]
def test_wildcard_etag(self):
computed_etag = '"xyzzy"'
etags = '*'
self._test_etag(computed_etag, etags, 304)
def test_strong_etag_match(self):
computed_etag = '"xyzzy"'
etags = '"xyzzy"'
self._test_etag(computed_etag, etags, 304)
def test_multiple_strong_etag_match(self):
computed_etag = '"xyzzy1"'
etags = '"xyzzy1", "xyzzy2"'
self._test_etag(computed_etag, etags, 304)
def test_strong_etag_not_match(self):
computed_etag = '"xyzzy"'
etags = '"xyzzy1"'
self._test_etag(computed_etag, etags, 200)
def test_multiple_strong_etag_not_match(self):
computed_etag = '"xyzzy"'
etags = '"xyzzy1", "xyzzy2"'
self._test_etag(computed_etag, etags, 200)
def test_weak_etag_match(self):
computed_etag = '"xyzzy1"'
etags = 'W/"xyzzy1"'
self._test_etag(computed_etag, etags, 304)
def test_multiple_weak_etag_match(self):
computed_etag = '"xyzzy2"'
etags = 'W/"xyzzy1", W/"xyzzy2"'
self._test_etag(computed_etag, etags, 304)
def test_weak_etag_not_match(self):
computed_etag = '"xyzzy2"'
etags = 'W/"xyzzy1"'
self._test_etag(computed_etag, etags, 200)
def test_multiple_weak_etag_not_match(self):
computed_etag = '"xyzzy3"'
etags = 'W/"xyzzy1", W/"xyzzy2"'
self._test_etag(computed_etag, etags, 200)
def _test_etag(self, computed_etag, etags, status_code):
response = self.fetch(
'/etag/' + computed_etag,
headers={'If-None-Match': etags}
)
self.assertEqual(response.code, status_code)
@wsgi_safe
class RequestSummaryTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
def get(self):
# remote_ip is optional, although it's set by
# both HTTPServer and WSGIAdapter.
# Clobber it to make sure it doesn't break logging.
self.request.remote_ip = None
self.finish(self._request_summary())
def test_missing_remote_ip(self):
resp = self.fetch("/")
self.assertEqual(resp.body, b"GET / (None)")
class HTTPErrorTest(unittest.TestCase):
def test_copy(self):
e = HTTPError(403, reason="Go away")
e2 = copy.copy(e)
self.assertIsNot(e, e2)
self.assertEqual(e.status_code, e2.status_code)
self.assertEqual(e.reason, e2.reason)
class ApplicationTest(AsyncTestCase):
def test_listen(self):
app = Application([])
server = app.listen(0, address='127.0.0.1')
server.stop()

View file

@ -3,6 +3,7 @@ from __future__ import absolute_import, division, print_function, with_statement
import traceback
from tornado.concurrent import Future
from tornado import gen
from tornado.httpclient import HTTPError, HTTPRequest
from tornado.log import gen_log, app_log
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
@ -11,7 +12,7 @@ from tornado.web import Application, RequestHandler
from tornado.util import u
try:
import tornado.websocket
import tornado.websocket # noqa
from tornado.util import _websocket_mask_python
except ImportError:
# The unittest module presents misleading errors on ImportError
@ -34,8 +35,12 @@ class TestWebSocketHandler(WebSocketHandler):
This allows for deterministic cleanup of the associated socket.
"""
def initialize(self, close_future):
def initialize(self, close_future, compression_options=None):
self.close_future = close_future
self.compression_options = compression_options
def get_compression_options(self):
return self.compression_options
def on_close(self):
self.close_future.set_result((self.close_code, self.close_reason))
@ -48,7 +53,7 @@ class EchoHandler(TestWebSocketHandler):
class ErrorInOnMessageHandler(TestWebSocketHandler):
def on_message(self, message):
1/0
1 / 0
class HeaderHandler(TestWebSocketHandler):
@ -70,10 +75,39 @@ class NonWebSocketHandler(RequestHandler):
class CloseReasonHandler(TestWebSocketHandler):
def open(self):
self.on_close_called = False
self.close(1001, "goodbye")
class WebSocketTest(AsyncHTTPTestCase):
class AsyncPrepareHandler(TestWebSocketHandler):
@gen.coroutine
def prepare(self):
yield gen.moment
def on_message(self, message):
self.write_message(message)
class WebSocketBaseTestCase(AsyncHTTPTestCase):
@gen.coroutine
def ws_connect(self, path, compression_options=None):
ws = yield websocket_connect(
'ws://127.0.0.1:%d%s' % (self.get_http_port(), path),
compression_options=compression_options)
raise gen.Return(ws)
@gen.coroutine
def close(self, ws):
"""Close a websocket connection and wait for the server side.
If we don't wait here, there are sometimes leak warnings in the
tests.
"""
ws.close()
yield self.close_future
class WebSocketTest(WebSocketBaseTestCase):
def get_app(self):
self.close_future = Future()
return Application([
@ -84,6 +118,8 @@ class WebSocketTest(AsyncHTTPTestCase):
dict(close_future=self.close_future)),
('/error_in_on_message', ErrorInOnMessageHandler,
dict(close_future=self.close_future)),
('/async_prepare', AsyncPrepareHandler,
dict(close_future=self.close_future)),
])
def test_http_request(self):
@ -93,18 +129,15 @@ class WebSocketTest(AsyncHTTPTestCase):
@gen_test
def test_websocket_gen(self):
ws = yield websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port(),
io_loop=self.io_loop)
ws.write_message('hello')
ws = yield self.ws_connect('/echo')
yield ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
ws.close()
yield self.close_future
yield self.close(ws)
def test_websocket_callbacks(self):
websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port(),
'ws://127.0.0.1:%d/echo' % self.get_http_port(),
io_loop=self.io_loop, callback=self.stop)
ws = self.wait().result()
ws.write_message('hello')
@ -117,49 +150,39 @@ class WebSocketTest(AsyncHTTPTestCase):
@gen_test
def test_binary_message(self):
ws = yield websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port())
ws = yield self.ws_connect('/echo')
ws.write_message(b'hello \xe9', binary=True)
response = yield ws.read_message()
self.assertEqual(response, b'hello \xe9')
ws.close()
yield self.close_future
yield self.close(ws)
@gen_test
def test_unicode_message(self):
ws = yield websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port())
ws = yield self.ws_connect('/echo')
ws.write_message(u('hello \u00e9'))
response = yield ws.read_message()
self.assertEqual(response, u('hello \u00e9'))
ws.close()
yield self.close_future
yield self.close(ws)
@gen_test
def test_error_in_on_message(self):
ws = yield websocket_connect(
'ws://localhost:%d/error_in_on_message' % self.get_http_port())
ws = yield self.ws_connect('/error_in_on_message')
ws.write_message('hello')
with ExpectLog(app_log, "Uncaught exception"):
response = yield ws.read_message()
self.assertIs(response, None)
ws.close()
yield self.close_future
yield self.close(ws)
@gen_test
def test_websocket_http_fail(self):
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(
'ws://localhost:%d/notfound' % self.get_http_port(),
io_loop=self.io_loop)
yield self.ws_connect('/notfound')
self.assertEqual(cm.exception.code, 404)
@gen_test
def test_websocket_http_success(self):
with self.assertRaises(WebSocketError):
yield websocket_connect(
'ws://localhost:%d/non_ws' % self.get_http_port(),
io_loop=self.io_loop)
yield self.ws_connect('/non_ws')
@gen_test
def test_websocket_network_fail(self):
@ -168,16 +191,17 @@ class WebSocketTest(AsyncHTTPTestCase):
with self.assertRaises(IOError):
with ExpectLog(gen_log, ".*"):
yield websocket_connect(
'ws://localhost:%d/' % port,
'ws://127.0.0.1:%d/' % port,
io_loop=self.io_loop,
connect_timeout=3600)
@gen_test
def test_websocket_close_buffered_data(self):
ws = yield websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port())
'ws://127.0.0.1:%d/echo' % self.get_http_port())
ws.write_message('hello')
ws.write_message('world')
# Close the underlying stream.
ws.stream.close()
yield self.close_future
@ -185,68 +209,78 @@ class WebSocketTest(AsyncHTTPTestCase):
def test_websocket_headers(self):
# Ensure that arbitrary headers can be passed through websocket_connect.
ws = yield websocket_connect(
HTTPRequest('ws://localhost:%d/header' % self.get_http_port(),
HTTPRequest('ws://127.0.0.1:%d/header' % self.get_http_port(),
headers={'X-Test': 'hello'}))
response = yield ws.read_message()
self.assertEqual(response, 'hello')
ws.close()
yield self.close_future
yield self.close(ws)
@gen_test
def test_server_close_reason(self):
ws = yield websocket_connect(
'ws://localhost:%d/close_reason' % self.get_http_port())
ws = yield self.ws_connect('/close_reason')
msg = yield ws.read_message()
# A message of None means the other side closed the connection.
self.assertIs(msg, None)
self.assertEqual(ws.close_code, 1001)
self.assertEqual(ws.close_reason, "goodbye")
# The on_close callback is called no matter which side closed.
code, reason = yield self.close_future
# The client echoed the close code it received to the server,
# so the server's close code (returned via close_future) is
# the same.
self.assertEqual(code, 1001)
@gen_test
def test_client_close_reason(self):
ws = yield websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port())
ws = yield self.ws_connect('/echo')
ws.close(1001, 'goodbye')
code, reason = yield self.close_future
self.assertEqual(code, 1001)
self.assertEqual(reason, 'goodbye')
@gen_test
def test_async_prepare(self):
# Previously, an async prepare method triggered a bug that would
# result in a timeout on test shutdown (and a memory leak).
ws = yield self.ws_connect('/async_prepare')
ws.write_message('hello')
res = yield ws.read_message()
self.assertEqual(res, 'hello')
@gen_test
def test_check_origin_valid_no_path(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
headers = {'Origin': 'http://localhost:%d' % port}
url = 'ws://127.0.0.1:%d/echo' % port
headers = {'Origin': 'http://127.0.0.1:%d' % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
io_loop=self.io_loop)
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
ws.close()
yield self.close_future
yield self.close(ws)
@gen_test
def test_check_origin_valid_with_path(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
headers = {'Origin': 'http://localhost:%d/something' % port}
url = 'ws://127.0.0.1:%d/echo' % port
headers = {'Origin': 'http://127.0.0.1:%d/something' % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
io_loop=self.io_loop)
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
ws.close()
yield self.close_future
yield self.close(ws)
@gen_test
def test_check_origin_invalid_partial_url(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
headers = {'Origin': 'localhost:%d' % port}
url = 'ws://127.0.0.1:%d/echo' % port
headers = {'Origin': '127.0.0.1:%d' % port}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers),
@ -257,8 +291,8 @@ class WebSocketTest(AsyncHTTPTestCase):
def test_check_origin_invalid(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
# Host is localhost, which should not be accessible from some other
url = 'ws://127.0.0.1:%d/echo' % port
# Host is 127.0.0.1, which should not be accessible from some other
# domain
headers = {'Origin': 'http://somewhereelse.com'}
@ -284,6 +318,78 @@ class WebSocketTest(AsyncHTTPTestCase):
self.assertEqual(cm.exception.code, 403)
class CompressionTestMixin(object):
MESSAGE = 'Hello world. Testing 123 123'
def get_app(self):
self.close_future = Future()
return Application([
('/echo', EchoHandler, dict(
close_future=self.close_future,
compression_options=self.get_server_compression_options())),
])
def get_server_compression_options(self):
return None
def get_client_compression_options(self):
return None
@gen_test
def test_message_sizes(self):
ws = yield self.ws_connect(
'/echo',
compression_options=self.get_client_compression_options())
# Send the same message three times so we can measure the
# effect of the context_takeover options.
for i in range(3):
ws.write_message(self.MESSAGE)
response = yield ws.read_message()
self.assertEqual(response, self.MESSAGE)
self.assertEqual(ws.protocol._message_bytes_out, len(self.MESSAGE) * 3)
self.assertEqual(ws.protocol._message_bytes_in, len(self.MESSAGE) * 3)
self.verify_wire_bytes(ws.protocol._wire_bytes_in,
ws.protocol._wire_bytes_out)
yield self.close(ws)
class UncompressedTestMixin(CompressionTestMixin):
"""Specialization of CompressionTestMixin when we expect no compression."""
def verify_wire_bytes(self, bytes_in, bytes_out):
# Bytes out includes the 4-byte mask key per message.
self.assertEqual(bytes_out, 3 * (len(self.MESSAGE) + 6))
self.assertEqual(bytes_in, 3 * (len(self.MESSAGE) + 2))
class NoCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
pass
# If only one side tries to compress, the extension is not negotiated.
class ServerOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
def get_server_compression_options(self):
return {}
class ClientOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
def get_client_compression_options(self):
return {}
class DefaultCompressionTest(CompressionTestMixin, WebSocketBaseTestCase):
def get_server_compression_options(self):
return {}
def get_client_compression_options(self):
return {}
def verify_wire_bytes(self, bytes_in, bytes_out):
self.assertLess(bytes_out, 3 * (len(self.MESSAGE) + 6))
self.assertLess(bytes_in, 3 * (len(self.MESSAGE) + 2))
# Bytes out includes the 4 bytes mask key per message.
self.assertEqual(bytes_out, bytes_in + 12)
class MaskFunctionMixin(object):
# Subclasses should define self.mask(mask, data)
def test_mask(self):

View file

@ -19,6 +19,7 @@ try:
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.ioloop import IOLoop, TimeoutError
from tornado import netutil
from tornado.process import Subprocess
except ImportError:
# These modules are not importable on app engine. Parts of this module
# won't work, but e.g. LogTrapTestCase and main() will.
@ -28,23 +29,35 @@ except ImportError:
IOLoop = None
netutil = None
SimpleAsyncHTTPClient = None
from tornado.log import gen_log
Subprocess = None
from tornado.log import gen_log, app_log
from tornado.stack_context import ExceptionStackContext
from tornado.util import raise_exc_info, basestring_type
import functools
import inspect
import logging
import os
import re
import signal
import socket
import sys
import types
try:
from cStringIO import StringIO # py2
except ImportError:
from io import StringIO # py3
try:
from collections.abc import Generator as GeneratorType # py35+
except ImportError:
from types import GeneratorType
if sys.version_info >= (3, 5):
iscoroutine = inspect.iscoroutine
iscoroutinefunction = inspect.iscoroutinefunction
else:
iscoroutine = iscoroutinefunction = lambda f: False
# Tornado's own test suite requires the updated unittest module
# (either py27+ or unittest2) so tornado.test.util enforces
# this requirement, but for other users of tornado.testing we want
@ -79,12 +92,13 @@ def get_unused_port():
return port
def bind_unused_port():
def bind_unused_port(reuse_port=False):
"""Binds a server socket to an available port on localhost.
Returns a tuple (socket, port).
"""
[sock] = netutil.bind_sockets(None, 'localhost', family=socket.AF_INET)
[sock] = netutil.bind_sockets(None, 'localhost', family=socket.AF_INET,
reuse_port=reuse_port)
port = sock.getsockname()[1]
return sock, port
@ -114,11 +128,11 @@ class _TestMethodWrapper(object):
def __init__(self, orig_method):
self.orig_method = orig_method
def __call__(self):
result = self.orig_method()
if isinstance(result, types.GeneratorType):
raise TypeError("Generator test methods should be decorated with "
"tornado.testing.gen_test")
def __call__(self, *args, **kwargs):
result = self.orig_method(*args, **kwargs)
if isinstance(result, GeneratorType) or iscoroutine(result):
raise TypeError("Generator and coroutine test methods should be"
" decorated with tornado.testing.gen_test")
elif result is not None:
raise ValueError("Return value from test method ignored: %r" %
result)
@ -214,6 +228,8 @@ class AsyncTestCase(unittest.TestCase):
self.io_loop.make_current()
def tearDown(self):
# Clean up Subprocess, so it can be used again with a new ioloop.
Subprocess.uninitialize()
self.io_loop.clear_current()
if (not IOLoop.initialized() or
self.io_loop is not IOLoop.instance()):
@ -237,7 +253,11 @@ class AsyncTestCase(unittest.TestCase):
return IOLoop()
def _handle_exception(self, typ, value, tb):
self.__failure = (typ, value, tb)
if self.__failure is None:
self.__failure = (typ, value, tb)
else:
app_log.error("multiple unhandled exceptions in test",
exc_info=(typ, value, tb))
self.stop()
return True
@ -323,20 +343,29 @@ class AsyncHTTPTestCase(AsyncTestCase):
Tests will typically use the provided ``self.http_client`` to fetch
URLs from this server.
Example::
Example, assuming the "Hello, world" example from the user guide is in
``hello.py``::
class MyHTTPTest(AsyncHTTPTestCase):
import hello
class TestHelloApp(AsyncHTTPTestCase):
def get_app(self):
return Application([('/', MyHandler)...])
return hello.make_app()
def test_homepage(self):
# The following two lines are equivalent to
# response = self.fetch('/')
# but are shown in full here to demonstrate explicit use
# of self.stop and self.wait.
self.http_client.fetch(self.get_url('/'), self.stop)
response = self.wait()
# test contents of response
response = self.fetch('/')
self.assertEqual(response.code, 200)
self.assertEqual(response.body, 'Hello, world')
That call to ``self.fetch()`` is equivalent to ::
self.http_client.fetch(self.get_url('/'), self.stop)
response = self.wait()
which illustrates how AsyncTestCase can turn an asynchronous operation,
like ``http_client.fetch()``, into a synchronous operation. If you need
to do other asynchronous operations in tests, you'll probably need to use
``stop()`` and ``wait()`` yourself.
"""
def setUp(self):
super(AsyncHTTPTestCase, self).setUp()
@ -395,7 +424,8 @@ class AsyncHTTPTestCase(AsyncTestCase):
def tearDown(self):
self.http_server.stop()
self.io_loop.run_sync(self.http_server.close_all_connections)
self.io_loop.run_sync(self.http_server.close_all_connections,
timeout=get_async_test_timeout())
if (not IOLoop.initialized() or
self.http_client.io_loop is not IOLoop.instance()):
self.http_client.close()
@ -408,10 +438,8 @@ class AsyncHTTPSTestCase(AsyncHTTPTestCase):
Interface is generally the same as `AsyncHTTPTestCase`.
"""
def get_http_client(self):
# Some versions of libcurl have deadlock bugs with ssl,
# so always run these tests with SimpleAsyncHTTPClient.
return SimpleAsyncHTTPClient(io_loop=self.io_loop, force_instance=True,
defaults=dict(validate_cert=False))
return AsyncHTTPClient(io_loop=self.io_loop, force_instance=True,
defaults=dict(validate_cert=False))
def get_httpserver_options(self):
return dict(ssl_options=self.get_ssl_options())
@ -478,13 +506,16 @@ def gen_test(func=None, timeout=None):
@functools.wraps(f)
def pre_coroutine(self, *args, **kwargs):
result = f(self, *args, **kwargs)
if isinstance(result, types.GeneratorType):
if isinstance(result, GeneratorType) or iscoroutine(result):
self._test_generator = result
else:
self._test_generator = None
return result
coro = gen.coroutine(pre_coroutine)
if iscoroutinefunction(f):
coro = pre_coroutine
else:
coro = gen.coroutine(pre_coroutine)
@functools.wraps(coro)
def post_coroutine(self, *args, **kwargs):
@ -494,8 +525,8 @@ def gen_test(func=None, timeout=None):
timeout=timeout)
except TimeoutError as e:
# run_sync raises an error with an unhelpful traceback.
# If we throw it back into the generator the stack trace
# will be replaced by the point where the test is stopped.
# Throw it back into the generator or coroutine so the stack
# trace is replaced by the point where the test is stopped.
self._test_generator.throw(e)
# In case the test contains an overly broad except clause,
# we may get back here. In this case re-raise the original
@ -534,6 +565,9 @@ class LogTrapTestCase(unittest.TestCase):
`logging.basicConfig` and the "pretty logging" configured by
`tornado.options`. It is not compatible with other log buffering
mechanisms, such as those provided by some test runners.
.. deprecated:: 4.1
Use the unittest module's ``--buffer`` option instead, or `.ExpectLog`.
"""
def run(self, result=None):
logger = logging.getLogger()
@ -565,10 +599,16 @@ class ExpectLog(logging.Filter):
Useful to make tests of error conditions less noisy, while still
leaving unexpected log entries visible. *Not thread safe.*
The attribute ``logged_stack`` is set to true if any exception
stack trace was logged.
Usage::
with ExpectLog('tornado.application', "Uncaught exception"):
error_response = self.fetch("/some_page")
.. versionchanged:: 4.3
Added the ``logged_stack`` attribute.
"""
def __init__(self, logger, regex, required=True):
"""Constructs an ExpectLog context manager.
@ -586,8 +626,11 @@ class ExpectLog(logging.Filter):
self.regex = re.compile(regex)
self.required = required
self.matched = False
self.logged_stack = False
def filter(self, record):
if record.exc_info:
self.logged_stack = True
message = record.getMessage()
if self.regex.match(message):
self.matched = True
@ -596,6 +639,7 @@ class ExpectLog(logging.Filter):
def __enter__(self):
self.logger.addFilter(self)
return self
def __exit__(self, typ, value, tb):
self.logger.removeFilter(self)

View file

@ -13,7 +13,6 @@ and `.Resolver`.
from __future__ import absolute_import, division, print_function, with_statement
import array
import inspect
import os
import sys
import zlib
@ -24,6 +23,13 @@ try:
except NameError:
xrange = range # py3
# inspect.getargspec() raises DeprecationWarnings in Python 3.5.
# The two functions have compatible interfaces for the parts we need.
try:
from inspect import getfullargspec as getargspec # py3
except ImportError:
from inspect import getargspec # py2
class ObjectDict(dict):
"""Makes a dictionary behave like an object, with attribute-style access.
@ -78,6 +84,25 @@ class GzipDecompressor(object):
return self.decompressobj.flush()
# Fake unicode literal support: Python 3.2 doesn't have the u'' marker for
# literal strings, and alternative solutions like "from __future__ import
# unicode_literals" have other problems (see PEP 414). u() can be applied
# to ascii strings that include \u escapes (but they must not contain
# literal non-ascii characters).
if not isinstance(b'', type('')):
def u(s):
return s
unicode_type = str
basestring_type = str
else:
def u(s):
return s.decode('unicode_escape')
# These names don't exist in py3, so use noqa comments to disable
# warnings in flake8.
unicode_type = unicode # noqa
basestring_type = basestring # noqa
def import_object(name):
"""Imports an object by name.
@ -96,6 +121,9 @@ def import_object(name):
...
ImportError: No module named missing_module
"""
if isinstance(name, unicode_type) and str is not unicode_type:
# On python 2 a byte string is required.
name = name.encode('utf-8')
if name.count('.') == 0:
return __import__(name, None, None)
@ -107,24 +135,9 @@ def import_object(name):
raise ImportError("No module named %s" % parts[-1])
# Fake unicode literal support: Python 3.2 doesn't have the u'' marker for
# literal strings, and alternative solutions like "from __future__ import
# unicode_literals" have other problems (see PEP 414). u() can be applied
# to ascii strings that include \u escapes (but they must not contain
# literal non-ascii characters).
if type('') is not type(b''):
def u(s):
return s
bytes_type = bytes
unicode_type = str
basestring_type = str
else:
def u(s):
return s.decode('unicode_escape')
bytes_type = str
unicode_type = unicode
basestring_type = basestring
# Deprecated alias that was used before we dropped py25 support.
# Left here in case anyone outside Tornado is using it.
bytes_type = bytes
if sys.version_info > (3,):
exec("""
@ -154,7 +167,7 @@ def errno_from_exception(e):
"""Provides the errno from an Exception object.
There are cases that the errno attribute was not set so we pull
the errno out of the args but if someone instatiates an Exception
the errno out of the args but if someone instantiates an Exception
without any args you will get a tuple error. So this function
abstracts all that behavior to give you a safe way to get the
errno.
@ -191,21 +204,21 @@ class Configurable(object):
__impl_class = None
__impl_kwargs = None
def __new__(cls, **kwargs):
def __new__(cls, *args, **kwargs):
base = cls.configurable_base()
args = {}
init_kwargs = {}
if cls is base:
impl = cls.configured_class()
if base.__impl_kwargs:
args.update(base.__impl_kwargs)
init_kwargs.update(base.__impl_kwargs)
else:
impl = cls
args.update(kwargs)
init_kwargs.update(kwargs)
instance = super(Configurable, cls).__new__(impl)
# initialize vs __init__ chosen for compatiblity with AsyncHTTPClient
# initialize vs __init__ chosen for compatibility with AsyncHTTPClient
# singleton magic. If we get rid of that we can switch to __init__
# here too.
instance.initialize(**args)
instance.initialize(*args, **init_kwargs)
return instance
@classmethod
@ -226,6 +239,9 @@ class Configurable(object):
"""Initialize a `Configurable` subclass instance.
Configurable classes should use `initialize` instead of ``__init__``.
.. versionchanged:: 4.2
Now accepts positional arguments in addition to keyword arguments.
"""
@classmethod
@ -237,7 +253,7 @@ class Configurable(object):
some parameters.
"""
base = cls.configurable_base()
if isinstance(impl, (unicode_type, bytes_type)):
if isinstance(impl, (unicode_type, bytes)):
impl = import_object(impl)
if impl is not None and not issubclass(impl, cls):
raise ValueError("Invalid subclass of %s" % cls)
@ -274,11 +290,26 @@ class ArgReplacer(object):
def __init__(self, func, name):
self.name = name
try:
self.arg_pos = inspect.getargspec(func).args.index(self.name)
self.arg_pos = self._getargnames(func).index(name)
except ValueError:
# Not a positional parameter
self.arg_pos = None
def _getargnames(self, func):
try:
return getargspec(func).args
except TypeError:
if hasattr(func, 'func_code'):
# Cython-generated code has all the attributes needed
# by inspect.getargspec, but the inspect module only
# works with ordinary functions. Inline the portion of
# getargspec that we need here. Note that for static
# functions the @cython.binding(True) decorator must
# be used (for methods it works out of the box).
code = func.func_code
return code.co_varnames[:code.co_argcount]
raise
def get_old_value(self, args, kwargs, default=None):
"""Returns the old value of the named argument without replacing it.
@ -338,7 +369,7 @@ def _websocket_mask_python(mask, data):
return unmasked.tostring()
if (os.environ.get('TORNADO_NO_EXTENSION') or
os.environ.get('TORNADO_EXTENSION') == '0'):
os.environ.get('TORNADO_EXTENSION') == '0'):
# These environment variables exist to make it easier to do performance
# comparisons; they are not guaranteed to remain supported in the future.
_websocket_mask = _websocket_mask_python

File diff suppressed because it is too large Load diff

View file

@ -26,6 +26,7 @@ import os
import struct
import tornado.escape
import tornado.web
import zlib
from tornado.concurrent import TracebackFuture
from tornado.escape import utf8, native_str, to_unicode
@ -35,12 +36,12 @@ from tornado.iostream import StreamClosedError
from tornado.log import gen_log, app_log
from tornado import simple_httpclient
from tornado.tcpclient import TCPClient
from tornado.util import bytes_type, _websocket_mask
from tornado.util import _websocket_mask
try:
from urllib.parse import urlparse # py2
from urllib.parse import urlparse # py2
except ImportError:
from urlparse import urlparse # py3
from urlparse import urlparse # py3
try:
xrange # py2
@ -73,17 +74,22 @@ class WebSocketHandler(tornado.web.RequestHandler):
http://tools.ietf.org/html/rfc6455.
Here is an example WebSocket handler that echos back all received messages
back to the client::
back to the client:
class EchoWebSocket(websocket.WebSocketHandler):
.. testcode::
class EchoWebSocket(tornado.websocket.WebSocketHandler):
def open(self):
print "WebSocket opened"
print("WebSocket opened")
def on_message(self, message):
self.write_message(u"You said: " + message)
def on_close(self):
print "WebSocket closed"
print("WebSocket closed")
.. testoutput::
:hide:
WebSockets are not standard HTTP connections. The "handshake" is
HTTP, but after the handshake, the protocol is
@ -105,6 +111,21 @@ class WebSocketHandler(tornado.web.RequestHandler):
};
This script pops up an alert box that says "You said: Hello, world".
Web browsers allow any site to open a websocket connection to any other,
instead of using the same-origin policy that governs other network
access from javascript. This can be surprising and is a potential
security hole, so since Tornado 4.0 `WebSocketHandler` requires
applications that wish to receive cross-origin websockets to opt in
by overriding the `~WebSocketHandler.check_origin` method (see that
method's docs for details). Failure to do so is the most likely
cause of 403 errors when making a websocket connection.
When using a secure websocket connection (``wss://``) with a self-signed
certificate, the connection from a browser may fail because it wants
to show the "accept this certificate" dialog but has nowhere to show it.
You must first visit a regular HTML page using the same certificate
to accept it before the websocket connection will succeed.
"""
def __init__(self, application, request, **kwargs):
tornado.web.RequestHandler.__init__(self, application, request,
@ -113,6 +134,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
self.close_code = None
self.close_reason = None
self.stream = None
self._on_close_called = False
@tornado.web.asynchronous
def get(self, *args, **kwargs):
@ -122,16 +144,22 @@ class WebSocketHandler(tornado.web.RequestHandler):
# Upgrade header should be present and should be equal to WebSocket
if self.request.headers.get("Upgrade", "").lower() != 'websocket':
self.set_status(400)
self.finish("Can \"Upgrade\" only to \"WebSocket\".")
log_msg = "Can \"Upgrade\" only to \"WebSocket\"."
self.finish(log_msg)
gen_log.debug(log_msg)
return
# Connection header should be upgrade. Some proxy servers/load balancers
# Connection header should be upgrade.
# Some proxy servers/load balancers
# might mess with it.
headers = self.request.headers
connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(","))
connection = map(lambda s: s.strip().lower(),
headers.get("Connection", "").split(","))
if 'upgrade' not in connection:
self.set_status(400)
self.finish("\"Connection\" must be \"Upgrade\".")
log_msg = "\"Connection\" must be \"Upgrade\"."
self.finish(log_msg)
gen_log.debug(log_msg)
return
# Handle WebSocket Origin naming convention differences
@ -143,27 +171,28 @@ class WebSocketHandler(tornado.web.RequestHandler):
else:
origin = self.request.headers.get("Sec-Websocket-Origin", None)
# If there was an origin header, check to make sure it matches
# according to check_origin. When the origin is None, we assume it
# did not come from a browser and that it can be passed on.
if origin is not None and not self.check_origin(origin):
self.set_status(403)
self.finish("Cross origin websockets not allowed")
log_msg = "Cross origin websockets not allowed"
self.finish(log_msg)
gen_log.debug(log_msg)
return
self.stream = self.request.connection.detach()
self.stream.set_close_callback(self.on_connection_close)
if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
self.ws_connection = WebSocketProtocol13(self)
self.ws_connection = self.get_websocket_protocol()
if self.ws_connection:
self.ws_connection.accept_connection()
else:
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 426 Upgrade Required\r\n"
"Sec-WebSocket-Version: 8\r\n\r\n"))
self.stream.close()
if not self.stream.closed():
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 426 Upgrade Required\r\n"
"Sec-WebSocket-Version: 7, 8, 13\r\n\r\n"))
self.stream.close()
def write_message(self, message, binary=False):
"""Sends the given message to the client of this Web Socket.
@ -178,12 +207,15 @@ class WebSocketHandler(tornado.web.RequestHandler):
.. versionchanged:: 3.2
`WebSocketClosedError` was added (previously a closed connection
would raise an `AttributeError`)
.. versionchanged:: 4.3
Returns a `.Future` which can be used for flow control.
"""
if self.ws_connection is None:
raise WebSocketClosedError()
if isinstance(message, dict):
message = tornado.escape.json_encode(message)
self.ws_connection.write_message(message, binary=binary)
return self.ws_connection.write_message(message, binary=binary)
def select_subprotocol(self, subprotocols):
"""Invoked when a new WebSocket requests specific subprotocols.
@ -198,7 +230,20 @@ class WebSocketHandler(tornado.web.RequestHandler):
"""
return None
def open(self):
def get_compression_options(self):
"""Override to return compression options for the connection.
If this method returns None (the default), compression will
be disabled. If it returns a dict (even an empty one), it
will be enabled. The contents of the dict may be used to
control the memory and CPU usage of the compression,
but no such options are currently implemented.
.. versionadded:: 4.1
"""
return None
def open(self, *args, **kwargs):
"""Invoked when a new WebSocket is opened.
The arguments to `open` are extracted from the `tornado.web.URLSpec`
@ -275,6 +320,19 @@ class WebSocketHandler(tornado.web.RequestHandler):
browsers, since WebSockets are allowed to bypass the usual same-origin
policies and don't use CORS headers.
To accept all cross-origin traffic (which was the default prior to
Tornado 4.0), simply override this method to always return true::
def check_origin(self, origin):
return True
To allow connections from any subdomain of your site, you might
do something like::
def check_origin(self, origin):
parsed_origin = urllib.parse.urlparse(origin)
return parsed_origin.netloc.endswith(".mydomain.com")
.. versionadded:: 4.0
"""
parsed_origin = urlparse(origin)
@ -306,8 +364,26 @@ class WebSocketHandler(tornado.web.RequestHandler):
if self.ws_connection:
self.ws_connection.on_connection_close()
self.ws_connection = None
if not self._on_close_called:
self._on_close_called = True
self.on_close()
def send_error(self, *args, **kwargs):
if self.stream is None:
super(WebSocketHandler, self).send_error(*args, **kwargs)
else:
# If we get an uncaught exception during the handshake,
# we have no choice but to abruptly close the connection.
# TODO: for uncaught exceptions after the handshake,
# we can close the connection more gracefully.
self.stream.close()
def get_websocket_protocol(self):
websocket_version = self.request.headers.get("Sec-WebSocket-Version")
if websocket_version in ("7", "8", "13"):
return WebSocketProtocol13(
self, compression_options=self.get_compression_options())
def _wrap_method(method):
def _disallow_for_websocket(self, *args, **kwargs):
@ -316,7 +392,7 @@ def _wrap_method(method):
else:
raise RuntimeError("Method not supported for Web Sockets")
return _disallow_for_websocket
for method in ["write", "redirect", "set_header", "send_error", "set_cookie",
for method in ["write", "redirect", "set_header", "set_cookie",
"set_status", "flush", "finish"]:
setattr(WebSocketHandler, method,
_wrap_method(getattr(WebSocketHandler, method)))
@ -355,13 +431,69 @@ class WebSocketProtocol(object):
self.close() # let the subclass cleanup
class _PerMessageDeflateCompressor(object):
def __init__(self, persistent, max_wbits):
if max_wbits is None:
max_wbits = zlib.MAX_WBITS
# There is no symbolic constant for the minimum wbits value.
if not (8 <= max_wbits <= zlib.MAX_WBITS):
raise ValueError("Invalid max_wbits value %r; allowed range 8-%d",
max_wbits, zlib.MAX_WBITS)
self._max_wbits = max_wbits
if persistent:
self._compressor = self._create_compressor()
else:
self._compressor = None
def _create_compressor(self):
return zlib.compressobj(tornado.web.GZipContentEncoding.GZIP_LEVEL,
zlib.DEFLATED, -self._max_wbits)
def compress(self, data):
compressor = self._compressor or self._create_compressor()
data = (compressor.compress(data) +
compressor.flush(zlib.Z_SYNC_FLUSH))
assert data.endswith(b'\x00\x00\xff\xff')
return data[:-4]
class _PerMessageDeflateDecompressor(object):
def __init__(self, persistent, max_wbits):
if max_wbits is None:
max_wbits = zlib.MAX_WBITS
if not (8 <= max_wbits <= zlib.MAX_WBITS):
raise ValueError("Invalid max_wbits value %r; allowed range 8-%d",
max_wbits, zlib.MAX_WBITS)
self._max_wbits = max_wbits
if persistent:
self._decompressor = self._create_decompressor()
else:
self._decompressor = None
def _create_decompressor(self):
return zlib.decompressobj(-self._max_wbits)
def decompress(self, data):
decompressor = self._decompressor or self._create_decompressor()
return decompressor.decompress(data + b'\x00\x00\xff\xff')
class WebSocketProtocol13(WebSocketProtocol):
"""Implementation of the WebSocket protocol from RFC 6455.
This class supports versions 7 and 8 of the protocol in addition to the
final version 13.
"""
def __init__(self, handler, mask_outgoing=False):
# Bit masks for the first byte of a frame.
FIN = 0x80
RSV1 = 0x40
RSV2 = 0x20
RSV3 = 0x10
RSV_MASK = RSV1 | RSV2 | RSV3
OPCODE_MASK = 0x0f
def __init__(self, handler, mask_outgoing=False,
compression_options=None):
WebSocketProtocol.__init__(self, handler)
self.mask_outgoing = mask_outgoing
self._final_frame = False
@ -372,13 +504,27 @@ class WebSocketProtocol13(WebSocketProtocol):
self._fragmented_message_buffer = None
self._fragmented_message_opcode = None
self._waiting = None
self._compression_options = compression_options
self._decompressor = None
self._compressor = None
self._frame_compressed = None
# The total uncompressed size of all messages received or sent.
# Unicode messages are encoded to utf8.
# Only for testing; subject to change.
self._message_bytes_in = 0
self._message_bytes_out = 0
# The total size of all packets received or sent. Includes
# the effect of compression, frame overhead, and control frames.
self._wire_bytes_in = 0
self._wire_bytes_out = 0
def accept_connection(self):
try:
self._handle_websocket_headers()
self._accept_connection()
except ValueError:
gen_log.debug("Malformed WebSocket request received", exc_info=True)
gen_log.debug("Malformed WebSocket request received",
exc_info=True)
self._abort()
return
@ -414,26 +560,102 @@ class WebSocketProtocol13(WebSocketProtocol):
selected = self.handler.select_subprotocol(subprotocols)
if selected:
assert selected in subprotocols
subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected
subprotocol_header = ("Sec-WebSocket-Protocol: %s\r\n"
% selected)
extension_header = ''
extensions = self._parse_extensions_header(self.request.headers)
for ext in extensions:
if (ext[0] == 'permessage-deflate' and
self._compression_options is not None):
# TODO: negotiate parameters if compression_options
# specifies limits.
self._create_compressors('server', ext[1])
if ('client_max_window_bits' in ext[1] and
ext[1]['client_max_window_bits'] is None):
# Don't echo an offered client_max_window_bits
# parameter with no value.
del ext[1]['client_max_window_bits']
extension_header = ('Sec-WebSocket-Extensions: %s\r\n' %
httputil._encode_header(
'permessage-deflate', ext[1]))
break
if self.stream.closed():
self._abort()
return
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Accept: %s\r\n"
"%s"
"\r\n" % (self._challenge_response(), subprotocol_header)))
"%s%s"
"\r\n" % (self._challenge_response(),
subprotocol_header, extension_header)))
self._run_callback(self.handler.open, *self.handler.open_args,
**self.handler.open_kwargs)
self._receive_frame()
def _write_frame(self, fin, opcode, data):
def _parse_extensions_header(self, headers):
extensions = headers.get("Sec-WebSocket-Extensions", '')
if extensions:
return [httputil._parse_header(e.strip())
for e in extensions.split(',')]
return []
def _process_server_headers(self, key, headers):
"""Process the headers sent by the server to this client connection.
'key' is the websocket handshake challenge/response key.
"""
assert headers['Upgrade'].lower() == 'websocket'
assert headers['Connection'].lower() == 'upgrade'
accept = self.compute_accept_value(key)
assert headers['Sec-Websocket-Accept'] == accept
extensions = self._parse_extensions_header(headers)
for ext in extensions:
if (ext[0] == 'permessage-deflate' and
self._compression_options is not None):
self._create_compressors('client', ext[1])
else:
raise ValueError("unsupported extension %r", ext)
def _get_compressor_options(self, side, agreed_parameters):
"""Converts a websocket agreed_parameters set to keyword arguments
for our compressor objects.
"""
options = dict(
persistent=(side + '_no_context_takeover') not in agreed_parameters)
wbits_header = agreed_parameters.get(side + '_max_window_bits', None)
if wbits_header is None:
options['max_wbits'] = zlib.MAX_WBITS
else:
options['max_wbits'] = int(wbits_header)
return options
def _create_compressors(self, side, agreed_parameters):
# TODO: handle invalid parameters gracefully
allowed_keys = set(['server_no_context_takeover',
'client_no_context_takeover',
'server_max_window_bits',
'client_max_window_bits'])
for key in agreed_parameters:
if key not in allowed_keys:
raise ValueError("unsupported compression parameter %r" % key)
other_side = 'client' if (side == 'server') else 'server'
self._compressor = _PerMessageDeflateCompressor(
**self._get_compressor_options(side, agreed_parameters))
self._decompressor = _PerMessageDeflateDecompressor(
**self._get_compressor_options(other_side, agreed_parameters))
def _write_frame(self, fin, opcode, data, flags=0):
if fin:
finbit = 0x80
finbit = self.FIN
else:
finbit = 0
frame = struct.pack("B", finbit | opcode)
frame = struct.pack("B", finbit | opcode | flags)
l = len(data)
if self.mask_outgoing:
mask_bit = 0x80
@ -449,7 +671,11 @@ class WebSocketProtocol13(WebSocketProtocol):
mask = os.urandom(4)
data = mask + _websocket_mask(mask, data)
frame += data
self.stream.write(frame)
self._wire_bytes_out += len(frame)
try:
return self.stream.write(frame)
except StreamClosedError:
self._abort()
def write_message(self, message, binary=False):
"""Sends the given message to the client of this Web Socket."""
@ -458,15 +684,17 @@ class WebSocketProtocol13(WebSocketProtocol):
else:
opcode = 0x1
message = tornado.escape.utf8(message)
assert isinstance(message, bytes_type)
try:
self._write_frame(True, opcode, message)
except StreamClosedError:
self._abort()
assert isinstance(message, bytes)
self._message_bytes_out += len(message)
flags = 0
if self._compressor:
message = self._compressor.compress(message)
flags |= self.RSV1
return self._write_frame(True, opcode, message, flags=flags)
def write_ping(self, data):
"""Send ping frame."""
assert isinstance(data, bytes_type)
assert isinstance(data, bytes)
self._write_frame(True, 0x9, data)
def _receive_frame(self):
@ -476,11 +704,15 @@ class WebSocketProtocol13(WebSocketProtocol):
self._abort()
def _on_frame_start(self, data):
self._wire_bytes_in += len(data)
header, payloadlen = struct.unpack("BB", data)
self._final_frame = header & 0x80
reserved_bits = header & 0x70
self._frame_opcode = header & 0xf
self._final_frame = header & self.FIN
reserved_bits = header & self.RSV_MASK
self._frame_opcode = header & self.OPCODE_MASK
self._frame_opcode_is_control = self._frame_opcode & 0x8
if self._decompressor is not None and self._frame_opcode != 0:
self._frame_compressed = bool(reserved_bits & self.RSV1)
reserved_bits &= ~self.RSV1
if reserved_bits:
# client is using as-yet-undefined extensions; abort
self._abort()
@ -497,7 +729,8 @@ class WebSocketProtocol13(WebSocketProtocol):
if self._masked_frame:
self.stream.read_bytes(4, self._on_masking_key)
else:
self.stream.read_bytes(self._frame_length, self._on_frame_data)
self.stream.read_bytes(self._frame_length,
self._on_frame_data)
elif payloadlen == 126:
self.stream.read_bytes(2, self._on_frame_length_16)
elif payloadlen == 127:
@ -506,6 +739,7 @@ class WebSocketProtocol13(WebSocketProtocol):
self._abort()
def _on_frame_length_16(self, data):
self._wire_bytes_in += len(data)
self._frame_length = struct.unpack("!H", data)[0]
try:
if self._masked_frame:
@ -516,6 +750,7 @@ class WebSocketProtocol13(WebSocketProtocol):
self._abort()
def _on_frame_length_64(self, data):
self._wire_bytes_in += len(data)
self._frame_length = struct.unpack("!Q", data)[0]
try:
if self._masked_frame:
@ -526,16 +761,20 @@ class WebSocketProtocol13(WebSocketProtocol):
self._abort()
def _on_masking_key(self, data):
self._wire_bytes_in += len(data)
self._frame_mask = data
try:
self.stream.read_bytes(self._frame_length, self._on_masked_frame_data)
self.stream.read_bytes(self._frame_length,
self._on_masked_frame_data)
except StreamClosedError:
self._abort()
def _on_masked_frame_data(self, data):
# Don't touch _wire_bytes_in; we'll do it in _on_frame_data.
self._on_frame_data(_websocket_mask(self._frame_mask, data))
def _on_frame_data(self, data):
self._wire_bytes_in += len(data)
if self._frame_opcode_is_control:
# control frames may be interleaved with a series of fragmented
# data frames, so control frames must not interact with
@ -576,8 +815,12 @@ class WebSocketProtocol13(WebSocketProtocol):
if self.client_terminated:
return
if self._frame_compressed:
data = self._decompressor.decompress(data)
if opcode == 0x1:
# UTF-8 data
self._message_bytes_in += len(data)
try:
decoded = data.decode("utf-8")
except UnicodeDecodeError:
@ -586,6 +829,7 @@ class WebSocketProtocol13(WebSocketProtocol):
self._run_callback(self.handler.on_message, decoded)
elif opcode == 0x2:
# Binary data
self._message_bytes_in += len(data)
self._run_callback(self.handler.on_message, data)
elif opcode == 0x8:
# Close
@ -594,7 +838,8 @@ class WebSocketProtocol13(WebSocketProtocol):
self.handler.close_code = struct.unpack('>H', data[:2])[0]
if len(data) > 2:
self.handler.close_reason = to_unicode(data[2:])
self.close()
# Echo the received close code, if any (RFC 6455 section 5.5.1).
self.close(self.handler.close_code)
elif opcode == 0x9:
# Ping
self._write_frame(True, 0xA, data)
@ -636,11 +881,16 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
This class should not be instantiated directly; use the
`websocket_connect` function instead.
"""
def __init__(self, io_loop, request):
def __init__(self, io_loop, request, on_message_callback=None,
compression_options=None):
self.compression_options = compression_options
self.connect_future = TracebackFuture()
self.protocol = None
self.read_future = None
self.read_queue = collections.deque()
self.key = base64.b64encode(os.urandom(16))
self._on_message_callback = on_message_callback
self.close_code = self.close_reason = None
scheme, sep, rest = request.url.partition(':')
scheme = {'ws': 'http', 'wss': 'https'}[scheme]
@ -651,11 +901,19 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
'Sec-WebSocket-Key': self.key,
'Sec-WebSocket-Version': '13',
})
if self.compression_options is not None:
# Always offer to let the server set our max_wbits (and even though
# we don't offer it, we will accept a client_no_context_takeover
# from the server).
# TODO: set server parameters for deflate extension
# if requested in self.compression_options.
request.headers['Sec-WebSocket-Extensions'] = (
'permessage-deflate; client_max_window_bits')
self.tcp_client = TCPClient(io_loop=io_loop)
super(WebSocketClientConnection, self).__init__(
io_loop, None, request, lambda: None, self._on_http_response,
104857600, self.tcp_client, 65536)
104857600, self.tcp_client, 65536, 104857600)
def close(self, code=None, reason=None):
"""Closes the websocket connection.
@ -673,10 +931,12 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
self.protocol.close(code, reason)
self.protocol = None
def _on_close(self):
def on_connection_close(self):
if not self.connect_future.done():
self.connect_future.set_exception(StreamClosedError())
self.on_message(None)
self.resolver.close()
super(WebSocketClientConnection, self)._on_close()
self.tcp_client.close()
super(WebSocketClientConnection, self).on_connection_close()
def _on_http_response(self, response):
if not self.connect_future.done():
@ -692,12 +952,8 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
start_line, headers)
self.headers = headers
assert self.headers['Upgrade'].lower() == 'websocket'
assert self.headers['Connection'].lower() == 'upgrade'
accept = WebSocketProtocol13.compute_accept_value(self.key)
assert self.headers['Sec-Websocket-Accept'] == accept
self.protocol = WebSocketProtocol13(self, mask_outgoing=True)
self.protocol = self.get_websocket_protocol()
self.protocol._process_server_headers(self.key, self.headers)
self.protocol._receive_frame()
if self._timeout is not None:
@ -705,17 +961,25 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
self._timeout = None
self.stream = self.connection.detach()
self.stream.set_close_callback(self._on_close)
self.stream.set_close_callback(self.on_connection_close)
# Once we've taken over the connection, clear the final callback
# we set on the http request. This deactivates the error handling
# in simple_httpclient that would otherwise interfere with our
# ability to see exceptions.
self.final_callback = None
self.connect_future.set_result(self)
def write_message(self, message, binary=False):
"""Sends a message to the WebSocket server."""
self.protocol.write_message(message, binary)
return self.protocol.write_message(message, binary)
def read_message(self, callback=None):
"""Reads a message from the WebSocket server.
If on_message_callback was specified at WebSocket
initialization, this function will never return messages
Returns a future whose result is the message, or None
if the connection is closed. If a callback argument
is given it will be called with the future when it is
@ -732,7 +996,9 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
return future
def on_message(self, message):
if self.read_future is not None:
if self._on_message_callback:
self._on_message_callback(message)
elif self.read_future is not None:
self.read_future.set_result(message)
self.read_future = None
else:
@ -741,15 +1007,41 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
def on_pong(self, data):
pass
def get_websocket_protocol(self):
return WebSocketProtocol13(self, mask_outgoing=True,
compression_options=self.compression_options)
def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None):
def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None,
on_message_callback=None, compression_options=None):
"""Client-side websocket support.
Takes a url and returns a Future whose result is a
`WebSocketClientConnection`.
``compression_options`` is interpreted in the same way as the
return value of `.WebSocketHandler.get_compression_options`.
The connection supports two styles of operation. In the coroutine
style, the application typically calls
`~.WebSocketClientConnection.read_message` in a loop::
conn = yield websocket_connect(url)
while True:
msg = yield conn.read_message()
if msg is None: break
# Do something with msg
In the callback style, pass an ``on_message_callback`` to
``websocket_connect``. In both styles, a message of ``None``
indicates that the connection has been closed.
.. versionchanged:: 3.2
Also accepts ``HTTPRequest`` objects in place of urls.
.. versionchanged:: 4.1
Added ``compression_options`` and ``on_message_callback``.
The ``io_loop`` argument is deprecated.
"""
if io_loop is None:
io_loop = IOLoop.current()
@ -763,7 +1055,9 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None):
request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
request = httpclient._RequestProxy(
request, httpclient.HTTPRequest._DEFAULTS)
conn = WebSocketClientConnection(io_loop, request)
conn = WebSocketClientConnection(io_loop, request,
on_message_callback=on_message_callback,
compression_options=compression_options)
if callback is not None:
io_loop.add_future(conn.connect_future, callback)
return conn.connect_future

View file

@ -32,6 +32,7 @@ provides WSGI support in two ways:
from __future__ import absolute_import, division, print_function, with_statement
import sys
from io import BytesIO
import tornado
from tornado.concurrent import Future
@ -40,12 +41,8 @@ from tornado import httputil
from tornado.log import access_log
from tornado import web
from tornado.escape import native_str
from tornado.util import bytes_type, unicode_type
from tornado.util import unicode_type
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO # python 2
try:
import urllib.parse as urllib_parse # py3
@ -58,7 +55,7 @@ except ImportError:
# here to minimize the temptation to use them in non-wsgi contexts.
if str is unicode_type:
def to_wsgi_str(s):
assert isinstance(s, bytes_type)
assert isinstance(s, bytes)
return s.decode('latin1')
def from_wsgi_str(s):
@ -66,7 +63,7 @@ if str is unicode_type:
return s.encode('latin1')
else:
def to_wsgi_str(s):
assert isinstance(s, bytes_type)
assert isinstance(s, bytes)
return s
def from_wsgi_str(s):
@ -210,7 +207,7 @@ class WSGIAdapter(object):
body = environ["wsgi.input"].read(
int(headers["Content-Length"]))
else:
body = ""
body = b""
protocol = environ["wsgi.url_scheme"]
remote_ip = environ.get("REMOTE_ADDR", "")
if environ.get("HTTP_HOST"):
@ -256,7 +253,7 @@ class WSGIContainer(object):
container = tornado.wsgi.WSGIContainer(simple_app)
http_server = tornado.httpserver.HTTPServer(container)
http_server.listen(8888)
tornado.ioloop.IOLoop.instance().start()
tornado.ioloop.IOLoop.current().start()
This class is intended to let other frameworks (Django, web.py, etc)
run on the Tornado HTTP server and I/O loop.
@ -287,7 +284,8 @@ class WSGIContainer(object):
if not data:
raise Exception("WSGI app did not call start_response")
status_code = int(data["status"].split()[0])
status_code, reason = data["status"].split(' ', 1)
status_code = int(status_code)
headers = data["headers"]
header_set = set(k.lower() for (k, v) in headers)
body = escape.utf8(body)
@ -299,13 +297,12 @@ class WSGIContainer(object):
if "server" not in header_set:
headers.append(("Server", "TornadoServer/%s" % tornado.version))
parts = [escape.utf8("HTTP/1.1 " + data["status"] + "\r\n")]
start_line = httputil.ResponseStartLine("HTTP/1.1", status_code, reason)
header_obj = httputil.HTTPHeaders()
for key, value in headers:
parts.append(escape.utf8(key) + b": " + escape.utf8(value) + b"\r\n")
parts.append(b"\r\n")
parts.append(body)
request.write(b"".join(parts))
request.finish()
header_obj.add(key, value)
request.connection.write_headers(start_line, header_obj, chunk=body)
request.connection.finish()
self._log(status_code, request)
@staticmethod