Switch to python3

This commit is contained in:
j 2014-09-30 18:15:32 +02:00
commit 9ba4b6a91a
5286 changed files with 677347 additions and 576888 deletions

View file

@ -0,0 +1,29 @@
#!/usr/bin/env python
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""The Tornado web server and tools."""
from __future__ import absolute_import, division, print_function, with_statement
# version is a human-readable version number.
# version_info is a four-tuple for programmatic comparison. The first
# three numbers are the components of the version number. The fourth
# is zero for an official release, positive for a development branch,
# or negative for a release candidate or beta (after the base version
# number has been incremented)
version = "4.0"
version_info = (4, 0, 0, 0)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,321 @@
#!/usr/bin/env python
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Automatically restart the server when a source file is modified.
Most applications should not access this module directly. Instead,
pass the keyword argument ``autoreload=True`` to the
`tornado.web.Application` constructor (or ``debug=True``, which
enables this setting and several others). This will enable autoreload
mode as well as checking for changes to templates and static
resources. Note that restarting is a destructive operation and any
requests in progress will be aborted when the process restarts. (If
you want to disable autoreload while using other debug-mode features,
pass both ``debug=True`` and ``autoreload=False``).
This module can also be used as a command-line wrapper around scripts
such as unit test runners. See the `main` method for details.
The command-line wrapper and Application debug modes can be used together.
This combination is encouraged as the wrapper catches syntax errors and
other import-time failures, while debug mode catches changes once
the server has started.
This module depends on `.IOLoop`, so it will not work in WSGI applications
and Google App Engine. It also will not work correctly when `.HTTPServer`'s
multi-process mode is used.
Reloading loses any Python interpreter command-line arguments (e.g. ``-u``)
because it re-executes Python using ``sys.executable`` and ``sys.argv``.
Additionally, modifying these variables will cause reloading to behave
incorrectly.
"""
from __future__ import absolute_import, division, print_function, with_statement
import os
import sys
# sys.path handling
# -----------------
#
# If a module is run with "python -m", the current directory (i.e. "")
# is automatically prepended to sys.path, but not if it is run as
# "path/to/file.py". The processing for "-m" rewrites the former to
# the latter, so subsequent executions won't have the same path as the
# original.
#
# Conversely, when run as path/to/file.py, the directory containing
# file.py gets added to the path, which can cause confusion as imports
# may become relative in spite of the future import.
#
# We address the former problem by setting the $PYTHONPATH environment
# variable before re-execution so the new process will see the correct
# path. We attempt to address the latter problem when tornado.autoreload
# is run as __main__, although we can't fix the general case because
# we cannot reliably reconstruct the original command line
# (http://bugs.python.org/issue14208).
if __name__ == "__main__":
# This sys.path manipulation must come before our imports (as much
# as possible - if we introduced a tornado.sys or tornado.os
# module we'd be in trouble), or else our imports would become
# relative again despite the future import.
#
# There is a separate __main__ block at the end of the file to call main().
if sys.path[0] == os.path.dirname(__file__):
del sys.path[0]
import functools
import logging
import os
import pkgutil
import sys
import traceback
import types
import subprocess
import weakref
from tornado import ioloop
from tornado.log import gen_log
from tornado import process
from tornado.util import exec_in
try:
import signal
except ImportError:
signal = None
_watched_files = set()
_reload_hooks = []
_reload_attempted = False
_io_loops = weakref.WeakKeyDictionary()
def start(io_loop=None, check_time=500):
"""Begins watching source files for changes using the given `.IOLoop`. """
io_loop = io_loop or ioloop.IOLoop.current()
if io_loop in _io_loops:
return
_io_loops[io_loop] = True
if len(_io_loops) > 1:
gen_log.warning("tornado.autoreload started more than once in the same process")
add_reload_hook(functools.partial(io_loop.close, all_fds=True))
modify_times = {}
callback = functools.partial(_reload_on_update, modify_times)
scheduler = ioloop.PeriodicCallback(callback, check_time, io_loop=io_loop)
scheduler.start()
def wait():
"""Wait for a watched file to change, then restart the process.
Intended to be used at the end of scripts like unit test runners,
to run the tests again after any source file changes (but see also
the command-line interface in `main`)
"""
io_loop = ioloop.IOLoop()
start(io_loop)
io_loop.start()
def watch(filename):
"""Add a file to the watch list.
All imported modules are watched by default.
"""
_watched_files.add(filename)
def add_reload_hook(fn):
"""Add a function to be called before reloading the process.
Note that for open file and socket handles it is generally
preferable to set the ``FD_CLOEXEC`` flag (using `fcntl` or
``tornado.platform.auto.set_close_exec``) instead
of using a reload hook to close them.
"""
_reload_hooks.append(fn)
def _reload_on_update(modify_times):
if _reload_attempted:
# We already tried to reload and it didn't work, so don't try again.
return
if process.task_id() is not None:
# We're in a child process created by fork_processes. If child
# processes restarted themselves, they'd all restart and then
# all call fork_processes again.
return
for module in sys.modules.values():
# Some modules play games with sys.modules (e.g. email/__init__.py
# in the standard library), and occasionally this can cause strange
# failures in getattr. Just ignore anything that's not an ordinary
# module.
if not isinstance(module, types.ModuleType):
continue
path = getattr(module, "__file__", None)
if not path:
continue
if path.endswith(".pyc") or path.endswith(".pyo"):
path = path[:-1]
_check_file(modify_times, path)
for path in _watched_files:
_check_file(modify_times, path)
def _check_file(modify_times, path):
try:
modified = os.stat(path).st_mtime
except Exception:
return
if path not in modify_times:
modify_times[path] = modified
return
if modify_times[path] != modified:
gen_log.info("%s modified; restarting server", path)
_reload()
def _reload():
global _reload_attempted
_reload_attempted = True
for fn in _reload_hooks:
fn()
if hasattr(signal, "setitimer"):
# Clear the alarm signal set by
# ioloop.set_blocking_log_threshold so it doesn't fire
# after the exec.
signal.setitimer(signal.ITIMER_REAL, 0, 0)
# sys.path fixes: see comments at top of file. If sys.path[0] is an empty
# string, we were (probably) invoked with -m and the effective path
# is about to change on re-exec. Add the current directory to $PYTHONPATH
# to ensure that the new process sees the same path we did.
path_prefix = '.' + os.pathsep
if (sys.path[0] == '' and
not os.environ.get("PYTHONPATH", "").startswith(path_prefix)):
os.environ["PYTHONPATH"] = (path_prefix +
os.environ.get("PYTHONPATH", ""))
if sys.platform == 'win32':
# os.execv is broken on Windows and can't properly parse command line
# arguments and executable name if they contain whitespaces. subprocess
# fixes that behavior.
subprocess.Popen([sys.executable] + sys.argv)
sys.exit(0)
else:
try:
os.execv(sys.executable, [sys.executable] + sys.argv)
except OSError:
# Mac OS X versions prior to 10.6 do not support execv in
# a process that contains multiple threads. Instead of
# re-executing in the current process, start a new one
# and cause the current process to exit. This isn't
# ideal since the new process is detached from the parent
# terminal and thus cannot easily be killed with ctrl-C,
# but it's better than not being able to autoreload at
# all.
# Unfortunately the errno returned in this case does not
# appear to be consistent, so we can't easily check for
# this error specifically.
os.spawnv(os.P_NOWAIT, sys.executable,
[sys.executable] + sys.argv)
sys.exit(0)
_USAGE = """\
Usage:
python -m tornado.autoreload -m module.to.run [args...]
python -m tornado.autoreload path/to/script.py [args...]
"""
def main():
"""Command-line wrapper to re-run a script whenever its source changes.
Scripts may be specified by filename or module name::
python -m tornado.autoreload -m tornado.test.runtests
python -m tornado.autoreload tornado/test/runtests.py
Running a script with this wrapper is similar to calling
`tornado.autoreload.wait` at the end of the script, but this wrapper
can catch import-time problems like syntax errors that would otherwise
prevent the script from reaching its call to `wait`.
"""
original_argv = sys.argv
sys.argv = sys.argv[:]
if len(sys.argv) >= 3 and sys.argv[1] == "-m":
mode = "module"
module = sys.argv[2]
del sys.argv[1:3]
elif len(sys.argv) >= 2:
mode = "script"
script = sys.argv[1]
sys.argv = sys.argv[1:]
else:
print(_USAGE, file=sys.stderr)
sys.exit(1)
try:
if mode == "module":
import runpy
runpy.run_module(module, run_name="__main__", alter_sys=True)
elif mode == "script":
with open(script) as f:
global __file__
__file__ = script
# Use globals as our "locals" dictionary so that
# something that tries to import __main__ (e.g. the unittest
# module) will see the right things.
exec_in(f.read(), globals(), globals())
except SystemExit as e:
logging.basicConfig()
gen_log.info("Script exited with status %s", e.code)
except Exception as e:
logging.basicConfig()
gen_log.warning("Script exited with uncaught exception", exc_info=True)
# If an exception occurred at import time, the file with the error
# never made it into sys.modules and so we won't know to watch it.
# Just to make sure we've covered everything, walk the stack trace
# from the exception and watch every file.
for (filename, lineno, name, line) in traceback.extract_tb(sys.exc_info()[2]):
watch(filename)
if isinstance(e, SyntaxError):
# SyntaxErrors are special: their innermost stack frame is fake
# so extract_tb won't see it and we have to get the filename
# from the exception object.
watch(e.filename)
else:
logging.basicConfig()
gen_log.info("Script exited normally")
# restore sys.argv so subsequent executions will include autoreload
sys.argv = original_argv
if mode == 'module':
# runpy did a fake import of the module as __main__, but now it's
# no longer in sys.modules. Figure out where it is and watch it.
loader = pkgutil.get_loader(module)
if loader is not None:
watch(loader.get_filename())
wait()
if __name__ == "__main__":
# See also the other __main__ block at the top of the file, which modifies
# sys.path before our imports
main()

View file

@ -0,0 +1,329 @@
#!/usr/bin/env python
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Utilities for working with threads and ``Futures``.
``Futures`` are a pattern for concurrent programming introduced in
Python 3.2 in the `concurrent.futures` package (this package has also
been backported to older versions of Python and can be installed with
``pip install futures``). Tornado will use `concurrent.futures.Future` if
it is available; otherwise it will use a compatible class defined in this
module.
"""
from __future__ import absolute_import, division, print_function, with_statement
import functools
import sys
from tornado.stack_context import ExceptionStackContext, wrap
from tornado.util import raise_exc_info, ArgReplacer
try:
from concurrent import futures
except ImportError:
futures = None
class ReturnValueIgnoredError(Exception):
pass
class Future(object):
"""Placeholder for an asynchronous result.
A ``Future`` encapsulates the result of an asynchronous
operation. In synchronous applications ``Futures`` are used
to wait for the result from a thread or process pool; in
Tornado they are normally used with `.IOLoop.add_future` or by
yielding them in a `.gen.coroutine`.
`tornado.concurrent.Future` is similar to
`concurrent.futures.Future`, but not thread-safe (and therefore
faster for use with single-threaded event loops).
In addition to ``exception`` and ``set_exception``, methods ``exc_info``
and ``set_exc_info`` are supported to capture tracebacks in Python 2.
The traceback is automatically available in Python 3, but in the
Python 2 futures backport this information is discarded.
This functionality was previously available in a separate class
``TracebackFuture``, which is now a deprecated alias for this class.
.. versionchanged:: 4.0
`tornado.concurrent.Future` is always a thread-unsafe ``Future``
with support for the ``exc_info`` methods. Previously it would
be an alias for the thread-safe `concurrent.futures.Future`
if that package was available and fall back to the thread-unsafe
implementation if it was not.
"""
def __init__(self):
self._done = False
self._result = None
self._exception = None
self._exc_info = None
self._callbacks = []
def cancel(self):
"""Cancel the operation, if possible.
Tornado ``Futures`` do not support cancellation, so this method always
returns False.
"""
return False
def cancelled(self):
"""Returns True if the operation has been cancelled.
Tornado ``Futures`` do not support cancellation, so this method
always returns False.
"""
return False
def running(self):
"""Returns True if this operation is currently running."""
return not self._done
def done(self):
"""Returns True if the future has finished running."""
return self._done
def result(self, timeout=None):
"""If the operation succeeded, return its result. If it failed,
re-raise its exception.
"""
if self._result is not None:
return self._result
if self._exc_info is not None:
raise_exc_info(self._exc_info)
elif self._exception is not None:
raise self._exception
self._check_done()
return self._result
def exception(self, timeout=None):
"""If the operation raised an exception, return the `Exception`
object. Otherwise returns None.
"""
if self._exception is not None:
return self._exception
else:
self._check_done()
return None
def add_done_callback(self, fn):
"""Attaches the given callback to the `Future`.
It will be invoked with the `Future` as its argument when the Future
has finished running and its result is available. In Tornado
consider using `.IOLoop.add_future` instead of calling
`add_done_callback` directly.
"""
if self._done:
fn(self)
else:
self._callbacks.append(fn)
def set_result(self, result):
"""Sets the result of a ``Future``.
It is undefined to call any of the ``set`` methods more than once
on the same object.
"""
self._result = result
self._set_done()
def set_exception(self, exception):
"""Sets the exception of a ``Future.``"""
self._exception = exception
self._set_done()
def exc_info(self):
"""Returns a tuple in the same format as `sys.exc_info` or None.
.. versionadded:: 4.0
"""
return self._exc_info
def set_exc_info(self, exc_info):
"""Sets the exception information of a ``Future.``
Preserves tracebacks on Python 2.
.. versionadded:: 4.0
"""
self._exc_info = exc_info
self.set_exception(exc_info[1])
def _check_done(self):
if not self._done:
raise Exception("DummyFuture does not support blocking for results")
def _set_done(self):
self._done = True
for cb in self._callbacks:
# TODO: error handling
cb(self)
self._callbacks = None
TracebackFuture = Future
if futures is None:
FUTURES = Future
else:
FUTURES = (futures.Future, Future)
def is_future(x):
return isinstance(x, FUTURES)
class DummyExecutor(object):
def submit(self, fn, *args, **kwargs):
future = TracebackFuture()
try:
future.set_result(fn(*args, **kwargs))
except Exception:
future.set_exc_info(sys.exc_info())
return future
def shutdown(self, wait=True):
pass
dummy_executor = DummyExecutor()
def run_on_executor(fn):
"""Decorator to run a synchronous method asynchronously on an executor.
The decorated method may be called with a ``callback`` keyword
argument and returns a future.
This decorator should be used only on methods of objects with attributes
``executor`` and ``io_loop``.
"""
@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
callback = kwargs.pop("callback", None)
future = self.executor.submit(fn, self, *args, **kwargs)
if callback:
self.io_loop.add_future(future,
lambda future: callback(future.result()))
return future
return wrapper
_NO_RESULT = object()
def return_future(f):
"""Decorator to make a function that returns via callback return a
`Future`.
The wrapped function should take a ``callback`` keyword argument
and invoke it with one argument when it has finished. To signal failure,
the function can simply raise an exception (which will be
captured by the `.StackContext` and passed along to the ``Future``).
From the caller's perspective, the callback argument is optional.
If one is given, it will be invoked when the function is complete
with `Future.result()` as an argument. If the function fails, the
callback will not be run and an exception will be raised into the
surrounding `.StackContext`.
If no callback is given, the caller should use the ``Future`` to
wait for the function to complete (perhaps by yielding it in a
`.gen.engine` function, or passing it to `.IOLoop.add_future`).
Usage::
@return_future
def future_func(arg1, arg2, callback):
# Do stuff (possibly asynchronous)
callback(result)
@gen.engine
def caller(callback):
yield future_func(arg1, arg2)
callback()
Note that ``@return_future`` and ``@gen.engine`` can be applied to the
same function, provided ``@return_future`` appears first. However,
consider using ``@gen.coroutine`` instead of this combination.
"""
replacer = ArgReplacer(f, 'callback')
@functools.wraps(f)
def wrapper(*args, **kwargs):
future = TracebackFuture()
callback, args, kwargs = replacer.replace(
lambda value=_NO_RESULT: future.set_result(value),
args, kwargs)
def handle_error(typ, value, tb):
future.set_exc_info((typ, value, tb))
return True
exc_info = None
with ExceptionStackContext(handle_error):
try:
result = f(*args, **kwargs)
if result is not None:
raise ReturnValueIgnoredError(
"@return_future should not be used with functions "
"that return values")
except:
exc_info = sys.exc_info()
raise
if exc_info is not None:
# If the initial synchronous part of f() raised an exception,
# go ahead and raise it to the caller directly without waiting
# for them to inspect the Future.
raise_exc_info(exc_info)
# If the caller passed in a callback, schedule it to be called
# when the future resolves. It is important that this happens
# just before we return the future, or else we risk confusing
# stack contexts with multiple exceptions (one here with the
# immediate exception, and again when the future resolves and
# the callback triggers its exception by calling future.result()).
if callback is not None:
def run_callback(future):
result = future.result()
if result is _NO_RESULT:
callback()
else:
callback(future.result())
future.add_done_callback(wrap(run_callback))
return future
return wrapper
def chain_future(a, b):
"""Chain two futures together so that when one completes, so does the other.
The result (success or failure) of ``a`` will be copied to ``b``, unless
``b`` has already been completed or cancelled by the time ``a`` finishes.
"""
def copy(future):
assert future is a
if b.done():
return
if (isinstance(a, TracebackFuture) and isinstance(b, TracebackFuture)
and a.exc_info() is not None):
b.set_exc_info(a.exc_info())
elif a.exception() is not None:
b.set_exception(a.exception())
else:
b.set_result(a.result())
a.add_done_callback(copy)

View file

@ -0,0 +1,477 @@
#!/usr/bin/env python
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Non-blocking HTTP client implementation using pycurl."""
from __future__ import absolute_import, division, print_function, with_statement
import collections
import logging
import pycurl
import threading
import time
from tornado import httputil
from tornado import ioloop
from tornado.log import gen_log
from tornado import stack_context
from tornado.escape import utf8, native_str
from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main
from tornado.util import bytes_type
try:
from io import BytesIO # py3
except ImportError:
from cStringIO import StringIO as BytesIO # py2
class CurlAsyncHTTPClient(AsyncHTTPClient):
def initialize(self, io_loop, max_clients=10, defaults=None):
super(CurlAsyncHTTPClient, self).initialize(io_loop, defaults=defaults)
self._multi = pycurl.CurlMulti()
self._multi.setopt(pycurl.M_TIMERFUNCTION, self._set_timeout)
self._multi.setopt(pycurl.M_SOCKETFUNCTION, self._handle_socket)
self._curls = [_curl_create() for i in range(max_clients)]
self._free_list = self._curls[:]
self._requests = collections.deque()
self._fds = {}
self._timeout = None
# libcurl has bugs that sometimes cause it to not report all
# relevant file descriptors and timeouts to TIMERFUNCTION/
# SOCKETFUNCTION. Mitigate the effects of such bugs by
# forcing a periodic scan of all active requests.
self._force_timeout_callback = ioloop.PeriodicCallback(
self._handle_force_timeout, 1000, io_loop=io_loop)
self._force_timeout_callback.start()
# Work around a bug in libcurl 7.29.0: Some fields in the curl
# multi object are initialized lazily, and its destructor will
# segfault if it is destroyed without having been used. Add
# and remove a dummy handle to make sure everything is
# initialized.
dummy_curl_handle = pycurl.Curl()
self._multi.add_handle(dummy_curl_handle)
self._multi.remove_handle(dummy_curl_handle)
def close(self):
self._force_timeout_callback.stop()
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
for curl in self._curls:
curl.close()
self._multi.close()
super(CurlAsyncHTTPClient, self).close()
def fetch_impl(self, request, callback):
self._requests.append((request, callback))
self._process_queue()
self._set_timeout(0)
def _handle_socket(self, event, fd, multi, data):
"""Called by libcurl when it wants to change the file descriptors
it cares about.
"""
event_map = {
pycurl.POLL_NONE: ioloop.IOLoop.NONE,
pycurl.POLL_IN: ioloop.IOLoop.READ,
pycurl.POLL_OUT: ioloop.IOLoop.WRITE,
pycurl.POLL_INOUT: ioloop.IOLoop.READ | ioloop.IOLoop.WRITE
}
if event == pycurl.POLL_REMOVE:
if fd in self._fds:
self.io_loop.remove_handler(fd)
del self._fds[fd]
else:
ioloop_event = event_map[event]
# libcurl sometimes closes a socket and then opens a new
# one using the same FD without giving us a POLL_NONE in
# between. This is a problem with the epoll IOLoop,
# because the kernel can tell when a socket is closed and
# removes it from the epoll automatically, causing future
# update_handler calls to fail. Since we can't tell when
# this has happened, always use remove and re-add
# instead of update.
if fd in self._fds:
self.io_loop.remove_handler(fd)
self.io_loop.add_handler(fd, self._handle_events,
ioloop_event)
self._fds[fd] = ioloop_event
def _set_timeout(self, msecs):
"""Called by libcurl to schedule a timeout."""
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
self._timeout = self.io_loop.add_timeout(
self.io_loop.time() + msecs / 1000.0, self._handle_timeout)
def _handle_events(self, fd, events):
"""Called by IOLoop when there is activity on one of our
file descriptors.
"""
action = 0
if events & ioloop.IOLoop.READ:
action |= pycurl.CSELECT_IN
if events & ioloop.IOLoop.WRITE:
action |= pycurl.CSELECT_OUT
while True:
try:
ret, num_handles = self._multi.socket_action(fd, action)
except pycurl.error as e:
ret = e.args[0]
if ret != pycurl.E_CALL_MULTI_PERFORM:
break
self._finish_pending_requests()
def _handle_timeout(self):
"""Called by IOLoop when the requested timeout has passed."""
with stack_context.NullContext():
self._timeout = None
while True:
try:
ret, num_handles = self._multi.socket_action(
pycurl.SOCKET_TIMEOUT, 0)
except pycurl.error as e:
ret = e.args[0]
if ret != pycurl.E_CALL_MULTI_PERFORM:
break
self._finish_pending_requests()
# In theory, we shouldn't have to do this because curl will
# call _set_timeout whenever the timeout changes. However,
# sometimes after _handle_timeout we will need to reschedule
# immediately even though nothing has changed from curl's
# perspective. This is because when socket_action is
# called with SOCKET_TIMEOUT, libcurl decides internally which
# timeouts need to be processed by using a monotonic clock
# (where available) while tornado uses python's time.time()
# to decide when timeouts have occurred. When those clocks
# disagree on elapsed time (as they will whenever there is an
# NTP adjustment), tornado might call _handle_timeout before
# libcurl is ready. After each timeout, resync the scheduled
# timeout with libcurl's current state.
new_timeout = self._multi.timeout()
if new_timeout >= 0:
self._set_timeout(new_timeout)
def _handle_force_timeout(self):
"""Called by IOLoop periodically to ask libcurl to process any
events it may have forgotten about.
"""
with stack_context.NullContext():
while True:
try:
ret, num_handles = self._multi.socket_all()
except pycurl.error as e:
ret = e.args[0]
if ret != pycurl.E_CALL_MULTI_PERFORM:
break
self._finish_pending_requests()
def _finish_pending_requests(self):
"""Process any requests that were completed by the last
call to multi.socket_action.
"""
while True:
num_q, ok_list, err_list = self._multi.info_read()
for curl in ok_list:
self._finish(curl)
for curl, errnum, errmsg in err_list:
self._finish(curl, errnum, errmsg)
if num_q == 0:
break
self._process_queue()
def _process_queue(self):
with stack_context.NullContext():
while True:
started = 0
while self._free_list and self._requests:
started += 1
curl = self._free_list.pop()
(request, callback) = self._requests.popleft()
curl.info = {
"headers": httputil.HTTPHeaders(),
"buffer": BytesIO(),
"request": request,
"callback": callback,
"curl_start_time": time.time(),
}
_curl_setup_request(curl, request, curl.info["buffer"],
curl.info["headers"])
self._multi.add_handle(curl)
if not started:
break
def _finish(self, curl, curl_error=None, curl_message=None):
info = curl.info
curl.info = None
self._multi.remove_handle(curl)
self._free_list.append(curl)
buffer = info["buffer"]
if curl_error:
error = CurlError(curl_error, curl_message)
code = error.code
effective_url = None
buffer.close()
buffer = None
else:
error = None
code = curl.getinfo(pycurl.HTTP_CODE)
effective_url = curl.getinfo(pycurl.EFFECTIVE_URL)
buffer.seek(0)
# the various curl timings are documented at
# http://curl.haxx.se/libcurl/c/curl_easy_getinfo.html
time_info = dict(
queue=info["curl_start_time"] - info["request"].start_time,
namelookup=curl.getinfo(pycurl.NAMELOOKUP_TIME),
connect=curl.getinfo(pycurl.CONNECT_TIME),
pretransfer=curl.getinfo(pycurl.PRETRANSFER_TIME),
starttransfer=curl.getinfo(pycurl.STARTTRANSFER_TIME),
total=curl.getinfo(pycurl.TOTAL_TIME),
redirect=curl.getinfo(pycurl.REDIRECT_TIME),
)
try:
info["callback"](HTTPResponse(
request=info["request"], code=code, headers=info["headers"],
buffer=buffer, effective_url=effective_url, error=error,
reason=info['headers'].get("X-Http-Reason", None),
request_time=time.time() - info["curl_start_time"],
time_info=time_info))
except Exception:
self.handle_callback_exception(info["callback"])
def handle_callback_exception(self, callback):
self.io_loop.handle_callback_exception(callback)
class CurlError(HTTPError):
def __init__(self, errno, message):
HTTPError.__init__(self, 599, message)
self.errno = errno
def _curl_create():
curl = pycurl.Curl()
if gen_log.isEnabledFor(logging.DEBUG):
curl.setopt(pycurl.VERBOSE, 1)
curl.setopt(pycurl.DEBUGFUNCTION, _curl_debug)
return curl
def _curl_setup_request(curl, request, buffer, headers):
curl.setopt(pycurl.URL, native_str(request.url))
# libcurl's magic "Expect: 100-continue" behavior causes delays
# with servers that don't support it (which include, among others,
# Google's OpenID endpoint). Additionally, this behavior has
# a bug in conjunction with the curl_multi_socket_action API
# (https://sourceforge.net/tracker/?func=detail&atid=100976&aid=3039744&group_id=976),
# which increases the delays. It's more trouble than it's worth,
# so just turn off the feature (yes, setting Expect: to an empty
# value is the official way to disable this)
if "Expect" not in request.headers:
request.headers["Expect"] = ""
# libcurl adds Pragma: no-cache by default; disable that too
if "Pragma" not in request.headers:
request.headers["Pragma"] = ""
# Request headers may be either a regular dict or HTTPHeaders object
if isinstance(request.headers, httputil.HTTPHeaders):
curl.setopt(pycurl.HTTPHEADER,
[native_str("%s: %s" % i) for i in request.headers.get_all()])
else:
curl.setopt(pycurl.HTTPHEADER,
[native_str("%s: %s" % i) for i in request.headers.items()])
if request.header_callback:
curl.setopt(pycurl.HEADERFUNCTION,
lambda line: request.header_callback(native_str(line)))
else:
curl.setopt(pycurl.HEADERFUNCTION,
lambda line: _curl_header_callback(headers,
native_str(line)))
if request.streaming_callback:
write_function = request.streaming_callback
else:
write_function = buffer.write
if bytes_type is str: # py2
curl.setopt(pycurl.WRITEFUNCTION, write_function)
else: # py3
# Upstream pycurl doesn't support py3, but ubuntu 12.10 includes
# a fork/port. That version has a bug in which it passes unicode
# strings instead of bytes to the WRITEFUNCTION. This means that
# if you use a WRITEFUNCTION (which tornado always does), you cannot
# download arbitrary binary data. This needs to be fixed in the
# ported pycurl package, but in the meantime this lambda will
# make it work for downloading (utf8) text.
curl.setopt(pycurl.WRITEFUNCTION, lambda s: write_function(utf8(s)))
curl.setopt(pycurl.FOLLOWLOCATION, request.follow_redirects)
curl.setopt(pycurl.MAXREDIRS, request.max_redirects)
curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(1000 * request.connect_timeout))
curl.setopt(pycurl.TIMEOUT_MS, int(1000 * request.request_timeout))
if request.user_agent:
curl.setopt(pycurl.USERAGENT, native_str(request.user_agent))
else:
curl.setopt(pycurl.USERAGENT, "Mozilla/5.0 (compatible; pycurl)")
if request.network_interface:
curl.setopt(pycurl.INTERFACE, request.network_interface)
if request.decompress_response:
curl.setopt(pycurl.ENCODING, "gzip,deflate")
else:
curl.setopt(pycurl.ENCODING, "none")
if request.proxy_host and request.proxy_port:
curl.setopt(pycurl.PROXY, request.proxy_host)
curl.setopt(pycurl.PROXYPORT, request.proxy_port)
if request.proxy_username:
credentials = '%s:%s' % (request.proxy_username,
request.proxy_password)
curl.setopt(pycurl.PROXYUSERPWD, credentials)
else:
curl.setopt(pycurl.PROXY, '')
curl.unsetopt(pycurl.PROXYUSERPWD)
if request.validate_cert:
curl.setopt(pycurl.SSL_VERIFYPEER, 1)
curl.setopt(pycurl.SSL_VERIFYHOST, 2)
else:
curl.setopt(pycurl.SSL_VERIFYPEER, 0)
curl.setopt(pycurl.SSL_VERIFYHOST, 0)
if request.ca_certs is not None:
curl.setopt(pycurl.CAINFO, request.ca_certs)
else:
# There is no way to restore pycurl.CAINFO to its default value
# (Using unsetopt makes it reject all certificates).
# I don't see any way to read the default value from python so it
# can be restored later. We'll have to just leave CAINFO untouched
# if no ca_certs file was specified, and require that if any
# request uses a custom ca_certs file, they all must.
pass
if request.allow_ipv6 is False:
# Curl behaves reasonably when DNS resolution gives an ipv6 address
# that we can't reach, so allow ipv6 unless the user asks to disable.
curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
else:
curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_WHATEVER)
# Set the request method through curl's irritating interface which makes
# up names for almost every single method
curl_options = {
"GET": pycurl.HTTPGET,
"POST": pycurl.POST,
"PUT": pycurl.UPLOAD,
"HEAD": pycurl.NOBODY,
}
custom_methods = set(["DELETE", "OPTIONS", "PATCH"])
for o in curl_options.values():
curl.setopt(o, False)
if request.method in curl_options:
curl.unsetopt(pycurl.CUSTOMREQUEST)
curl.setopt(curl_options[request.method], True)
elif request.allow_nonstandard_methods or request.method in custom_methods:
curl.setopt(pycurl.CUSTOMREQUEST, request.method)
else:
raise KeyError('unknown method ' + request.method)
# Handle curl's cryptic options for every individual HTTP method
if request.method in ("POST", "PUT"):
if request.body is None:
raise AssertionError(
'Body must not be empty for "%s" request'
% request.method)
request_buffer = BytesIO(utf8(request.body))
curl.setopt(pycurl.READFUNCTION, request_buffer.read)
if request.method == "POST":
def ioctl(cmd):
if cmd == curl.IOCMD_RESTARTREAD:
request_buffer.seek(0)
curl.setopt(pycurl.IOCTLFUNCTION, ioctl)
curl.setopt(pycurl.POSTFIELDSIZE, len(request.body))
else:
curl.setopt(pycurl.INFILESIZE, len(request.body))
elif request.method == "GET":
if request.body is not None:
raise AssertionError('Body must be empty for GET request')
if request.auth_username is not None:
userpwd = "%s:%s" % (request.auth_username, request.auth_password or '')
if request.auth_mode is None or request.auth_mode == "basic":
curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_BASIC)
elif request.auth_mode == "digest":
curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_DIGEST)
else:
raise ValueError("Unsupported auth_mode %s" % request.auth_mode)
curl.setopt(pycurl.USERPWD, native_str(userpwd))
gen_log.debug("%s %s (username: %r)", request.method, request.url,
request.auth_username)
else:
curl.unsetopt(pycurl.USERPWD)
gen_log.debug("%s %s", request.method, request.url)
if request.client_cert is not None:
curl.setopt(pycurl.SSLCERT, request.client_cert)
if request.client_key is not None:
curl.setopt(pycurl.SSLKEY, request.client_key)
if threading.activeCount() > 1:
# libcurl/pycurl is not thread-safe by default. When multiple threads
# are used, signals should be disabled. This has the side effect
# of disabling DNS timeouts in some environments (when libcurl is
# not linked against ares), so we don't do it when there is only one
# thread. Applications that use many short-lived threads may need
# to set NOSIGNAL manually in a prepare_curl_callback since
# there may not be any other threads running at the time we call
# threading.activeCount.
curl.setopt(pycurl.NOSIGNAL, 1)
if request.prepare_curl_callback is not None:
request.prepare_curl_callback(curl)
def _curl_header_callback(headers, header_line):
# header_line as returned by curl includes the end-of-line characters.
header_line = header_line.strip()
if header_line.startswith("HTTP/"):
headers.clear()
try:
(__, __, reason) = httputil.parse_response_start_line(header_line)
header_line = "X-Http-Reason: %s" % reason
except httputil.HTTPInputError:
return
if not header_line:
return
headers.parse_line(header_line)
def _curl_debug(debug_type, debug_msg):
debug_types = ('I', '<', '>', '<', '>')
if debug_type == 0:
gen_log.debug('%s', debug_msg.strip())
elif debug_type in (1, 2):
for line in debug_msg.splitlines():
gen_log.debug('%s %s', debug_types[debug_type], line)
elif debug_type == 4:
gen_log.debug('%s %r', debug_types[debug_type], debug_msg)
if __name__ == "__main__":
AsyncHTTPClient.configure(CurlAsyncHTTPClient)
main()

View file

@ -0,0 +1,396 @@
#!/usr/bin/env python
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Escaping/unescaping methods for HTML, JSON, URLs, and others.
Also includes a few other miscellaneous string manipulation functions that
have crept in over time.
"""
from __future__ import absolute_import, division, print_function, with_statement
import re
import sys
from tornado.util import bytes_type, unicode_type, basestring_type, u
try:
from urllib.parse import parse_qs as _parse_qs # py3
except ImportError:
from urlparse import parse_qs as _parse_qs # Python 2.6+
try:
import htmlentitydefs # py2
except ImportError:
import html.entities as htmlentitydefs # py3
try:
import urllib.parse as urllib_parse # py3
except ImportError:
import urllib as urllib_parse # py2
import json
try:
unichr
except NameError:
unichr = chr
_XHTML_ESCAPE_RE = re.compile('[&<>"\']')
_XHTML_ESCAPE_DICT = {'&': '&amp;', '<': '&lt;', '>': '&gt;', '"': '&quot;',
'\'': '&#39;'}
def xhtml_escape(value):
"""Escapes a string so it is valid within HTML or XML.
Escapes the characters ``<``, ``>``, ``"``, ``'``, and ``&``.
When used in attribute values the escaped strings must be enclosed
in quotes.
.. versionchanged:: 3.2
Added the single quote to the list of escaped characters.
"""
return _XHTML_ESCAPE_RE.sub(lambda match: _XHTML_ESCAPE_DICT[match.group(0)],
to_basestring(value))
def xhtml_unescape(value):
"""Un-escapes an XML-escaped string."""
return re.sub(r"&(#?)(\w+?);", _convert_entity, _unicode(value))
# The fact that json_encode wraps json.dumps is an implementation detail.
# Please see https://github.com/tornadoweb/tornado/pull/706
# before sending a pull request that adds **kwargs to this function.
def json_encode(value):
"""JSON-encodes the given Python object."""
# JSON permits but does not require forward slashes to be escaped.
# This is useful when json data is emitted in a <script> tag
# in HTML, as it prevents </script> tags from prematurely terminating
# the javscript. Some json libraries do this escaping by default,
# although python's standard library does not, so we do it here.
# http://stackoverflow.com/questions/1580647/json-why-are-forward-slashes-escaped
return json.dumps(value).replace("</", "<\\/")
def json_decode(value):
"""Returns Python objects for the given JSON string."""
return json.loads(to_basestring(value))
def squeeze(value):
"""Replace all sequences of whitespace chars with a single space."""
return re.sub(r"[\x00-\x20]+", " ", value).strip()
def url_escape(value, plus=True):
"""Returns a URL-encoded version of the given value.
If ``plus`` is true (the default), spaces will be represented
as "+" instead of "%20". This is appropriate for query strings
but not for the path component of a URL. Note that this default
is the reverse of Python's urllib module.
.. versionadded:: 3.1
The ``plus`` argument
"""
quote = urllib_parse.quote_plus if plus else urllib_parse.quote
return quote(utf8(value))
# python 3 changed things around enough that we need two separate
# implementations of url_unescape. We also need our own implementation
# of parse_qs since python 3's version insists on decoding everything.
if sys.version_info[0] < 3:
def url_unescape(value, encoding='utf-8', plus=True):
"""Decodes the given value from a URL.
The argument may be either a byte or unicode string.
If encoding is None, the result will be a byte string. Otherwise,
the result is a unicode string in the specified encoding.
If ``plus`` is true (the default), plus signs will be interpreted
as spaces (literal plus signs must be represented as "%2B"). This
is appropriate for query strings and form-encoded values but not
for the path component of a URL. Note that this default is the
reverse of Python's urllib module.
.. versionadded:: 3.1
The ``plus`` argument
"""
unquote = (urllib_parse.unquote_plus if plus else urllib_parse.unquote)
if encoding is None:
return unquote(utf8(value))
else:
return unicode_type(unquote(utf8(value)), encoding)
parse_qs_bytes = _parse_qs
else:
def url_unescape(value, encoding='utf-8', plus=True):
"""Decodes the given value from a URL.
The argument may be either a byte or unicode string.
If encoding is None, the result will be a byte string. Otherwise,
the result is a unicode string in the specified encoding.
If ``plus`` is true (the default), plus signs will be interpreted
as spaces (literal plus signs must be represented as "%2B"). This
is appropriate for query strings and form-encoded values but not
for the path component of a URL. Note that this default is the
reverse of Python's urllib module.
.. versionadded:: 3.1
The ``plus`` argument
"""
if encoding is None:
if plus:
# unquote_to_bytes doesn't have a _plus variant
value = to_basestring(value).replace('+', ' ')
return urllib_parse.unquote_to_bytes(value)
else:
unquote = (urllib_parse.unquote_plus if plus
else urllib_parse.unquote)
return unquote(to_basestring(value), encoding=encoding)
def parse_qs_bytes(qs, keep_blank_values=False, strict_parsing=False):
"""Parses a query string like urlparse.parse_qs, but returns the
values as byte strings.
Keys still become type str (interpreted as latin1 in python3!)
because it's too painful to keep them as byte strings in
python3 and in practice they're nearly always ascii anyway.
"""
# This is gross, but python3 doesn't give us another way.
# Latin1 is the universal donor of character encodings.
result = _parse_qs(qs, keep_blank_values, strict_parsing,
encoding='latin1', errors='strict')
encoded = {}
for k, v in result.items():
encoded[k] = [i.encode('latin1') for i in v]
return encoded
_UTF8_TYPES = (bytes_type, type(None))
def utf8(value):
"""Converts a string argument to a byte string.
If the argument is already a byte string or None, it is returned unchanged.
Otherwise it must be a unicode string and is encoded as utf8.
"""
if isinstance(value, _UTF8_TYPES):
return value
if not isinstance(value, unicode_type):
raise TypeError(
"Expected bytes, unicode, or None; got %r" % type(value)
)
return value.encode("utf-8")
_TO_UNICODE_TYPES = (unicode_type, type(None))
def to_unicode(value):
"""Converts a string argument to a unicode string.
If the argument is already a unicode string or None, it is returned
unchanged. Otherwise it must be a byte string and is decoded as utf8.
"""
if isinstance(value, _TO_UNICODE_TYPES):
return value
if not isinstance(value, bytes_type):
raise TypeError(
"Expected bytes, unicode, or None; got %r" % type(value)
)
return value.decode("utf-8")
# to_unicode was previously named _unicode not because it was private,
# but to avoid conflicts with the built-in unicode() function/type
_unicode = to_unicode
# When dealing with the standard library across python 2 and 3 it is
# sometimes useful to have a direct conversion to the native string type
if str is unicode_type:
native_str = to_unicode
else:
native_str = utf8
_BASESTRING_TYPES = (basestring_type, type(None))
def to_basestring(value):
"""Converts a string argument to a subclass of basestring.
In python2, byte and unicode strings are mostly interchangeable,
so functions that deal with a user-supplied argument in combination
with ascii string constants can use either and should return the type
the user supplied. In python3, the two types are not interchangeable,
so this method is needed to convert byte strings to unicode.
"""
if isinstance(value, _BASESTRING_TYPES):
return value
if not isinstance(value, bytes_type):
raise TypeError(
"Expected bytes, unicode, or None; got %r" % type(value)
)
return value.decode("utf-8")
def recursive_unicode(obj):
"""Walks a simple data structure, converting byte strings to unicode.
Supports lists, tuples, and dictionaries.
"""
if isinstance(obj, dict):
return dict((recursive_unicode(k), recursive_unicode(v)) for (k, v) in obj.items())
elif isinstance(obj, list):
return list(recursive_unicode(i) for i in obj)
elif isinstance(obj, tuple):
return tuple(recursive_unicode(i) for i in obj)
elif isinstance(obj, bytes_type):
return to_unicode(obj)
else:
return obj
# I originally used the regex from
# http://daringfireball.net/2010/07/improved_regex_for_matching_urls
# but it gets all exponential on certain patterns (such as too many trailing
# dots), causing the regex matcher to never return.
# This regex should avoid those problems.
# Use to_unicode instead of tornado.util.u - we don't want backslashes getting
# processed as escapes.
_URL_RE = re.compile(to_unicode(r"""\b((?:([\w-]+):(/{1,3})|www[.])(?:(?:(?:[^\s&()]|&amp;|&quot;)*(?:[^!"#$%&'()*+,.:;<=>?@\[\]^`{|}~\s]))|(?:\((?:[^\s&()]|&amp;|&quot;)*\)))+)"""))
def linkify(text, shorten=False, extra_params="",
require_protocol=False, permitted_protocols=["http", "https"]):
"""Converts plain text into HTML with links.
For example: ``linkify("Hello http://tornadoweb.org!")`` would return
``Hello <a href="http://tornadoweb.org">http://tornadoweb.org</a>!``
Parameters:
* ``shorten``: Long urls will be shortened for display.
* ``extra_params``: Extra text to include in the link tag, or a callable
taking the link as an argument and returning the extra text
e.g. ``linkify(text, extra_params='rel="nofollow" class="external"')``,
or::
def extra_params_cb(url):
if url.startswith("http://example.com"):
return 'class="internal"'
else:
return 'class="external" rel="nofollow"'
linkify(text, extra_params=extra_params_cb)
* ``require_protocol``: Only linkify urls which include a protocol. If
this is False, urls such as www.facebook.com will also be linkified.
* ``permitted_protocols``: List (or set) of protocols which should be
linkified, e.g. ``linkify(text, permitted_protocols=["http", "ftp",
"mailto"])``. It is very unsafe to include protocols such as
``javascript``.
"""
if extra_params and not callable(extra_params):
extra_params = " " + extra_params.strip()
def make_link(m):
url = m.group(1)
proto = m.group(2)
if require_protocol and not proto:
return url # not protocol, no linkify
if proto and proto not in permitted_protocols:
return url # bad protocol, no linkify
href = m.group(1)
if not proto:
href = "http://" + href # no proto specified, use http
if callable(extra_params):
params = " " + extra_params(href).strip()
else:
params = extra_params
# clip long urls. max_len is just an approximation
max_len = 30
if shorten and len(url) > max_len:
before_clip = url
if proto:
proto_len = len(proto) + 1 + len(m.group(3) or "") # +1 for :
else:
proto_len = 0
parts = url[proto_len:].split("/")
if len(parts) > 1:
# Grab the whole host part plus the first bit of the path
# The path is usually not that interesting once shortened
# (no more slug, etc), so it really just provides a little
# extra indication of shortening.
url = url[:proto_len] + parts[0] + "/" + \
parts[1][:8].split('?')[0].split('.')[0]
if len(url) > max_len * 1.5: # still too long
url = url[:max_len]
if url != before_clip:
amp = url.rfind('&')
# avoid splitting html char entities
if amp > max_len - 5:
url = url[:amp]
url += "..."
if len(url) >= len(before_clip):
url = before_clip
else:
# full url is visible on mouse-over (for those who don't
# have a status bar, such as Safari by default)
params += ' title="%s"' % href
return u('<a href="%s"%s>%s</a>') % (href, params, url)
# First HTML-escape so that our strings are all safe.
# The regex is modified to avoid character entites other than &amp; so
# that we won't pick up &quot;, etc.
text = _unicode(xhtml_escape(text))
return _URL_RE.sub(make_link, text)
def _convert_entity(m):
if m.group(1) == "#":
try:
return unichr(int(m.group(2)))
except ValueError:
return "&#%s;" % m.group(2)
try:
return _HTML_UNICODE_MAP[m.group(2)]
except KeyError:
return "&%s;" % m.group(2)
def _build_unicode_map():
unicode_map = {}
for name, value in htmlentitydefs.name2codepoint.items():
unicode_map[name] = unichr(value)
return unicode_map
_HTML_UNICODE_MAP = _build_unicode_map()

View file

@ -0,0 +1,740 @@
"""``tornado.gen`` is a generator-based interface to make it easier to
work in an asynchronous environment. Code using the ``gen`` module
is technically asynchronous, but it is written as a single generator
instead of a collection of separate functions.
For example, the following asynchronous handler::
class AsyncHandler(RequestHandler):
@asynchronous
def get(self):
http_client = AsyncHTTPClient()
http_client.fetch("http://example.com",
callback=self.on_fetch)
def on_fetch(self, response):
do_something_with_response(response)
self.render("template.html")
could be written with ``gen`` as::
class GenAsyncHandler(RequestHandler):
@gen.coroutine
def get(self):
http_client = AsyncHTTPClient()
response = yield http_client.fetch("http://example.com")
do_something_with_response(response)
self.render("template.html")
Most asynchronous functions in Tornado return a `.Future`;
yielding this object returns its `~.Future.result`.
You can also yield a list or dict of ``Futures``, which will be
started at the same time and run in parallel; a list or dict of results will
be returned when they are all finished::
@gen.coroutine
def get(self):
http_client = AsyncHTTPClient()
response1, response2 = yield [http_client.fetch(url1),
http_client.fetch(url2)]
response_dict = yield dict(response3=http_client.fetch(url3),
response4=http_client.fetch(url4))
response3 = response_dict['response3']
response4 = response_dict['response4']
.. versionchanged:: 3.2
Dict support added.
"""
from __future__ import absolute_import, division, print_function, with_statement
import collections
import functools
import itertools
import sys
import types
from tornado.concurrent import Future, TracebackFuture, is_future, chain_future
from tornado.ioloop import IOLoop
from tornado import stack_context
class KeyReuseError(Exception):
pass
class UnknownKeyError(Exception):
pass
class LeakedCallbackError(Exception):
pass
class BadYieldError(Exception):
pass
class ReturnValueIgnoredError(Exception):
pass
class TimeoutError(Exception):
"""Exception raised by ``with_timeout``."""
def engine(func):
"""Callback-oriented decorator for asynchronous generators.
This is an older interface; for new code that does not need to be
compatible with versions of Tornado older than 3.0 the
`coroutine` decorator is recommended instead.
This decorator is similar to `coroutine`, except it does not
return a `.Future` and the ``callback`` argument is not treated
specially.
In most cases, functions decorated with `engine` should take
a ``callback`` argument and invoke it with their result when
they are finished. One notable exception is the
`~tornado.web.RequestHandler` :ref:`HTTP verb methods <verbs>`,
which use ``self.finish()`` in place of a callback argument.
"""
func = _make_coroutine_wrapper(func, replace_callback=False)
@functools.wraps(func)
def wrapper(*args, **kwargs):
future = func(*args, **kwargs)
def final_callback(future):
if future.result() is not None:
raise ReturnValueIgnoredError(
"@gen.engine functions cannot return values: %r" %
(future.result(),))
future.add_done_callback(final_callback)
return wrapper
def coroutine(func, replace_callback=True):
"""Decorator for asynchronous generators.
Any generator that yields objects from this module must be wrapped
in either this decorator or `engine`.
Coroutines may "return" by raising the special exception
`Return(value) <Return>`. In Python 3.3+, it is also possible for
the function to simply use the ``return value`` statement (prior to
Python 3.3 generators were not allowed to also return values).
In all versions of Python a coroutine that simply wishes to exit
early may use the ``return`` statement without a value.
Functions with this decorator return a `.Future`. Additionally,
they may be called with a ``callback`` keyword argument, which
will be invoked with the future's result when it resolves. If the
coroutine fails, the callback will not be run and an exception
will be raised into the surrounding `.StackContext`. The
``callback`` argument is not visible inside the decorated
function; it is handled by the decorator itself.
From the caller's perspective, ``@gen.coroutine`` is similar to
the combination of ``@return_future`` and ``@gen.engine``.
"""
return _make_coroutine_wrapper(func, replace_callback=True)
def _make_coroutine_wrapper(func, replace_callback):
"""The inner workings of ``@gen.coroutine`` and ``@gen.engine``.
The two decorators differ in their treatment of the ``callback``
argument, so we cannot simply implement ``@engine`` in terms of
``@coroutine``.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
future = TracebackFuture()
if replace_callback and 'callback' in kwargs:
callback = kwargs.pop('callback')
IOLoop.current().add_future(
future, lambda future: callback(future.result()))
try:
result = func(*args, **kwargs)
except (Return, StopIteration) as e:
result = getattr(e, 'value', None)
except Exception:
future.set_exc_info(sys.exc_info())
return future
else:
if isinstance(result, types.GeneratorType):
# Inline the first iteration of Runner.run. This lets us
# avoid the cost of creating a Runner when the coroutine
# never actually yields, which in turn allows us to
# use "optional" coroutines in critical path code without
# performance penalty for the synchronous case.
try:
orig_stack_contexts = stack_context._state.contexts
yielded = next(result)
if stack_context._state.contexts is not orig_stack_contexts:
yielded = TracebackFuture()
yielded.set_exception(
stack_context.StackContextInconsistentError(
'stack_context inconsistency (probably caused '
'by yield within a "with StackContext" block)'))
except (StopIteration, Return) as e:
future.set_result(getattr(e, 'value', None))
except Exception:
future.set_exc_info(sys.exc_info())
else:
Runner(result, future, yielded)
return future
future.set_result(result)
return future
return wrapper
class Return(Exception):
"""Special exception to return a value from a `coroutine`.
If this exception is raised, its value argument is used as the
result of the coroutine::
@gen.coroutine
def fetch_json(url):
response = yield AsyncHTTPClient().fetch(url)
raise gen.Return(json_decode(response.body))
In Python 3.3, this exception is no longer necessary: the ``return``
statement can be used directly to return a value (previously
``yield`` and ``return`` with a value could not be combined in the
same function).
By analogy with the return statement, the value argument is optional,
but it is never necessary to ``raise gen.Return()``. The ``return``
statement can be used with no arguments instead.
"""
def __init__(self, value=None):
super(Return, self).__init__()
self.value = value
class YieldPoint(object):
"""Base class for objects that may be yielded from the generator.
.. deprecated:: 4.0
Use `Futures <.Future>` instead.
"""
def start(self, runner):
"""Called by the runner after the generator has yielded.
No other methods will be called on this object before ``start``.
"""
raise NotImplementedError()
def is_ready(self):
"""Called by the runner to determine whether to resume the generator.
Returns a boolean; may be called more than once.
"""
raise NotImplementedError()
def get_result(self):
"""Returns the value to use as the result of the yield expression.
This method will only be called once, and only after `is_ready`
has returned true.
"""
raise NotImplementedError()
class Callback(YieldPoint):
"""Returns a callable object that will allow a matching `Wait` to proceed.
The key may be any value suitable for use as a dictionary key, and is
used to match ``Callbacks`` to their corresponding ``Waits``. The key
must be unique among outstanding callbacks within a single run of the
generator function, but may be reused across different runs of the same
function (so constants generally work fine).
The callback may be called with zero or one arguments; if an argument
is given it will be returned by `Wait`.
.. deprecated:: 4.0
Use `Futures <.Future>` instead.
"""
def __init__(self, key):
self.key = key
def start(self, runner):
self.runner = runner
runner.register_callback(self.key)
def is_ready(self):
return True
def get_result(self):
return self.runner.result_callback(self.key)
class Wait(YieldPoint):
"""Returns the argument passed to the result of a previous `Callback`.
.. deprecated:: 4.0
Use `Futures <.Future>` instead.
"""
def __init__(self, key):
self.key = key
def start(self, runner):
self.runner = runner
def is_ready(self):
return self.runner.is_ready(self.key)
def get_result(self):
return self.runner.pop_result(self.key)
class WaitAll(YieldPoint):
"""Returns the results of multiple previous `Callbacks <Callback>`.
The argument is a sequence of `Callback` keys, and the result is
a list of results in the same order.
`WaitAll` is equivalent to yielding a list of `Wait` objects.
.. deprecated:: 4.0
Use `Futures <.Future>` instead.
"""
def __init__(self, keys):
self.keys = keys
def start(self, runner):
self.runner = runner
def is_ready(self):
return all(self.runner.is_ready(key) for key in self.keys)
def get_result(self):
return [self.runner.pop_result(key) for key in self.keys]
def Task(func, *args, **kwargs):
"""Adapts a callback-based asynchronous function for use in coroutines.
Takes a function (and optional additional arguments) and runs it with
those arguments plus a ``callback`` keyword argument. The argument passed
to the callback is returned as the result of the yield expression.
.. versionchanged:: 4.0
``gen.Task`` is now a function that returns a `.Future`, instead of
a subclass of `YieldPoint`. It still behaves the same way when
yielded.
"""
future = Future()
def handle_exception(typ, value, tb):
if future.done():
return False
future.set_exc_info((typ, value, tb))
return True
def set_result(result):
if future.done():
return
future.set_result(result)
with stack_context.ExceptionStackContext(handle_exception):
func(*args, callback=_argument_adapter(set_result), **kwargs)
return future
class YieldFuture(YieldPoint):
def __init__(self, future, io_loop=None):
self.future = future
self.io_loop = io_loop or IOLoop.current()
def start(self, runner):
if not self.future.done():
self.runner = runner
self.key = object()
runner.register_callback(self.key)
self.io_loop.add_future(self.future, runner.result_callback(self.key))
else:
self.runner = None
self.result = self.future.result()
def is_ready(self):
if self.runner is not None:
return self.runner.is_ready(self.key)
else:
return True
def get_result(self):
if self.runner is not None:
return self.runner.pop_result(self.key).result()
else:
return self.result
class Multi(YieldPoint):
"""Runs multiple asynchronous operations in parallel.
Takes a list of ``YieldPoints`` or ``Futures`` and returns a list of
their responses. It is not necessary to call `Multi` explicitly,
since the engine will do so automatically when the generator yields
a list of ``YieldPoints`` or a mixture of ``YieldPoints`` and ``Futures``.
Instead of a list, the argument may also be a dictionary whose values are
Futures, in which case a parallel dictionary is returned mapping the same
keys to their results.
"""
def __init__(self, children):
self.keys = None
if isinstance(children, dict):
self.keys = list(children.keys())
children = children.values()
self.children = []
for i in children:
if is_future(i):
i = YieldFuture(i)
self.children.append(i)
assert all(isinstance(i, YieldPoint) for i in self.children)
self.unfinished_children = set(self.children)
def start(self, runner):
for i in self.children:
i.start(runner)
def is_ready(self):
finished = list(itertools.takewhile(
lambda i: i.is_ready(), self.unfinished_children))
self.unfinished_children.difference_update(finished)
return not self.unfinished_children
def get_result(self):
result = (i.get_result() for i in self.children)
if self.keys is not None:
return dict(zip(self.keys, result))
else:
return list(result)
def multi_future(children):
"""Wait for multiple asynchronous futures in parallel.
Takes a list of ``Futures`` (but *not* other ``YieldPoints``) and returns
a new Future that resolves when all the other Futures are done.
If all the ``Futures`` succeeded, the returned Future's result is a list
of their results. If any failed, the returned Future raises the exception
of the first one to fail.
Instead of a list, the argument may also be a dictionary whose values are
Futures, in which case a parallel dictionary is returned mapping the same
keys to their results.
It is not necessary to call `multi_future` explcitly, since the engine will
do so automatically when the generator yields a list of `Futures`.
This function is faster than the `Multi` `YieldPoint` because it does not
require the creation of a stack context.
.. versionadded:: 4.0
"""
if isinstance(children, dict):
keys = list(children.keys())
children = children.values()
else:
keys = None
assert all(is_future(i) for i in children)
unfinished_children = set(children)
future = Future()
if not children:
future.set_result({} if keys is not None else [])
def callback(f):
unfinished_children.remove(f)
if not unfinished_children:
try:
result_list = [i.result() for i in children]
except Exception:
future.set_exc_info(sys.exc_info())
else:
if keys is not None:
future.set_result(dict(zip(keys, result_list)))
else:
future.set_result(result_list)
for f in children:
f.add_done_callback(callback)
return future
def maybe_future(x):
"""Converts ``x`` into a `.Future`.
If ``x`` is already a `.Future`, it is simply returned; otherwise
it is wrapped in a new `.Future`. This is suitable for use as
``result = yield gen.maybe_future(f())`` when you don't know whether
``f()`` returns a `.Future` or not.
"""
if is_future(x):
return x
else:
fut = Future()
fut.set_result(x)
return fut
def with_timeout(timeout, future, io_loop=None):
"""Wraps a `.Future` in a timeout.
Raises `TimeoutError` if the input future does not complete before
``timeout``, which may be specified in any form allowed by
`.IOLoop.add_timeout` (i.e. a `datetime.timedelta` or an absolute time
relative to `.IOLoop.time`)
Currently only supports Futures, not other `YieldPoint` classes.
.. versionadded:: 4.0
"""
# TODO: allow yield points in addition to futures?
# Tricky to do with stack_context semantics.
#
# It's tempting to optimize this by cancelling the input future on timeout
# instead of creating a new one, but A) we can't know if we are the only
# one waiting on the input future, so cancelling it might disrupt other
# callers and B) concurrent futures can only be cancelled while they are
# in the queue, so cancellation cannot reliably bound our waiting time.
result = Future()
chain_future(future, result)
if io_loop is None:
io_loop = IOLoop.current()
timeout_handle = io_loop.add_timeout(
timeout,
lambda: result.set_exception(TimeoutError("Timeout")))
if isinstance(future, Future):
# We know this future will resolve on the IOLoop, so we don't
# need the extra thread-safety of IOLoop.add_future (and we also
# don't care about StackContext here.
future.add_done_callback(
lambda future: io_loop.remove_timeout(timeout_handle))
else:
# concurrent.futures.Futures may resolve on any thread, so we
# need to route them back to the IOLoop.
io_loop.add_future(
future, lambda future: io_loop.remove_timeout(timeout_handle))
return result
_null_future = Future()
_null_future.set_result(None)
moment = Future()
moment.__doc__ = \
"""A special object which may be yielded to allow the IOLoop to run for
one iteration.
This is not needed in normal use but it can be helpful in long-running
coroutines that are likely to yield Futures that are ready instantly.
Usage: ``yield gen.moment``
.. versionadded:: 4.0
"""
moment.set_result(None)
class Runner(object):
"""Internal implementation of `tornado.gen.engine`.
Maintains information about pending callbacks and their results.
The results of the generator are stored in ``result_future`` (a
`.TracebackFuture`)
"""
def __init__(self, gen, result_future, first_yielded):
self.gen = gen
self.result_future = result_future
self.future = _null_future
self.yield_point = None
self.pending_callbacks = None
self.results = None
self.running = False
self.finished = False
self.had_exception = False
self.io_loop = IOLoop.current()
# For efficiency, we do not create a stack context until we
# reach a YieldPoint (stack contexts are required for the historical
# semantics of YieldPoints, but not for Futures). When we have
# done so, this field will be set and must be called at the end
# of the coroutine.
self.stack_context_deactivate = None
if self.handle_yield(first_yielded):
self.run()
def register_callback(self, key):
"""Adds ``key`` to the list of callbacks."""
if self.pending_callbacks is None:
# Lazily initialize the old-style YieldPoint data structures.
self.pending_callbacks = set()
self.results = {}
if key in self.pending_callbacks:
raise KeyReuseError("key %r is already pending" % (key,))
self.pending_callbacks.add(key)
def is_ready(self, key):
"""Returns true if a result is available for ``key``."""
if self.pending_callbacks is None or key not in self.pending_callbacks:
raise UnknownKeyError("key %r is not pending" % (key,))
return key in self.results
def set_result(self, key, result):
"""Sets the result for ``key`` and attempts to resume the generator."""
self.results[key] = result
if self.yield_point is not None and self.yield_point.is_ready():
try:
self.future.set_result(self.yield_point.get_result())
except:
self.future.set_exc_info(sys.exc_info())
self.yield_point = None
self.run()
def pop_result(self, key):
"""Returns the result for ``key`` and unregisters it."""
self.pending_callbacks.remove(key)
return self.results.pop(key)
def run(self):
"""Starts or resumes the generator, running until it reaches a
yield point that is not ready.
"""
if self.running or self.finished:
return
try:
self.running = True
while True:
future = self.future
if not future.done():
return
self.future = None
try:
orig_stack_contexts = stack_context._state.contexts
try:
value = future.result()
except Exception:
self.had_exception = True
yielded = self.gen.throw(*sys.exc_info())
else:
yielded = self.gen.send(value)
if stack_context._state.contexts is not orig_stack_contexts:
self.gen.throw(
stack_context.StackContextInconsistentError(
'stack_context inconsistency (probably caused '
'by yield within a "with StackContext" block)'))
except (StopIteration, Return) as e:
self.finished = True
self.future = _null_future
if self.pending_callbacks and not self.had_exception:
# If we ran cleanly without waiting on all callbacks
# raise an error (really more of a warning). If we
# had an exception then some callbacks may have been
# orphaned, so skip the check in that case.
raise LeakedCallbackError(
"finished without waiting for callbacks %r" %
self.pending_callbacks)
self.result_future.set_result(getattr(e, 'value', None))
self.result_future = None
self._deactivate_stack_context()
return
except Exception:
self.finished = True
self.future = _null_future
self.result_future.set_exc_info(sys.exc_info())
self.result_future = None
self._deactivate_stack_context()
return
if not self.handle_yield(yielded):
return
finally:
self.running = False
def handle_yield(self, yielded):
if isinstance(yielded, list):
if all(is_future(f) for f in yielded):
yielded = multi_future(yielded)
else:
yielded = Multi(yielded)
elif isinstance(yielded, dict):
if all(is_future(f) for f in yielded.values()):
yielded = multi_future(yielded)
else:
yielded = Multi(yielded)
if isinstance(yielded, YieldPoint):
self.future = TracebackFuture()
def start_yield_point():
try:
yielded.start(self)
if yielded.is_ready():
self.future.set_result(
yielded.get_result())
else:
self.yield_point = yielded
except Exception:
self.future = TracebackFuture()
self.future.set_exc_info(sys.exc_info())
if self.stack_context_deactivate is None:
# Start a stack context if this is the first
# YieldPoint we've seen.
with stack_context.ExceptionStackContext(
self.handle_exception) as deactivate:
self.stack_context_deactivate = deactivate
def cb():
start_yield_point()
self.run()
self.io_loop.add_callback(cb)
return False
else:
start_yield_point()
elif is_future(yielded):
self.future = yielded
if not self.future.done() or self.future is moment:
self.io_loop.add_future(
self.future, lambda f: self.run())
return False
else:
self.future = TracebackFuture()
self.future.set_exception(BadYieldError(
"yielded unknown object %r" % (yielded,)))
return True
def result_callback(self, key):
return stack_context.wrap(_argument_adapter(
functools.partial(self.set_result, key)))
def handle_exception(self, typ, value, tb):
if not self.running and not self.finished:
self.future = TracebackFuture()
self.future.set_exc_info((typ, value, tb))
self.run()
return True
else:
return False
def _deactivate_stack_context(self):
if self.stack_context_deactivate is not None:
self.stack_context_deactivate()
self.stack_context_deactivate = None
Arguments = collections.namedtuple('Arguments', ['args', 'kwargs'])
def _argument_adapter(callback):
"""Returns a function that when invoked runs ``callback`` with one arg.
If the function returned by this function is called with exactly
one argument, that argument is passed to ``callback``. Otherwise
the args tuple and kwargs dict are wrapped in an `Arguments` object.
"""
def wrapper(*args, **kwargs):
if kwargs or len(args) > 1:
callback(Arguments(args, kwargs))
elif args:
callback(args[0])
else:
callback(None)
return wrapper

View file

@ -0,0 +1,651 @@
#!/usr/bin/env python
#
# Copyright 2014 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Client and server implementations of HTTP/1.x.
.. versionadded:: 4.0
"""
from __future__ import absolute_import, division, print_function, with_statement
from tornado.concurrent import Future
from tornado.escape import native_str, utf8
from tornado import gen
from tornado import httputil
from tornado import iostream
from tornado.log import gen_log, app_log
from tornado import stack_context
from tornado.util import GzipDecompressor
class _QuietException(Exception):
def __init__(self):
pass
class _ExceptionLoggingContext(object):
"""Used with the ``with`` statement when calling delegate methods to
log any exceptions with the given logger. Any exceptions caught are
converted to _QuietException
"""
def __init__(self, logger):
self.logger = logger
def __enter__(self):
pass
def __exit__(self, typ, value, tb):
if value is not None:
self.logger.error("Uncaught exception", exc_info=(typ, value, tb))
raise _QuietException
class HTTP1ConnectionParameters(object):
"""Parameters for `.HTTP1Connection` and `.HTTP1ServerConnection`.
"""
def __init__(self, no_keep_alive=False, chunk_size=None,
max_header_size=None, header_timeout=None, max_body_size=None,
body_timeout=None, decompress=False):
"""
:arg bool no_keep_alive: If true, always close the connection after
one request.
:arg int chunk_size: how much data to read into memory at once
:arg int max_header_size: maximum amount of data for HTTP headers
:arg float header_timeout: how long to wait for all headers (seconds)
:arg int max_body_size: maximum amount of data for body
:arg float body_timeout: how long to wait while reading body (seconds)
:arg bool decompress: if true, decode incoming
``Content-Encoding: gzip``
"""
self.no_keep_alive = no_keep_alive
self.chunk_size = chunk_size or 65536
self.max_header_size = max_header_size or 65536
self.header_timeout = header_timeout
self.max_body_size = max_body_size
self.body_timeout = body_timeout
self.decompress = decompress
class HTTP1Connection(httputil.HTTPConnection):
"""Implements the HTTP/1.x protocol.
This class can be on its own for clients, or via `HTTP1ServerConnection`
for servers.
"""
def __init__(self, stream, is_client, params=None, context=None):
"""
:arg stream: an `.IOStream`
:arg bool is_client: client or server
:arg params: a `.HTTP1ConnectionParameters` instance or ``None``
:arg context: an opaque application-defined object that can be accessed
as ``connection.context``.
"""
self.is_client = is_client
self.stream = stream
if params is None:
params = HTTP1ConnectionParameters()
self.params = params
self.context = context
self.no_keep_alive = params.no_keep_alive
# The body limits can be altered by the delegate, so save them
# here instead of just referencing self.params later.
self._max_body_size = (self.params.max_body_size or
self.stream.max_buffer_size)
self._body_timeout = self.params.body_timeout
# _write_finished is set to True when finish() has been called,
# i.e. there will be no more data sent. Data may still be in the
# stream's write buffer.
self._write_finished = False
# True when we have read the entire incoming body.
self._read_finished = False
# _finish_future resolves when all data has been written and flushed
# to the IOStream.
self._finish_future = Future()
# If true, the connection should be closed after this request
# (after the response has been written in the server side,
# and after it has been read in the client)
self._disconnect_on_finish = False
self._clear_callbacks()
# Save the start lines after we read or write them; they
# affect later processing (e.g. 304 responses and HEAD methods
# have content-length but no bodies)
self._request_start_line = None
self._response_start_line = None
self._request_headers = None
# True if we are writing output with chunked encoding.
self._chunking_output = None
# While reading a body with a content-length, this is the
# amount left to read.
self._expected_content_remaining = None
# A Future for our outgoing writes, returned by IOStream.write.
self._pending_write = None
def read_response(self, delegate):
"""Read a single HTTP response.
Typical client-mode usage is to write a request using `write_headers`,
`write`, and `finish`, and then call ``read_response``.
:arg delegate: a `.HTTPMessageDelegate`
Returns a `.Future` that resolves to None after the full response has
been read.
"""
if self.params.decompress:
delegate = _GzipMessageDelegate(delegate, self.params.chunk_size)
return self._read_message(delegate)
@gen.coroutine
def _read_message(self, delegate):
need_delegate_close = False
try:
header_future = self.stream.read_until_regex(
b"\r?\n\r?\n",
max_bytes=self.params.max_header_size)
if self.params.header_timeout is None:
header_data = yield header_future
else:
try:
header_data = yield gen.with_timeout(
self.stream.io_loop.time() + self.params.header_timeout,
header_future,
io_loop=self.stream.io_loop)
except gen.TimeoutError:
self.close()
raise gen.Return(False)
start_line, headers = self._parse_headers(header_data)
if self.is_client:
start_line = httputil.parse_response_start_line(start_line)
self._response_start_line = start_line
else:
start_line = httputil.parse_request_start_line(start_line)
self._request_start_line = start_line
self._request_headers = headers
self._disconnect_on_finish = not self._can_keep_alive(
start_line, headers)
need_delegate_close = True
with _ExceptionLoggingContext(app_log):
header_future = delegate.headers_received(start_line, headers)
if header_future is not None:
yield header_future
if self.stream is None:
# We've been detached.
need_delegate_close = False
raise gen.Return(False)
skip_body = False
if self.is_client:
if (self._request_start_line is not None and
self._request_start_line.method == 'HEAD'):
skip_body = True
code = start_line.code
if code == 304:
skip_body = True
if code >= 100 and code < 200:
# TODO: client delegates will get headers_received twice
# in the case of a 100-continue. Document or change?
yield self._read_message(delegate)
else:
if (headers.get("Expect") == "100-continue" and
not self._write_finished):
self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
if not skip_body:
body_future = self._read_body(headers, delegate)
if body_future is not None:
if self._body_timeout is None:
yield body_future
else:
try:
yield gen.with_timeout(
self.stream.io_loop.time() + self._body_timeout,
body_future, self.stream.io_loop)
except gen.TimeoutError:
gen_log.info("Timeout reading body from %s",
self.context)
self.stream.close()
raise gen.Return(False)
self._read_finished = True
if not self._write_finished or self.is_client:
need_delegate_close = False
with _ExceptionLoggingContext(app_log):
delegate.finish()
# If we're waiting for the application to produce an asynchronous
# response, and we're not detached, register a close callback
# on the stream (we didn't need one while we were reading)
if (not self._finish_future.done() and
self.stream is not None and
not self.stream.closed()):
self.stream.set_close_callback(self._on_connection_close)
yield self._finish_future
if self.is_client and self._disconnect_on_finish:
self.close()
if self.stream is None:
raise gen.Return(False)
except httputil.HTTPInputError as e:
gen_log.info("Malformed HTTP message from %s: %s",
self.context, e)
self.close()
raise gen.Return(False)
finally:
if need_delegate_close:
with _ExceptionLoggingContext(app_log):
delegate.on_connection_close()
self._clear_callbacks()
raise gen.Return(True)
def _clear_callbacks(self):
"""Clears the callback attributes.
This allows the request handler to be garbage collected more
quickly in CPython by breaking up reference cycles.
"""
self._write_callback = None
self._write_future = None
self._close_callback = None
if self.stream is not None:
self.stream.set_close_callback(None)
def set_close_callback(self, callback):
"""Sets a callback that will be run when the connection is closed.
.. deprecated:: 4.0
Use `.HTTPMessageDelegate.on_connection_close` instead.
"""
self._close_callback = stack_context.wrap(callback)
def _on_connection_close(self):
# Note that this callback is only registered on the IOStream
# when we have finished reading the request and are waiting for
# the application to produce its response.
if self._close_callback is not None:
callback = self._close_callback
self._close_callback = None
callback()
if not self._finish_future.done():
self._finish_future.set_result(None)
self._clear_callbacks()
def close(self):
if self.stream is not None:
self.stream.close()
self._clear_callbacks()
if not self._finish_future.done():
self._finish_future.set_result(None)
def detach(self):
"""Take control of the underlying stream.
Returns the underlying `.IOStream` object and stops all further
HTTP processing. May only be called during
`.HTTPMessageDelegate.headers_received`. Intended for implementing
protocols like websockets that tunnel over an HTTP handshake.
"""
self._clear_callbacks()
stream = self.stream
self.stream = None
return stream
def set_body_timeout(self, timeout):
"""Sets the body timeout for a single request.
Overrides the value from `.HTTP1ConnectionParameters`.
"""
self._body_timeout = timeout
def set_max_body_size(self, max_body_size):
"""Sets the body size limit for a single request.
Overrides the value from `.HTTP1ConnectionParameters`.
"""
self._max_body_size = max_body_size
def write_headers(self, start_line, headers, chunk=None, callback=None):
"""Implements `.HTTPConnection.write_headers`."""
if self.is_client:
self._request_start_line = start_line
# Client requests with a non-empty body must have either a
# Content-Length or a Transfer-Encoding.
self._chunking_output = (
start_line.method in ('POST', 'PUT', 'PATCH') and
'Content-Length' not in headers and
'Transfer-Encoding' not in headers)
else:
self._response_start_line = start_line
self._chunking_output = (
# TODO: should this use
# self._request_start_line.version or
# start_line.version?
self._request_start_line.version == 'HTTP/1.1' and
# 304 responses have no body (not even a zero-length body), and so
# should not have either Content-Length or Transfer-Encoding.
# headers.
start_line.code != 304 and
# No need to chunk the output if a Content-Length is specified.
'Content-Length' not in headers and
# Applications are discouraged from touching Transfer-Encoding,
# but if they do, leave it alone.
'Transfer-Encoding' not in headers)
# If a 1.0 client asked for keep-alive, add the header.
if (self._request_start_line.version == 'HTTP/1.0' and
(self._request_headers.get('Connection', '').lower()
== 'keep-alive')):
headers['Connection'] = 'Keep-Alive'
if self._chunking_output:
headers['Transfer-Encoding'] = 'chunked'
if (not self.is_client and
(self._request_start_line.method == 'HEAD' or
start_line.code == 304)):
self._expected_content_remaining = 0
elif 'Content-Length' in headers:
self._expected_content_remaining = int(headers['Content-Length'])
else:
self._expected_content_remaining = None
lines = [utf8("%s %s %s" % start_line)]
lines.extend([utf8(n) + b": " + utf8(v) for n, v in headers.get_all()])
for line in lines:
if b'\n' in line:
raise ValueError('Newline in header: ' + repr(line))
future = None
if self.stream.closed():
future = self._write_future = Future()
future.set_exception(iostream.StreamClosedError())
else:
if callback is not None:
self._write_callback = stack_context.wrap(callback)
else:
future = self._write_future = Future()
data = b"\r\n".join(lines) + b"\r\n\r\n"
if chunk:
data += self._format_chunk(chunk)
self._pending_write = self.stream.write(data)
self._pending_write.add_done_callback(self._on_write_complete)
return future
def _format_chunk(self, chunk):
if self._expected_content_remaining is not None:
self._expected_content_remaining -= len(chunk)
if self._expected_content_remaining < 0:
# Close the stream now to stop further framing errors.
self.stream.close()
raise httputil.HTTPOutputError(
"Tried to write more data than Content-Length")
if self._chunking_output and chunk:
# Don't write out empty chunks because that means END-OF-STREAM
# with chunked encoding
return utf8("%x" % len(chunk)) + b"\r\n" + chunk + b"\r\n"
else:
return chunk
def write(self, chunk, callback=None):
"""Implements `.HTTPConnection.write`.
For backwards compatibility is is allowed but deprecated to
skip `write_headers` and instead call `write()` with a
pre-encoded header block.
"""
future = None
if self.stream.closed():
future = self._write_future = Future()
self._write_future.set_exception(iostream.StreamClosedError())
else:
if callback is not None:
self._write_callback = stack_context.wrap(callback)
else:
future = self._write_future = Future()
self._pending_write = self.stream.write(self._format_chunk(chunk))
self._pending_write.add_done_callback(self._on_write_complete)
return future
def finish(self):
"""Implements `.HTTPConnection.finish`."""
if (self._expected_content_remaining is not None and
self._expected_content_remaining != 0 and
not self.stream.closed()):
self.stream.close()
raise httputil.HTTPOutputError(
"Tried to write %d bytes less than Content-Length" %
self._expected_content_remaining)
if self._chunking_output:
if not self.stream.closed():
self._pending_write = self.stream.write(b"0\r\n\r\n")
self._pending_write.add_done_callback(self._on_write_complete)
self._write_finished = True
# If the app finished the request while we're still reading,
# divert any remaining data away from the delegate and
# close the connection when we're done sending our response.
# Closing the connection is the only way to avoid reading the
# whole input body.
if not self._read_finished:
self._disconnect_on_finish = True
# No more data is coming, so instruct TCP to send any remaining
# data immediately instead of waiting for a full packet or ack.
self.stream.set_nodelay(True)
if self._pending_write is None:
self._finish_request(None)
else:
self._pending_write.add_done_callback(self._finish_request)
def _on_write_complete(self, future):
if self._write_callback is not None:
callback = self._write_callback
self._write_callback = None
self.stream.io_loop.add_callback(callback)
if self._write_future is not None:
future = self._write_future
self._write_future = None
future.set_result(None)
def _can_keep_alive(self, start_line, headers):
if self.params.no_keep_alive:
return False
connection_header = headers.get("Connection")
if connection_header is not None:
connection_header = connection_header.lower()
if start_line.version == "HTTP/1.1":
return connection_header != "close"
elif ("Content-Length" in headers
or start_line.method in ("HEAD", "GET")):
return connection_header == "keep-alive"
return False
def _finish_request(self, future):
self._clear_callbacks()
if not self.is_client and self._disconnect_on_finish:
self.close()
return
# Turn Nagle's algorithm back on, leaving the stream in its
# default state for the next request.
self.stream.set_nodelay(False)
if not self._finish_future.done():
self._finish_future.set_result(None)
def _parse_headers(self, data):
data = native_str(data.decode('latin1'))
eol = data.find("\r\n")
start_line = data[:eol]
try:
headers = httputil.HTTPHeaders.parse(data[eol:])
except ValueError:
# probably form split() if there was no ':' in the line
raise httputil.HTTPInputError("Malformed HTTP headers: %r" %
data[eol:100])
return start_line, headers
def _read_body(self, headers, delegate):
content_length = headers.get("Content-Length")
if content_length:
content_length = int(content_length)
if content_length > self._max_body_size:
raise httputil.HTTPInputError("Content-Length too long")
return self._read_fixed_body(content_length, delegate)
if headers.get("Transfer-Encoding") == "chunked":
return self._read_chunked_body(delegate)
if self.is_client:
return self._read_body_until_close(delegate)
return None
@gen.coroutine
def _read_fixed_body(self, content_length, delegate):
while content_length > 0:
body = yield self.stream.read_bytes(
min(self.params.chunk_size, content_length), partial=True)
content_length -= len(body)
if not self._write_finished or self.is_client:
with _ExceptionLoggingContext(app_log):
yield gen.maybe_future(delegate.data_received(body))
@gen.coroutine
def _read_chunked_body(self, delegate):
# TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
total_size = 0
while True:
chunk_len = yield self.stream.read_until(b"\r\n", max_bytes=64)
chunk_len = int(chunk_len.strip(), 16)
if chunk_len == 0:
return
total_size += chunk_len
if total_size > self._max_body_size:
raise httputil.HTTPInputError("chunked body too large")
bytes_to_read = chunk_len
while bytes_to_read:
chunk = yield self.stream.read_bytes(
min(bytes_to_read, self.params.chunk_size), partial=True)
bytes_to_read -= len(chunk)
if not self._write_finished or self.is_client:
with _ExceptionLoggingContext(app_log):
yield gen.maybe_future(delegate.data_received(chunk))
# chunk ends with \r\n
crlf = yield self.stream.read_bytes(2)
assert crlf == b"\r\n"
@gen.coroutine
def _read_body_until_close(self, delegate):
body = yield self.stream.read_until_close()
if not self._write_finished or self.is_client:
with _ExceptionLoggingContext(app_log):
delegate.data_received(body)
class _GzipMessageDelegate(httputil.HTTPMessageDelegate):
"""Wraps an `HTTPMessageDelegate` to decode ``Content-Encoding: gzip``.
"""
def __init__(self, delegate, chunk_size):
self._delegate = delegate
self._chunk_size = chunk_size
self._decompressor = None
def headers_received(self, start_line, headers):
if headers.get("Content-Encoding") == "gzip":
self._decompressor = GzipDecompressor()
# Downstream delegates will only see uncompressed data,
# so rename the content-encoding header.
# (but note that curl_httpclient doesn't do this).
headers.add("X-Consumed-Content-Encoding",
headers["Content-Encoding"])
del headers["Content-Encoding"]
return self._delegate.headers_received(start_line, headers)
@gen.coroutine
def data_received(self, chunk):
if self._decompressor:
compressed_data = chunk
while compressed_data:
decompressed = self._decompressor.decompress(
compressed_data, self._chunk_size)
if decompressed:
yield gen.maybe_future(
self._delegate.data_received(decompressed))
compressed_data = self._decompressor.unconsumed_tail
else:
yield gen.maybe_future(self._delegate.data_received(chunk))
def finish(self):
if self._decompressor is not None:
tail = self._decompressor.flush()
if tail:
# I believe the tail will always be empty (i.e.
# decompress will return all it can). The purpose
# of the flush call is to detect errors such
# as truncated input. But in case it ever returns
# anything, treat it as an extra chunk
self._delegate.data_received(tail)
return self._delegate.finish()
class HTTP1ServerConnection(object):
"""An HTTP/1.x server."""
def __init__(self, stream, params=None, context=None):
"""
:arg stream: an `.IOStream`
:arg params: a `.HTTP1ConnectionParameters` or None
:arg context: an opaque application-defined object that is accessible
as ``connection.context``
"""
self.stream = stream
if params is None:
params = HTTP1ConnectionParameters()
self.params = params
self.context = context
self._serving_future = None
@gen.coroutine
def close(self):
"""Closes the connection.
Returns a `.Future` that resolves after the serving loop has exited.
"""
self.stream.close()
# Block until the serving loop is done, but ignore any exceptions
# (start_serving is already responsible for logging them).
try:
yield self._serving_future
except Exception:
pass
def start_serving(self, delegate):
"""Starts serving requests on this connection.
:arg delegate: a `.HTTPServerConnectionDelegate`
"""
assert isinstance(delegate, httputil.HTTPServerConnectionDelegate)
self._serving_future = self._server_request_loop(delegate)
# Register the future on the IOLoop so its errors get logged.
self.stream.io_loop.add_future(self._serving_future,
lambda f: f.result())
@gen.coroutine
def _server_request_loop(self, delegate):
try:
while True:
conn = HTTP1Connection(self.stream, False,
self.params, self.context)
request_delegate = delegate.start_request(self, conn)
try:
ret = yield conn.read_response(request_delegate)
except (iostream.StreamClosedError,
iostream.UnsatisfiableReadError):
return
except _QuietException:
# This exception was already logged.
conn.close()
return
except Exception:
gen_log.error("Uncaught exception", exc_info=True)
conn.close()
return
if not ret:
return
yield gen.moment
finally:
delegate.on_close(self)

View file

@ -0,0 +1,638 @@
"""Blocking and non-blocking HTTP client interfaces.
This module defines a common interface shared by two implementations,
``simple_httpclient`` and ``curl_httpclient``. Applications may either
instantiate their chosen implementation class directly or use the
`AsyncHTTPClient` class from this module, which selects an implementation
that can be overridden with the `AsyncHTTPClient.configure` method.
The default implementation is ``simple_httpclient``, and this is expected
to be suitable for most users' needs. However, some applications may wish
to switch to ``curl_httpclient`` for reasons such as the following:
* ``curl_httpclient`` has some features not found in ``simple_httpclient``,
including support for HTTP proxies and the ability to use a specified
network interface.
* ``curl_httpclient`` is more likely to be compatible with sites that are
not-quite-compliant with the HTTP spec, or sites that use little-exercised
features of HTTP.
* ``curl_httpclient`` is faster.
* ``curl_httpclient`` was the default prior to Tornado 2.0.
Note that if you are using ``curl_httpclient``, it is highly
recommended that you use a recent version of ``libcurl`` and
``pycurl``. Currently the minimum supported version of libcurl is
7.21.1, and the minimum version of pycurl is 7.18.2. It is highly
recommended that your ``libcurl`` installation is built with
asynchronous DNS resolver (threaded or c-ares), otherwise you may
encounter various problems with request timeouts (for more
information, see
http://curl.haxx.se/libcurl/c/curl_easy_setopt.html#CURLOPTCONNECTTIMEOUTMS
and comments in curl_httpclient.py).
To select ``curl_httpclient``, call `AsyncHTTPClient.configure` at startup::
AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient")
"""
from __future__ import absolute_import, division, print_function, with_statement
import functools
import time
import weakref
from tornado.concurrent import TracebackFuture
from tornado.escape import utf8, native_str
from tornado import httputil, stack_context
from tornado.ioloop import IOLoop
from tornado.util import Configurable
class HTTPClient(object):
"""A blocking HTTP client.
This interface is provided for convenience and testing; most applications
that are running an IOLoop will want to use `AsyncHTTPClient` instead.
Typical usage looks like this::
http_client = httpclient.HTTPClient()
try:
response = http_client.fetch("http://www.google.com/")
print response.body
except httpclient.HTTPError as e:
print "Error:", e
http_client.close()
"""
def __init__(self, async_client_class=None, **kwargs):
self._io_loop = IOLoop()
if async_client_class is None:
async_client_class = AsyncHTTPClient
self._async_client = async_client_class(self._io_loop, **kwargs)
self._closed = False
def __del__(self):
self.close()
def close(self):
"""Closes the HTTPClient, freeing any resources used."""
if not self._closed:
self._async_client.close()
self._io_loop.close()
self._closed = True
def fetch(self, request, **kwargs):
"""Executes a request, returning an `HTTPResponse`.
The request may be either a string URL or an `HTTPRequest` object.
If it is a string, we construct an `HTTPRequest` using any additional
kwargs: ``HTTPRequest(request, **kwargs)``
If an error occurs during the fetch, we raise an `HTTPError`.
"""
response = self._io_loop.run_sync(functools.partial(
self._async_client.fetch, request, **kwargs))
response.rethrow()
return response
class AsyncHTTPClient(Configurable):
"""An non-blocking HTTP client.
Example usage::
def handle_request(response):
if response.error:
print "Error:", response.error
else:
print response.body
http_client = AsyncHTTPClient()
http_client.fetch("http://www.google.com/", handle_request)
The constructor for this class is magic in several respects: It
actually creates an instance of an implementation-specific
subclass, and instances are reused as a kind of pseudo-singleton
(one per `.IOLoop`). The keyword argument ``force_instance=True``
can be used to suppress this singleton behavior. Unless
``force_instance=True`` is used, no arguments other than
``io_loop`` should be passed to the `AsyncHTTPClient` constructor.
The implementation subclass as well as arguments to its
constructor can be set with the static method `configure()`
All `AsyncHTTPClient` implementations support a ``defaults``
keyword argument, which can be used to set default values for
`HTTPRequest` attributes. For example::
AsyncHTTPClient.configure(
None, defaults=dict(user_agent="MyUserAgent"))
# or with force_instance:
client = AsyncHTTPClient(force_instance=True,
defaults=dict(user_agent="MyUserAgent"))
"""
@classmethod
def configurable_base(cls):
return AsyncHTTPClient
@classmethod
def configurable_default(cls):
from tornado.simple_httpclient import SimpleAsyncHTTPClient
return SimpleAsyncHTTPClient
@classmethod
def _async_clients(cls):
attr_name = '_async_client_dict_' + cls.__name__
if not hasattr(cls, attr_name):
setattr(cls, attr_name, weakref.WeakKeyDictionary())
return getattr(cls, attr_name)
def __new__(cls, io_loop=None, force_instance=False, **kwargs):
io_loop = io_loop or IOLoop.current()
if force_instance:
instance_cache = None
else:
instance_cache = cls._async_clients()
if instance_cache is not None and io_loop in instance_cache:
return instance_cache[io_loop]
instance = super(AsyncHTTPClient, cls).__new__(cls, io_loop=io_loop,
**kwargs)
# Make sure the instance knows which cache to remove itself from.
# It can't simply call _async_clients() because we may be in
# __new__(AsyncHTTPClient) but instance.__class__ may be
# SimpleAsyncHTTPClient.
instance._instance_cache = instance_cache
if instance_cache is not None:
instance_cache[instance.io_loop] = instance
return instance
def initialize(self, io_loop, defaults=None):
self.io_loop = io_loop
self.defaults = dict(HTTPRequest._DEFAULTS)
if defaults is not None:
self.defaults.update(defaults)
self._closed = False
def close(self):
"""Destroys this HTTP client, freeing any file descriptors used.
This method is **not needed in normal use** due to the way
that `AsyncHTTPClient` objects are transparently reused.
``close()`` is generally only necessary when either the
`.IOLoop` is also being closed, or the ``force_instance=True``
argument was used when creating the `AsyncHTTPClient`.
No other methods may be called on the `AsyncHTTPClient` after
``close()``.
"""
if self._closed:
return
self._closed = True
if self._instance_cache is not None:
if self._instance_cache.get(self.io_loop) is not self:
raise RuntimeError("inconsistent AsyncHTTPClient cache")
del self._instance_cache[self.io_loop]
def fetch(self, request, callback=None, **kwargs):
"""Executes a request, asynchronously returning an `HTTPResponse`.
The request may be either a string URL or an `HTTPRequest` object.
If it is a string, we construct an `HTTPRequest` using any additional
kwargs: ``HTTPRequest(request, **kwargs)``
This method returns a `.Future` whose result is an
`HTTPResponse`. The ``Future`` will raise an `HTTPError` if
the request returned a non-200 response code.
If a ``callback`` is given, it will be invoked with the `HTTPResponse`.
In the callback interface, `HTTPError` is not automatically raised.
Instead, you must check the response's ``error`` attribute or
call its `~HTTPResponse.rethrow` method.
"""
if self._closed:
raise RuntimeError("fetch() called on closed AsyncHTTPClient")
if not isinstance(request, HTTPRequest):
request = HTTPRequest(url=request, **kwargs)
# We may modify this (to add Host, Accept-Encoding, etc),
# so make sure we don't modify the caller's object. This is also
# where normal dicts get converted to HTTPHeaders objects.
request.headers = httputil.HTTPHeaders(request.headers)
request = _RequestProxy(request, self.defaults)
future = TracebackFuture()
if callback is not None:
callback = stack_context.wrap(callback)
def handle_future(future):
exc = future.exception()
if isinstance(exc, HTTPError) and exc.response is not None:
response = exc.response
elif exc is not None:
response = HTTPResponse(
request, 599, error=exc,
request_time=time.time() - request.start_time)
else:
response = future.result()
self.io_loop.add_callback(callback, response)
future.add_done_callback(handle_future)
def handle_response(response):
if response.error:
future.set_exception(response.error)
else:
future.set_result(response)
self.fetch_impl(request, handle_response)
return future
def fetch_impl(self, request, callback):
raise NotImplementedError()
@classmethod
def configure(cls, impl, **kwargs):
"""Configures the `AsyncHTTPClient` subclass to use.
``AsyncHTTPClient()`` actually creates an instance of a subclass.
This method may be called with either a class object or the
fully-qualified name of such a class (or ``None`` to use the default,
``SimpleAsyncHTTPClient``)
If additional keyword arguments are given, they will be passed
to the constructor of each subclass instance created. The
keyword argument ``max_clients`` determines the maximum number
of simultaneous `~AsyncHTTPClient.fetch()` operations that can
execute in parallel on each `.IOLoop`. Additional arguments
may be supported depending on the implementation class in use.
Example::
AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient")
"""
super(AsyncHTTPClient, cls).configure(impl, **kwargs)
class HTTPRequest(object):
"""HTTP client request object."""
# Default values for HTTPRequest parameters.
# Merged with the values on the request object by AsyncHTTPClient
# implementations.
_DEFAULTS = dict(
connect_timeout=20.0,
request_timeout=20.0,
follow_redirects=True,
max_redirects=5,
decompress_response=True,
proxy_password='',
allow_nonstandard_methods=False,
validate_cert=True)
def __init__(self, url, method="GET", headers=None, body=None,
auth_username=None, auth_password=None, auth_mode=None,
connect_timeout=None, request_timeout=None,
if_modified_since=None, follow_redirects=None,
max_redirects=None, user_agent=None, use_gzip=None,
network_interface=None, streaming_callback=None,
header_callback=None, prepare_curl_callback=None,
proxy_host=None, proxy_port=None, proxy_username=None,
proxy_password=None, allow_nonstandard_methods=None,
validate_cert=None, ca_certs=None,
allow_ipv6=None,
client_key=None, client_cert=None, body_producer=None,
expect_100_continue=False, decompress_response=None):
r"""All parameters except ``url`` are optional.
:arg string url: URL to fetch
:arg string method: HTTP method, e.g. "GET" or "POST"
:arg headers: Additional HTTP headers to pass on the request
:type headers: `~tornado.httputil.HTTPHeaders` or `dict`
:arg body: HTTP request body as a string (byte or unicode; if unicode
the utf-8 encoding will be used)
:arg body_producer: Callable used for lazy/asynchronous request bodies.
It is called with one argument, a ``write`` function, and should
return a `.Future`. It should call the write function with new
data as it becomes available. The write function returns a
`.Future` which can be used for flow control.
Only one of ``body`` and ``body_producer`` may
be specified. ``body_producer`` is not supported on
``curl_httpclient``. When using ``body_producer`` it is recommended
to pass a ``Content-Length`` in the headers as otherwise chunked
encoding will be used, and many servers do not support chunked
encoding on requests. New in Tornado 4.0
:arg string auth_username: Username for HTTP authentication
:arg string auth_password: Password for HTTP authentication
:arg string auth_mode: Authentication mode; default is "basic".
Allowed values are implementation-defined; ``curl_httpclient``
supports "basic" and "digest"; ``simple_httpclient`` only supports
"basic"
:arg float connect_timeout: Timeout for initial connection in seconds
:arg float request_timeout: Timeout for entire request in seconds
:arg if_modified_since: Timestamp for ``If-Modified-Since`` header
:type if_modified_since: `datetime` or `float`
:arg bool follow_redirects: Should redirects be followed automatically
or return the 3xx response?
:arg int max_redirects: Limit for ``follow_redirects``
:arg string user_agent: String to send as ``User-Agent`` header
:arg bool decompress_response: Request a compressed response from
the server and decompress it after downloading. Default is True.
New in Tornado 4.0.
:arg bool use_gzip: Deprecated alias for ``decompress_response``
since Tornado 4.0.
:arg string network_interface: Network interface to use for request.
``curl_httpclient`` only; see note below.
:arg callable streaming_callback: If set, ``streaming_callback`` will
be run with each chunk of data as it is received, and
``HTTPResponse.body`` and ``HTTPResponse.buffer`` will be empty in
the final response.
:arg callable header_callback: If set, ``header_callback`` will
be run with each header line as it is received (including the
first line, e.g. ``HTTP/1.0 200 OK\r\n``, and a final line
containing only ``\r\n``. All lines include the trailing newline
characters). ``HTTPResponse.headers`` will be empty in the final
response. This is most useful in conjunction with
``streaming_callback``, because it's the only way to get access to
header data while the request is in progress.
:arg callable prepare_curl_callback: If set, will be called with
a ``pycurl.Curl`` object to allow the application to make additional
``setopt`` calls.
:arg string proxy_host: HTTP proxy hostname. To use proxies,
``proxy_host`` and ``proxy_port`` must be set; ``proxy_username`` and
``proxy_pass`` are optional. Proxies are currently only supported
with ``curl_httpclient``.
:arg int proxy_port: HTTP proxy port
:arg string proxy_username: HTTP proxy username
:arg string proxy_password: HTTP proxy password
:arg bool allow_nonstandard_methods: Allow unknown values for ``method``
argument?
:arg bool validate_cert: For HTTPS requests, validate the server's
certificate?
:arg string ca_certs: filename of CA certificates in PEM format,
or None to use defaults. See note below when used with
``curl_httpclient``.
:arg bool allow_ipv6: Use IPv6 when available? Default is false in
``simple_httpclient`` and true in ``curl_httpclient``
:arg string client_key: Filename for client SSL key, if any. See
note below when used with ``curl_httpclient``.
:arg string client_cert: Filename for client SSL certificate, if any.
See note below when used with ``curl_httpclient``.
:arg bool expect_100_continue: If true, send the
``Expect: 100-continue`` header and wait for a continue response
before sending the request body. Only supported with
simple_httpclient.
.. note::
When using ``curl_httpclient`` certain options may be
inherited by subsequent fetches because ``pycurl`` does
not allow them to be cleanly reset. This applies to the
``ca_certs``, ``client_key``, ``client_cert``, and
``network_interface`` arguments. If you use these
options, you should pass them on every request (you don't
have to always use the same values, but it's not possible
to mix requests that specify these options with ones that
use the defaults).
.. versionadded:: 3.1
The ``auth_mode`` argument.
.. versionadded:: 4.0
The ``body_producer`` and ``expect_100_continue`` arguments.
"""
# Note that some of these attributes go through property setters
# defined below.
self.headers = headers
if if_modified_since:
self.headers["If-Modified-Since"] = httputil.format_timestamp(
if_modified_since)
self.proxy_host = proxy_host
self.proxy_port = proxy_port
self.proxy_username = proxy_username
self.proxy_password = proxy_password
self.url = url
self.method = method
self.body = body
self.body_producer = body_producer
self.auth_username = auth_username
self.auth_password = auth_password
self.auth_mode = auth_mode
self.connect_timeout = connect_timeout
self.request_timeout = request_timeout
self.follow_redirects = follow_redirects
self.max_redirects = max_redirects
self.user_agent = user_agent
if decompress_response is not None:
self.decompress_response = decompress_response
else:
self.decompress_response = use_gzip
self.network_interface = network_interface
self.streaming_callback = streaming_callback
self.header_callback = header_callback
self.prepare_curl_callback = prepare_curl_callback
self.allow_nonstandard_methods = allow_nonstandard_methods
self.validate_cert = validate_cert
self.ca_certs = ca_certs
self.allow_ipv6 = allow_ipv6
self.client_key = client_key
self.client_cert = client_cert
self.expect_100_continue = expect_100_continue
self.start_time = time.time()
@property
def headers(self):
return self._headers
@headers.setter
def headers(self, value):
if value is None:
self._headers = httputil.HTTPHeaders()
else:
self._headers = value
@property
def body(self):
return self._body
@body.setter
def body(self, value):
self._body = utf8(value)
@property
def body_producer(self):
return self._body_producer
@body_producer.setter
def body_producer(self, value):
self._body_producer = stack_context.wrap(value)
@property
def streaming_callback(self):
return self._streaming_callback
@streaming_callback.setter
def streaming_callback(self, value):
self._streaming_callback = stack_context.wrap(value)
@property
def header_callback(self):
return self._header_callback
@header_callback.setter
def header_callback(self, value):
self._header_callback = stack_context.wrap(value)
@property
def prepare_curl_callback(self):
return self._prepare_curl_callback
@prepare_curl_callback.setter
def prepare_curl_callback(self, value):
self._prepare_curl_callback = stack_context.wrap(value)
class HTTPResponse(object):
"""HTTP Response object.
Attributes:
* request: HTTPRequest object
* code: numeric HTTP status code, e.g. 200 or 404
* reason: human-readable reason phrase describing the status code
* headers: `tornado.httputil.HTTPHeaders` object
* effective_url: final location of the resource after following any
redirects
* buffer: ``cStringIO`` object for response body
* body: response body as string (created on demand from ``self.buffer``)
* error: Exception object, if any
* request_time: seconds from request start to finish
* time_info: dictionary of diagnostic timing information from the request.
Available data are subject to change, but currently uses timings
available from http://curl.haxx.se/libcurl/c/curl_easy_getinfo.html,
plus ``queue``, which is the delay (if any) introduced by waiting for
a slot under `AsyncHTTPClient`'s ``max_clients`` setting.
"""
def __init__(self, request, code, headers=None, buffer=None,
effective_url=None, error=None, request_time=None,
time_info=None, reason=None):
if isinstance(request, _RequestProxy):
self.request = request.request
else:
self.request = request
self.code = code
self.reason = reason or httputil.responses.get(code, "Unknown")
if headers is not None:
self.headers = headers
else:
self.headers = httputil.HTTPHeaders()
self.buffer = buffer
self._body = None
if effective_url is None:
self.effective_url = request.url
else:
self.effective_url = effective_url
if error is None:
if self.code < 200 or self.code >= 300:
self.error = HTTPError(self.code, message=self.reason,
response=self)
else:
self.error = None
else:
self.error = error
self.request_time = request_time
self.time_info = time_info or {}
def _get_body(self):
if self.buffer is None:
return None
elif self._body is None:
self._body = self.buffer.getvalue()
return self._body
body = property(_get_body)
def rethrow(self):
"""If there was an error on the request, raise an `HTTPError`."""
if self.error:
raise self.error
def __repr__(self):
args = ",".join("%s=%r" % i for i in sorted(self.__dict__.items()))
return "%s(%s)" % (self.__class__.__name__, args)
class HTTPError(Exception):
"""Exception thrown for an unsuccessful HTTP request.
Attributes:
* ``code`` - HTTP error integer error code, e.g. 404. Error code 599 is
used when no HTTP response was received, e.g. for a timeout.
* ``response`` - `HTTPResponse` object, if any.
Note that if ``follow_redirects`` is False, redirects become HTTPErrors,
and you can look at ``error.response.headers['Location']`` to see the
destination of the redirect.
"""
def __init__(self, code, message=None, response=None):
self.code = code
message = message or httputil.responses.get(code, "Unknown")
self.response = response
Exception.__init__(self, "HTTP %d: %s" % (self.code, message))
class _RequestProxy(object):
"""Combines an object with a dictionary of defaults.
Used internally by AsyncHTTPClient implementations.
"""
def __init__(self, request, defaults):
self.request = request
self.defaults = defaults
def __getattr__(self, name):
request_attr = getattr(self.request, name)
if request_attr is not None:
return request_attr
elif self.defaults is not None:
return self.defaults.get(name, None)
else:
return None
def main():
from tornado.options import define, options, parse_command_line
define("print_headers", type=bool, default=False)
define("print_body", type=bool, default=True)
define("follow_redirects", type=bool, default=True)
define("validate_cert", type=bool, default=True)
args = parse_command_line()
client = HTTPClient()
for arg in args:
try:
response = client.fetch(arg,
follow_redirects=options.follow_redirects,
validate_cert=options.validate_cert,
)
except HTTPError as e:
if e.response is not None:
response = e.response
else:
raise
if options.print_headers:
print(response.headers)
if options.print_body:
print(native_str(response.body))
client.close()
if __name__ == "__main__":
main()

View file

@ -0,0 +1,297 @@
#!/usr/bin/env python
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""A non-blocking, single-threaded HTTP server.
Typical applications have little direct interaction with the `HTTPServer`
class except to start a server at the beginning of the process
(and even that is often done indirectly via `tornado.web.Application.listen`).
.. versionchanged:: 4.0
The ``HTTPRequest`` class that used to live in this module has been moved
to `tornado.httputil.HTTPServerRequest`. The old name remains as an alias.
"""
from __future__ import absolute_import, division, print_function, with_statement
import socket
from tornado.escape import native_str
from tornado.http1connection import HTTP1ServerConnection, HTTP1ConnectionParameters
from tornado import gen
from tornado import httputil
from tornado import iostream
from tornado import netutil
from tornado.tcpserver import TCPServer
class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
r"""A non-blocking, single-threaded HTTP server.
A server is defined by either a request callback that takes a
`.HTTPServerRequest` as an argument or a `.HTTPServerConnectionDelegate`
instance.
A simple example server that echoes back the URI you requested::
import tornado.httpserver
import tornado.ioloop
def handle_request(request):
message = "You requested %s\n" % request.uri
request.connection.write_headers(
httputil.ResponseStartLine('HTTP/1.1', 200, 'OK'),
{"Content-Length": str(len(message))})
request.connection.write(message)
request.connection.finish()
http_server = tornado.httpserver.HTTPServer(handle_request)
http_server.listen(8888)
tornado.ioloop.IOLoop.instance().start()
Applications should use the methods of `.HTTPConnection` to write
their response.
`HTTPServer` supports keep-alive connections by default
(automatically for HTTP/1.1, or for HTTP/1.0 when the client
requests ``Connection: keep-alive``).
If ``xheaders`` is ``True``, we support the
``X-Real-Ip``/``X-Forwarded-For`` and
``X-Scheme``/``X-Forwarded-Proto`` headers, which override the
remote IP and URI scheme/protocol for all requests. These headers
are useful when running Tornado behind a reverse proxy or load
balancer. The ``protocol`` argument can also be set to ``https``
if Tornado is run behind an SSL-decoding proxy that does not set one of
the supported ``xheaders``.
To make this server serve SSL traffic, send the ``ssl_options`` dictionary
argument with the arguments required for the `ssl.wrap_socket` method,
including ``certfile`` and ``keyfile``. (In Python 3.2+ you can pass
an `ssl.SSLContext` object instead of a dict)::
HTTPServer(applicaton, ssl_options={
"certfile": os.path.join(data_dir, "mydomain.crt"),
"keyfile": os.path.join(data_dir, "mydomain.key"),
})
`HTTPServer` initialization follows one of three patterns (the
initialization methods are defined on `tornado.tcpserver.TCPServer`):
1. `~tornado.tcpserver.TCPServer.listen`: simple single-process::
server = HTTPServer(app)
server.listen(8888)
IOLoop.instance().start()
In many cases, `tornado.web.Application.listen` can be used to avoid
the need to explicitly create the `HTTPServer`.
2. `~tornado.tcpserver.TCPServer.bind`/`~tornado.tcpserver.TCPServer.start`:
simple multi-process::
server = HTTPServer(app)
server.bind(8888)
server.start(0) # Forks multiple sub-processes
IOLoop.instance().start()
When using this interface, an `.IOLoop` must *not* be passed
to the `HTTPServer` constructor. `~.TCPServer.start` will always start
the server on the default singleton `.IOLoop`.
3. `~tornado.tcpserver.TCPServer.add_sockets`: advanced multi-process::
sockets = tornado.netutil.bind_sockets(8888)
tornado.process.fork_processes(0)
server = HTTPServer(app)
server.add_sockets(sockets)
IOLoop.instance().start()
The `~.TCPServer.add_sockets` interface is more complicated,
but it can be used with `tornado.process.fork_processes` to
give you more flexibility in when the fork happens.
`~.TCPServer.add_sockets` can also be used in single-process
servers if you want to create your listening sockets in some
way other than `tornado.netutil.bind_sockets`.
.. versionchanged:: 4.0
Added ``decompress_request``, ``chunk_size``, ``max_header_size``,
``idle_connection_timeout``, ``body_timeout``, ``max_body_size``
arguments. Added support for `.HTTPServerConnectionDelegate`
instances as ``request_callback``.
"""
def __init__(self, request_callback, no_keep_alive=False, io_loop=None,
xheaders=False, ssl_options=None, protocol=None,
decompress_request=False,
chunk_size=None, max_header_size=None,
idle_connection_timeout=None, body_timeout=None,
max_body_size=None, max_buffer_size=None):
self.request_callback = request_callback
self.no_keep_alive = no_keep_alive
self.xheaders = xheaders
self.protocol = protocol
self.conn_params = HTTP1ConnectionParameters(
decompress=decompress_request,
chunk_size=chunk_size,
max_header_size=max_header_size,
header_timeout=idle_connection_timeout or 3600,
max_body_size=max_body_size,
body_timeout=body_timeout)
TCPServer.__init__(self, io_loop=io_loop, ssl_options=ssl_options,
max_buffer_size=max_buffer_size,
read_chunk_size=chunk_size)
self._connections = set()
@gen.coroutine
def close_all_connections(self):
while self._connections:
# Peek at an arbitrary element of the set
conn = next(iter(self._connections))
yield conn.close()
def handle_stream(self, stream, address):
context = _HTTPRequestContext(stream, address,
self.protocol)
conn = HTTP1ServerConnection(
stream, self.conn_params, context)
self._connections.add(conn)
conn.start_serving(self)
def start_request(self, server_conn, request_conn):
return _ServerRequestAdapter(self, request_conn)
def on_close(self, server_conn):
self._connections.remove(server_conn)
class _HTTPRequestContext(object):
def __init__(self, stream, address, protocol):
self.address = address
self.protocol = protocol
# Save the socket's address family now so we know how to
# interpret self.address even after the stream is closed
# and its socket attribute replaced with None.
if stream.socket is not None:
self.address_family = stream.socket.family
else:
self.address_family = None
# In HTTPServerRequest we want an IP, not a full socket address.
if (self.address_family in (socket.AF_INET, socket.AF_INET6) and
address is not None):
self.remote_ip = address[0]
else:
# Unix (or other) socket; fake the remote address.
self.remote_ip = '0.0.0.0'
if protocol:
self.protocol = protocol
elif isinstance(stream, iostream.SSLIOStream):
self.protocol = "https"
else:
self.protocol = "http"
self._orig_remote_ip = self.remote_ip
self._orig_protocol = self.protocol
def __str__(self):
if self.address_family in (socket.AF_INET, socket.AF_INET6):
return self.remote_ip
elif isinstance(self.address, bytes):
# Python 3 with the -bb option warns about str(bytes),
# so convert it explicitly.
# Unix socket addresses are str on mac but bytes on linux.
return native_str(self.address)
else:
return str(self.address)
def _apply_xheaders(self, headers):
"""Rewrite the ``remote_ip`` and ``protocol`` fields."""
# Squid uses X-Forwarded-For, others use X-Real-Ip
ip = headers.get("X-Forwarded-For", self.remote_ip)
ip = ip.split(',')[-1].strip()
ip = headers.get("X-Real-Ip", ip)
if netutil.is_valid_ip(ip):
self.remote_ip = ip
# AWS uses X-Forwarded-Proto
proto_header = headers.get(
"X-Scheme", headers.get("X-Forwarded-Proto",
self.protocol))
if proto_header in ("http", "https"):
self.protocol = proto_header
def _unapply_xheaders(self):
"""Undo changes from `_apply_xheaders`.
Xheaders are per-request so they should not leak to the next
request on the same connection.
"""
self.remote_ip = self._orig_remote_ip
self.protocol = self._orig_protocol
class _ServerRequestAdapter(httputil.HTTPMessageDelegate):
"""Adapts the `HTTPMessageDelegate` interface to the interface expected
by our clients.
"""
def __init__(self, server, connection):
self.server = server
self.connection = connection
self.request = None
if isinstance(server.request_callback,
httputil.HTTPServerConnectionDelegate):
self.delegate = server.request_callback.start_request(connection)
self._chunks = None
else:
self.delegate = None
self._chunks = []
def headers_received(self, start_line, headers):
if self.server.xheaders:
self.connection.context._apply_xheaders(headers)
if self.delegate is None:
self.request = httputil.HTTPServerRequest(
connection=self.connection, start_line=start_line,
headers=headers)
else:
return self.delegate.headers_received(start_line, headers)
def data_received(self, chunk):
if self.delegate is None:
self._chunks.append(chunk)
else:
return self.delegate.data_received(chunk)
def finish(self):
if self.delegate is None:
self.request.body = b''.join(self._chunks)
self.request._parse_body()
self.server.request_callback(self.request)
else:
self.delegate.finish()
self._cleanup()
def on_connection_close(self):
if self.delegate is None:
self._chunks = None
else:
self.delegate.on_connection_close()
self._cleanup()
def _cleanup(self):
if self.server.xheaders:
self.connection.context._unapply_xheaders()
HTTPRequest = httputil.HTTPServerRequest

View file

@ -0,0 +1,844 @@
#!/usr/bin/env python
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""HTTP utility code shared by clients and servers.
This module also defines the `HTTPServerRequest` class which is exposed
via `tornado.web.RequestHandler.request`.
"""
from __future__ import absolute_import, division, print_function, with_statement
import calendar
import collections
import copy
import datetime
import email.utils
import numbers
import re
import time
from tornado.escape import native_str, parse_qs_bytes, utf8
from tornado.log import gen_log
from tornado.util import ObjectDict, bytes_type
try:
import Cookie # py2
except ImportError:
import http.cookies as Cookie # py3
try:
from httplib import responses # py2
except ImportError:
from http.client import responses # py3
# responses is unused in this file, but we re-export it to other files.
# Reference it so pyflakes doesn't complain.
responses
try:
from urllib import urlencode # py2
except ImportError:
from urllib.parse import urlencode # py3
try:
from ssl import SSLError
except ImportError:
# ssl is unavailable on app engine.
class SSLError(Exception):
pass
class _NormalizedHeaderCache(dict):
"""Dynamic cached mapping of header names to Http-Header-Case.
Implemented as a dict subclass so that cache hits are as fast as a
normal dict lookup, without the overhead of a python function
call.
>>> normalized_headers = _NormalizedHeaderCache(10)
>>> normalized_headers["coNtent-TYPE"]
'Content-Type'
"""
def __init__(self, size):
super(_NormalizedHeaderCache, self).__init__()
self.size = size
self.queue = collections.deque()
def __missing__(self, key):
normalized = "-".join([w.capitalize() for w in key.split("-")])
self[key] = normalized
self.queue.append(key)
if len(self.queue) > self.size:
# Limit the size of the cache. LRU would be better, but this
# simpler approach should be fine. In Python 2.7+ we could
# use OrderedDict (or in 3.2+, @functools.lru_cache).
old_key = self.queue.popleft()
del self[old_key]
return normalized
_normalized_headers = _NormalizedHeaderCache(1000)
class HTTPHeaders(dict):
"""A dictionary that maintains ``Http-Header-Case`` for all keys.
Supports multiple values per key via a pair of new methods,
`add()` and `get_list()`. The regular dictionary interface
returns a single value per key, with multiple values joined by a
comma.
>>> h = HTTPHeaders({"content-type": "text/html"})
>>> list(h.keys())
['Content-Type']
>>> h["Content-Type"]
'text/html'
>>> h.add("Set-Cookie", "A=B")
>>> h.add("Set-Cookie", "C=D")
>>> h["set-cookie"]
'A=B,C=D'
>>> h.get_list("set-cookie")
['A=B', 'C=D']
>>> for (k,v) in sorted(h.get_all()):
... print('%s: %s' % (k,v))
...
Content-Type: text/html
Set-Cookie: A=B
Set-Cookie: C=D
"""
def __init__(self, *args, **kwargs):
# Don't pass args or kwargs to dict.__init__, as it will bypass
# our __setitem__
dict.__init__(self)
self._as_list = {}
self._last_key = None
if (len(args) == 1 and len(kwargs) == 0 and
isinstance(args[0], HTTPHeaders)):
# Copy constructor
for k, v in args[0].get_all():
self.add(k, v)
else:
# Dict-style initialization
self.update(*args, **kwargs)
# new public methods
def add(self, name, value):
"""Adds a new value for the given key."""
norm_name = _normalized_headers[name]
self._last_key = norm_name
if norm_name in self:
# bypass our override of __setitem__ since it modifies _as_list
dict.__setitem__(self, norm_name,
native_str(self[norm_name]) + ',' +
native_str(value))
self._as_list[norm_name].append(value)
else:
self[norm_name] = value
def get_list(self, name):
"""Returns all values for the given header as a list."""
norm_name = _normalized_headers[name]
return self._as_list.get(norm_name, [])
def get_all(self):
"""Returns an iterable of all (name, value) pairs.
If a header has multiple values, multiple pairs will be
returned with the same name.
"""
for name, values in self._as_list.items():
for value in values:
yield (name, value)
def parse_line(self, line):
"""Updates the dictionary with a single header line.
>>> h = HTTPHeaders()
>>> h.parse_line("Content-Type: text/html")
>>> h.get('content-type')
'text/html'
"""
if line[0].isspace():
# continuation of a multi-line header
new_part = ' ' + line.lstrip()
self._as_list[self._last_key][-1] += new_part
dict.__setitem__(self, self._last_key,
self[self._last_key] + new_part)
else:
name, value = line.split(":", 1)
self.add(name, value.strip())
@classmethod
def parse(cls, headers):
"""Returns a dictionary from HTTP header text.
>>> h = HTTPHeaders.parse("Content-Type: text/html\\r\\nContent-Length: 42\\r\\n")
>>> sorted(h.items())
[('Content-Length', '42'), ('Content-Type', 'text/html')]
"""
h = cls()
for line in headers.splitlines():
if line:
h.parse_line(line)
return h
# dict implementation overrides
def __setitem__(self, name, value):
norm_name = _normalized_headers[name]
dict.__setitem__(self, norm_name, value)
self._as_list[norm_name] = [value]
def __getitem__(self, name):
return dict.__getitem__(self, _normalized_headers[name])
def __delitem__(self, name):
norm_name = _normalized_headers[name]
dict.__delitem__(self, norm_name)
del self._as_list[norm_name]
def __contains__(self, name):
norm_name = _normalized_headers[name]
return dict.__contains__(self, norm_name)
def get(self, name, default=None):
return dict.get(self, _normalized_headers[name], default)
def update(self, *args, **kwargs):
# dict.update bypasses our __setitem__
for k, v in dict(*args, **kwargs).items():
self[k] = v
def copy(self):
# default implementation returns dict(self), not the subclass
return HTTPHeaders(self)
class HTTPServerRequest(object):
"""A single HTTP request.
All attributes are type `str` unless otherwise noted.
.. attribute:: method
HTTP request method, e.g. "GET" or "POST"
.. attribute:: uri
The requested uri.
.. attribute:: path
The path portion of `uri`
.. attribute:: query
The query portion of `uri`
.. attribute:: version
HTTP version specified in request, e.g. "HTTP/1.1"
.. attribute:: headers
`.HTTPHeaders` dictionary-like object for request headers. Acts like
a case-insensitive dictionary with additional methods for repeated
headers.
.. attribute:: body
Request body, if present, as a byte string.
.. attribute:: remote_ip
Client's IP address as a string. If ``HTTPServer.xheaders`` is set,
will pass along the real IP address provided by a load balancer
in the ``X-Real-Ip`` or ``X-Forwarded-For`` header.
.. versionchanged:: 3.1
The list format of ``X-Forwarded-For`` is now supported.
.. attribute:: protocol
The protocol used, either "http" or "https". If ``HTTPServer.xheaders``
is set, will pass along the protocol used by a load balancer if
reported via an ``X-Scheme`` header.
.. attribute:: host
The requested hostname, usually taken from the ``Host`` header.
.. attribute:: arguments
GET/POST arguments are available in the arguments property, which
maps arguments names to lists of values (to support multiple values
for individual names). Names are of type `str`, while arguments
are byte strings. Note that this is different from
`.RequestHandler.get_argument`, which returns argument values as
unicode strings.
.. attribute:: query_arguments
Same format as ``arguments``, but contains only arguments extracted
from the query string.
.. versionadded:: 3.2
.. attribute:: body_arguments
Same format as ``arguments``, but contains only arguments extracted
from the request body.
.. versionadded:: 3.2
.. attribute:: files
File uploads are available in the files property, which maps file
names to lists of `.HTTPFile`.
.. attribute:: connection
An HTTP request is attached to a single HTTP connection, which can
be accessed through the "connection" attribute. Since connections
are typically kept open in HTTP/1.1, multiple requests can be handled
sequentially on a single connection.
.. versionchanged:: 4.0
Moved from ``tornado.httpserver.HTTPRequest``.
"""
def __init__(self, method=None, uri=None, version="HTTP/1.0", headers=None,
body=None, host=None, files=None, connection=None,
start_line=None):
if start_line is not None:
method, uri, version = start_line
self.method = method
self.uri = uri
self.version = version
self.headers = headers or HTTPHeaders()
self.body = body or ""
# set remote IP and protocol
context = getattr(connection, 'context', None)
self.remote_ip = getattr(context, 'remote_ip')
self.protocol = getattr(context, 'protocol', "http")
self.host = host or self.headers.get("Host") or "127.0.0.1"
self.files = files or {}
self.connection = connection
self._start_time = time.time()
self._finish_time = None
self.path, sep, self.query = uri.partition('?')
self.arguments = parse_qs_bytes(self.query, keep_blank_values=True)
self.query_arguments = copy.deepcopy(self.arguments)
self.body_arguments = {}
def supports_http_1_1(self):
"""Returns True if this request supports HTTP/1.1 semantics.
.. deprecated:: 4.0
Applications are less likely to need this information with the
introduction of `.HTTPConnection`. If you still need it, access
the ``version`` attribute directly.
"""
return self.version == "HTTP/1.1"
@property
def cookies(self):
"""A dictionary of Cookie.Morsel objects."""
if not hasattr(self, "_cookies"):
self._cookies = Cookie.SimpleCookie()
if "Cookie" in self.headers:
try:
self._cookies.load(
native_str(self.headers["Cookie"]))
except Exception:
self._cookies = {}
return self._cookies
def write(self, chunk, callback=None):
"""Writes the given chunk to the response stream.
.. deprecated:: 4.0
Use ``request.connection`` and the `.HTTPConnection` methods
to write the response.
"""
assert isinstance(chunk, bytes_type)
self.connection.write(chunk, callback=callback)
def finish(self):
"""Finishes this HTTP request on the open connection.
.. deprecated:: 4.0
Use ``request.connection`` and the `.HTTPConnection` methods
to write the response.
"""
self.connection.finish()
self._finish_time = time.time()
def full_url(self):
"""Reconstructs the full URL for this request."""
return self.protocol + "://" + self.host + self.uri
def request_time(self):
"""Returns the amount of time it took for this request to execute."""
if self._finish_time is None:
return time.time() - self._start_time
else:
return self._finish_time - self._start_time
def get_ssl_certificate(self, binary_form=False):
"""Returns the client's SSL certificate, if any.
To use client certificates, the HTTPServer must have been constructed
with cert_reqs set in ssl_options, e.g.::
server = HTTPServer(app,
ssl_options=dict(
certfile="foo.crt",
keyfile="foo.key",
cert_reqs=ssl.CERT_REQUIRED,
ca_certs="cacert.crt"))
By default, the return value is a dictionary (or None, if no
client certificate is present). If ``binary_form`` is true, a
DER-encoded form of the certificate is returned instead. See
SSLSocket.getpeercert() in the standard library for more
details.
http://docs.python.org/library/ssl.html#sslsocket-objects
"""
try:
return self.connection.stream.socket.getpeercert(
binary_form=binary_form)
except SSLError:
return None
def _parse_body(self):
parse_body_arguments(
self.headers.get("Content-Type", ""), self.body,
self.body_arguments, self.files,
self.headers)
for k, v in self.body_arguments.items():
self.arguments.setdefault(k, []).extend(v)
def __repr__(self):
attrs = ("protocol", "host", "method", "uri", "version", "remote_ip")
args = ", ".join(["%s=%r" % (n, getattr(self, n)) for n in attrs])
return "%s(%s, headers=%s)" % (
self.__class__.__name__, args, dict(self.headers))
class HTTPInputError(Exception):
"""Exception class for malformed HTTP requests or responses
from remote sources.
.. versionadded:: 4.0
"""
pass
class HTTPOutputError(Exception):
"""Exception class for errors in HTTP output.
.. versionadded:: 4.0
"""
pass
class HTTPServerConnectionDelegate(object):
"""Implement this interface to handle requests from `.HTTPServer`.
.. versionadded:: 4.0
"""
def start_request(self, server_conn, request_conn):
"""This method is called by the server when a new request has started.
:arg server_conn: is an opaque object representing the long-lived
(e.g. tcp-level) connection.
:arg request_conn: is a `.HTTPConnection` object for a single
request/response exchange.
This method should return a `.HTTPMessageDelegate`.
"""
raise NotImplementedError()
def on_close(self, server_conn):
"""This method is called when a connection has been closed.
:arg server_conn: is a server connection that has previously been
passed to ``start_request``.
"""
pass
class HTTPMessageDelegate(object):
"""Implement this interface to handle an HTTP request or response.
.. versionadded:: 4.0
"""
def headers_received(self, start_line, headers):
"""Called when the HTTP headers have been received and parsed.
:arg start_line: a `.RequestStartLine` or `.ResponseStartLine`
depending on whether this is a client or server message.
:arg headers: a `.HTTPHeaders` instance.
Some `.HTTPConnection` methods can only be called during
``headers_received``.
May return a `.Future`; if it does the body will not be read
until it is done.
"""
pass
def data_received(self, chunk):
"""Called when a chunk of data has been received.
May return a `.Future` for flow control.
"""
pass
def finish(self):
"""Called after the last chunk of data has been received."""
pass
def on_connection_close(self):
"""Called if the connection is closed without finishing the request.
If ``headers_received`` is called, either ``finish`` or
``on_connection_close`` will be called, but not both.
"""
pass
class HTTPConnection(object):
"""Applications use this interface to write their responses.
.. versionadded:: 4.0
"""
def write_headers(self, start_line, headers, chunk=None, callback=None):
"""Write an HTTP header block.
:arg start_line: a `.RequestStartLine` or `.ResponseStartLine`.
:arg headers: a `.HTTPHeaders` instance.
:arg chunk: the first (optional) chunk of data. This is an optimization
so that small responses can be written in the same call as their
headers.
:arg callback: a callback to be run when the write is complete.
Returns a `.Future` if no callback is given.
"""
raise NotImplementedError()
def write(self, chunk, callback=None):
"""Writes a chunk of body data.
The callback will be run when the write is complete. If no callback
is given, returns a Future.
"""
raise NotImplementedError()
def finish(self):
"""Indicates that the last body data has been written.
"""
raise NotImplementedError()
def url_concat(url, args):
"""Concatenate url and argument dictionary regardless of whether
url has existing query parameters.
>>> url_concat("http://example.com/foo?a=b", dict(c="d"))
'http://example.com/foo?a=b&c=d'
"""
if not args:
return url
if url[-1] not in ('?', '&'):
url += '&' if ('?' in url) else '?'
return url + urlencode(args)
class HTTPFile(ObjectDict):
"""Represents a file uploaded via a form.
For backwards compatibility, its instance attributes are also
accessible as dictionary keys.
* ``filename``
* ``body``
* ``content_type``
"""
pass
def _parse_request_range(range_header):
"""Parses a Range header.
Returns either ``None`` or tuple ``(start, end)``.
Note that while the HTTP headers use inclusive byte positions,
this method returns indexes suitable for use in slices.
>>> start, end = _parse_request_range("bytes=1-2")
>>> start, end
(1, 3)
>>> [0, 1, 2, 3, 4][start:end]
[1, 2]
>>> _parse_request_range("bytes=6-")
(6, None)
>>> _parse_request_range("bytes=-6")
(-6, None)
>>> _parse_request_range("bytes=-0")
(None, 0)
>>> _parse_request_range("bytes=")
(None, None)
>>> _parse_request_range("foo=42")
>>> _parse_request_range("bytes=1-2,6-10")
Note: only supports one range (ex, ``bytes=1-2,6-10`` is not allowed).
See [0] for the details of the range header.
[0]: http://greenbytes.de/tech/webdav/draft-ietf-httpbis-p5-range-latest.html#byte.ranges
"""
unit, _, value = range_header.partition("=")
unit, value = unit.strip(), value.strip()
if unit != "bytes":
return None
start_b, _, end_b = value.partition("-")
try:
start = _int_or_none(start_b)
end = _int_or_none(end_b)
except ValueError:
return None
if end is not None:
if start is None:
if end != 0:
start = -end
end = None
else:
end += 1
return (start, end)
def _get_content_range(start, end, total):
"""Returns a suitable Content-Range header:
>>> print(_get_content_range(None, 1, 4))
bytes 0-0/4
>>> print(_get_content_range(1, 3, 4))
bytes 1-2/4
>>> print(_get_content_range(None, None, 4))
bytes 0-3/4
"""
start = start or 0
end = (end or total) - 1
return "bytes %s-%s/%s" % (start, end, total)
def _int_or_none(val):
val = val.strip()
if val == "":
return None
return int(val)
def parse_body_arguments(content_type, body, arguments, files, headers=None):
"""Parses a form request body.
Supports ``application/x-www-form-urlencoded`` and
``multipart/form-data``. The ``content_type`` parameter should be
a string and ``body`` should be a byte string. The ``arguments``
and ``files`` parameters are dictionaries that will be updated
with the parsed contents.
"""
if headers and 'Content-Encoding' in headers:
gen_log.warning("Unsupported Content-Encoding: %s",
headers['Content-Encoding'])
return
if content_type.startswith("application/x-www-form-urlencoded"):
try:
uri_arguments = parse_qs_bytes(native_str(body), keep_blank_values=True)
except Exception as e:
gen_log.warning('Invalid x-www-form-urlencoded body: %s', e)
uri_arguments = {}
for name, values in uri_arguments.items():
if values:
arguments.setdefault(name, []).extend(values)
elif content_type.startswith("multipart/form-data"):
fields = content_type.split(";")
for field in fields:
k, sep, v = field.strip().partition("=")
if k == "boundary" and v:
parse_multipart_form_data(utf8(v), body, arguments, files)
break
else:
gen_log.warning("Invalid multipart/form-data")
def parse_multipart_form_data(boundary, data, arguments, files):
"""Parses a ``multipart/form-data`` body.
The ``boundary`` and ``data`` parameters are both byte strings.
The dictionaries given in the arguments and files parameters
will be updated with the contents of the body.
"""
# The standard allows for the boundary to be quoted in the header,
# although it's rare (it happens at least for google app engine
# xmpp). I think we're also supposed to handle backslash-escapes
# here but I'll save that until we see a client that uses them
# in the wild.
if boundary.startswith(b'"') and boundary.endswith(b'"'):
boundary = boundary[1:-1]
final_boundary_index = data.rfind(b"--" + boundary + b"--")
if final_boundary_index == -1:
gen_log.warning("Invalid multipart/form-data: no final boundary")
return
parts = data[:final_boundary_index].split(b"--" + boundary + b"\r\n")
for part in parts:
if not part:
continue
eoh = part.find(b"\r\n\r\n")
if eoh == -1:
gen_log.warning("multipart/form-data missing headers")
continue
headers = HTTPHeaders.parse(part[:eoh].decode("utf-8"))
disp_header = headers.get("Content-Disposition", "")
disposition, disp_params = _parse_header(disp_header)
if disposition != "form-data" or not part.endswith(b"\r\n"):
gen_log.warning("Invalid multipart/form-data")
continue
value = part[eoh + 4:-2]
if not disp_params.get("name"):
gen_log.warning("multipart/form-data value missing name")
continue
name = disp_params["name"]
if disp_params.get("filename"):
ctype = headers.get("Content-Type", "application/unknown")
files.setdefault(name, []).append(HTTPFile(
filename=disp_params["filename"], body=value,
content_type=ctype))
else:
arguments.setdefault(name, []).append(value)
def format_timestamp(ts):
"""Formats a timestamp in the format used by HTTP.
The argument may be a numeric timestamp as returned by `time.time`,
a time tuple as returned by `time.gmtime`, or a `datetime.datetime`
object.
>>> format_timestamp(1359312200)
'Sun, 27 Jan 2013 18:43:20 GMT'
"""
if isinstance(ts, numbers.Real):
pass
elif isinstance(ts, (tuple, time.struct_time)):
ts = calendar.timegm(ts)
elif isinstance(ts, datetime.datetime):
ts = calendar.timegm(ts.utctimetuple())
else:
raise TypeError("unknown timestamp type: %r" % ts)
return email.utils.formatdate(ts, usegmt=True)
RequestStartLine = collections.namedtuple(
'RequestStartLine', ['method', 'path', 'version'])
def parse_request_start_line(line):
"""Returns a (method, path, version) tuple for an HTTP 1.x request line.
The response is a `collections.namedtuple`.
>>> parse_request_start_line("GET /foo HTTP/1.1")
RequestStartLine(method='GET', path='/foo', version='HTTP/1.1')
"""
try:
method, path, version = line.split(" ")
except ValueError:
raise HTTPInputError("Malformed HTTP request line")
if not version.startswith("HTTP/"):
raise HTTPInputError(
"Malformed HTTP version in HTTP Request-Line: %r" % version)
return RequestStartLine(method, path, version)
ResponseStartLine = collections.namedtuple(
'ResponseStartLine', ['version', 'code', 'reason'])
def parse_response_start_line(line):
"""Returns a (version, code, reason) tuple for an HTTP 1.x response line.
The response is a `collections.namedtuple`.
>>> parse_response_start_line("HTTP/1.1 200 OK")
ResponseStartLine(version='HTTP/1.1', code=200, reason='OK')
"""
line = native_str(line)
match = re.match("(HTTP/1.[01]) ([0-9]+) ([^\r]*)", line)
if not match:
raise HTTPInputError("Error parsing response start line")
return ResponseStartLine(match.group(1), int(match.group(2)),
match.group(3))
# _parseparam and _parse_header are copied and modified from python2.7's cgi.py
# The original 2.7 version of this code did not correctly support some
# combinations of semicolons and double quotes.
def _parseparam(s):
while s[:1] == ';':
s = s[1:]
end = s.find(';')
while end > 0 and (s.count('"', 0, end) - s.count('\\"', 0, end)) % 2:
end = s.find(';', end + 1)
if end < 0:
end = len(s)
f = s[:end]
yield f.strip()
s = s[end:]
def _parse_header(line):
"""Parse a Content-type like header.
Return the main content-type and a dictionary of options.
"""
parts = _parseparam(';' + line)
key = next(parts)
pdict = {}
for p in parts:
i = p.find('=')
if i >= 0:
name = p[:i].strip().lower()
value = p[i + 1:].strip()
if len(value) >= 2 and value[0] == value[-1] == '"':
value = value[1:-1]
value = value.replace('\\\\', '\\').replace('\\"', '"')
pdict[name] = value
return key, pdict
def doctests():
import doctest
return doctest.DocTestSuite()

View file

@ -0,0 +1,982 @@
#!/usr/bin/env python
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""An I/O event loop for non-blocking sockets.
Typical applications will use a single `IOLoop` object, in the
`IOLoop.instance` singleton. The `IOLoop.start` method should usually
be called at the end of the ``main()`` function. Atypical applications may
use more than one `IOLoop`, such as one `IOLoop` per thread, or per `unittest`
case.
In addition to I/O events, the `IOLoop` can also schedule time-based events.
`IOLoop.add_timeout` is a non-blocking alternative to `time.sleep`.
"""
from __future__ import absolute_import, division, print_function, with_statement
import datetime
import errno
import functools
import heapq
import itertools
import logging
import numbers
import os
import select
import sys
import threading
import time
import traceback
from tornado.concurrent import TracebackFuture, is_future
from tornado.log import app_log, gen_log
from tornado import stack_context
from tornado.util import Configurable, errno_from_exception, timedelta_to_seconds
try:
import signal
except ImportError:
signal = None
try:
import thread # py2
except ImportError:
import _thread as thread # py3
from tornado.platform.auto import set_close_exec, Waker
_POLL_TIMEOUT = 3600.0
class TimeoutError(Exception):
pass
class IOLoop(Configurable):
"""A level-triggered I/O loop.
We use ``epoll`` (Linux) or ``kqueue`` (BSD and Mac OS X) if they
are available, or else we fall back on select(). If you are
implementing a system that needs to handle thousands of
simultaneous connections, you should use a system that supports
either ``epoll`` or ``kqueue``.
Example usage for a simple TCP server::
import errno
import functools
import ioloop
import socket
def connection_ready(sock, fd, events):
while True:
try:
connection, address = sock.accept()
except socket.error, e:
if e.args[0] not in (errno.EWOULDBLOCK, errno.EAGAIN):
raise
return
connection.setblocking(0)
handle_connection(connection, address)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setblocking(0)
sock.bind(("", port))
sock.listen(128)
io_loop = ioloop.IOLoop.instance()
callback = functools.partial(connection_ready, sock)
io_loop.add_handler(sock.fileno(), callback, io_loop.READ)
io_loop.start()
"""
# Constants from the epoll module
_EPOLLIN = 0x001
_EPOLLPRI = 0x002
_EPOLLOUT = 0x004
_EPOLLERR = 0x008
_EPOLLHUP = 0x010
_EPOLLRDHUP = 0x2000
_EPOLLONESHOT = (1 << 30)
_EPOLLET = (1 << 31)
# Our events map exactly to the epoll events
NONE = 0
READ = _EPOLLIN
WRITE = _EPOLLOUT
ERROR = _EPOLLERR | _EPOLLHUP
# Global lock for creating global IOLoop instance
_instance_lock = threading.Lock()
_current = threading.local()
@staticmethod
def instance():
"""Returns a global `IOLoop` instance.
Most applications have a single, global `IOLoop` running on the
main thread. Use this method to get this instance from
another thread. To get the current thread's `IOLoop`, use `current()`.
"""
if not hasattr(IOLoop, "_instance"):
with IOLoop._instance_lock:
if not hasattr(IOLoop, "_instance"):
# New instance after double check
IOLoop._instance = IOLoop()
return IOLoop._instance
@staticmethod
def initialized():
"""Returns true if the singleton instance has been created."""
return hasattr(IOLoop, "_instance")
def install(self):
"""Installs this `IOLoop` object as the singleton instance.
This is normally not necessary as `instance()` will create
an `IOLoop` on demand, but you may want to call `install` to use
a custom subclass of `IOLoop`.
"""
assert not IOLoop.initialized()
IOLoop._instance = self
@staticmethod
def clear_instance():
"""Clear the global `IOLoop` instance.
.. versionadded:: 4.0
"""
if hasattr(IOLoop, "_instance"):
del IOLoop._instance
@staticmethod
def current():
"""Returns the current thread's `IOLoop`.
If an `IOLoop` is currently running or has been marked as current
by `make_current`, returns that instance. Otherwise returns
`IOLoop.instance()`, i.e. the main thread's `IOLoop`.
A common pattern for classes that depend on ``IOLoops`` is to use
a default argument to enable programs with multiple ``IOLoops``
but not require the argument for simpler applications::
class MyClass(object):
def __init__(self, io_loop=None):
self.io_loop = io_loop or IOLoop.current()
In general you should use `IOLoop.current` as the default when
constructing an asynchronous object, and use `IOLoop.instance`
when you mean to communicate to the main thread from a different
one.
"""
current = getattr(IOLoop._current, "instance", None)
if current is None:
return IOLoop.instance()
return current
def make_current(self):
"""Makes this the `IOLoop` for the current thread.
An `IOLoop` automatically becomes current for its thread
when it is started, but it is sometimes useful to call
`make_current` explictly before starting the `IOLoop`,
so that code run at startup time can find the right
instance.
"""
IOLoop._current.instance = self
@staticmethod
def clear_current():
IOLoop._current.instance = None
@classmethod
def configurable_base(cls):
return IOLoop
@classmethod
def configurable_default(cls):
if hasattr(select, "epoll"):
from tornado.platform.epoll import EPollIOLoop
return EPollIOLoop
if hasattr(select, "kqueue"):
# Python 2.6+ on BSD or Mac
from tornado.platform.kqueue import KQueueIOLoop
return KQueueIOLoop
from tornado.platform.select import SelectIOLoop
return SelectIOLoop
def initialize(self):
pass
def close(self, all_fds=False):
"""Closes the `IOLoop`, freeing any resources used.
If ``all_fds`` is true, all file descriptors registered on the
IOLoop will be closed (not just the ones created by the
`IOLoop` itself).
Many applications will only use a single `IOLoop` that runs for the
entire lifetime of the process. In that case closing the `IOLoop`
is not necessary since everything will be cleaned up when the
process exits. `IOLoop.close` is provided mainly for scenarios
such as unit tests, which create and destroy a large number of
``IOLoops``.
An `IOLoop` must be completely stopped before it can be closed. This
means that `IOLoop.stop()` must be called *and* `IOLoop.start()` must
be allowed to return before attempting to call `IOLoop.close()`.
Therefore the call to `close` will usually appear just after
the call to `start` rather than near the call to `stop`.
.. versionchanged:: 3.1
If the `IOLoop` implementation supports non-integer objects
for "file descriptors", those objects will have their
``close`` method when ``all_fds`` is true.
"""
raise NotImplementedError()
def add_handler(self, fd, handler, events):
"""Registers the given handler to receive the given events for ``fd``.
The ``fd`` argument may either be an integer file descriptor or
a file-like object with a ``fileno()`` method (and optionally a
``close()`` method, which may be called when the `IOLoop` is shut
down).
The ``events`` argument is a bitwise or of the constants
``IOLoop.READ``, ``IOLoop.WRITE``, and ``IOLoop.ERROR``.
When an event occurs, ``handler(fd, events)`` will be run.
.. versionchanged:: 4.0
Added the ability to pass file-like objects in addition to
raw file descriptors.
"""
raise NotImplementedError()
def update_handler(self, fd, events):
"""Changes the events we listen for ``fd``.
.. versionchanged:: 4.0
Added the ability to pass file-like objects in addition to
raw file descriptors.
"""
raise NotImplementedError()
def remove_handler(self, fd):
"""Stop listening for events on ``fd``.
.. versionchanged:: 4.0
Added the ability to pass file-like objects in addition to
raw file descriptors.
"""
raise NotImplementedError()
def set_blocking_signal_threshold(self, seconds, action):
"""Sends a signal if the `IOLoop` is blocked for more than
``s`` seconds.
Pass ``seconds=None`` to disable. Requires Python 2.6 on a unixy
platform.
The action parameter is a Python signal handler. Read the
documentation for the `signal` module for more information.
If ``action`` is None, the process will be killed if it is
blocked for too long.
"""
raise NotImplementedError()
def set_blocking_log_threshold(self, seconds):
"""Logs a stack trace if the `IOLoop` is blocked for more than
``s`` seconds.
Equivalent to ``set_blocking_signal_threshold(seconds,
self.log_stack)``
"""
self.set_blocking_signal_threshold(seconds, self.log_stack)
def log_stack(self, signal, frame):
"""Signal handler to log the stack trace of the current thread.
For use with `set_blocking_signal_threshold`.
"""
gen_log.warning('IOLoop blocked for %f seconds in\n%s',
self._blocking_signal_threshold,
''.join(traceback.format_stack(frame)))
def start(self):
"""Starts the I/O loop.
The loop will run until one of the callbacks calls `stop()`, which
will make the loop stop after the current event iteration completes.
"""
raise NotImplementedError()
def _setup_logging(self):
"""The IOLoop catches and logs exceptions, so it's
important that log output be visible. However, python's
default behavior for non-root loggers (prior to python
3.2) is to print an unhelpful "no handlers could be
found" message rather than the actual log entry, so we
must explicitly configure logging if we've made it this
far without anything.
This method should be called from start() in subclasses.
"""
if not any([logging.getLogger().handlers,
logging.getLogger('tornado').handlers,
logging.getLogger('tornado.application').handlers]):
logging.basicConfig()
def stop(self):
"""Stop the I/O loop.
If the event loop is not currently running, the next call to `start()`
will return immediately.
To use asynchronous methods from otherwise-synchronous code (such as
unit tests), you can start and stop the event loop like this::
ioloop = IOLoop()
async_method(ioloop=ioloop, callback=ioloop.stop)
ioloop.start()
``ioloop.start()`` will return after ``async_method`` has run
its callback, whether that callback was invoked before or
after ``ioloop.start``.
Note that even after `stop` has been called, the `IOLoop` is not
completely stopped until `IOLoop.start` has also returned.
Some work that was scheduled before the call to `stop` may still
be run before the `IOLoop` shuts down.
"""
raise NotImplementedError()
def run_sync(self, func, timeout=None):
"""Starts the `IOLoop`, runs the given function, and stops the loop.
If the function returns a `.Future`, the `IOLoop` will run
until the future is resolved. If it raises an exception, the
`IOLoop` will stop and the exception will be re-raised to the
caller.
The keyword-only argument ``timeout`` may be used to set
a maximum duration for the function. If the timeout expires,
a `TimeoutError` is raised.
This method is useful in conjunction with `tornado.gen.coroutine`
to allow asynchronous calls in a ``main()`` function::
@gen.coroutine
def main():
# do stuff...
if __name__ == '__main__':
IOLoop.instance().run_sync(main)
"""
future_cell = [None]
def run():
try:
result = func()
except Exception:
future_cell[0] = TracebackFuture()
future_cell[0].set_exc_info(sys.exc_info())
else:
if is_future(result):
future_cell[0] = result
else:
future_cell[0] = TracebackFuture()
future_cell[0].set_result(result)
self.add_future(future_cell[0], lambda future: self.stop())
self.add_callback(run)
if timeout is not None:
timeout_handle = self.add_timeout(self.time() + timeout, self.stop)
self.start()
if timeout is not None:
self.remove_timeout(timeout_handle)
if not future_cell[0].done():
raise TimeoutError('Operation timed out after %s seconds' % timeout)
return future_cell[0].result()
def time(self):
"""Returns the current time according to the `IOLoop`'s clock.
The return value is a floating-point number relative to an
unspecified time in the past.
By default, the `IOLoop`'s time function is `time.time`. However,
it may be configured to use e.g. `time.monotonic` instead.
Calls to `add_timeout` that pass a number instead of a
`datetime.timedelta` should use this function to compute the
appropriate time, so they can work no matter what time function
is chosen.
"""
return time.time()
def add_timeout(self, deadline, callback, *args, **kwargs):
"""Runs the ``callback`` at the time ``deadline`` from the I/O loop.
Returns an opaque handle that may be passed to
`remove_timeout` to cancel.
``deadline`` may be a number denoting a time (on the same
scale as `IOLoop.time`, normally `time.time`), or a
`datetime.timedelta` object for a deadline relative to the
current time. Since Tornado 4.0, `call_later` is a more
convenient alternative for the relative case since it does not
require a timedelta object.
Note that it is not safe to call `add_timeout` from other threads.
Instead, you must use `add_callback` to transfer control to the
`IOLoop`'s thread, and then call `add_timeout` from there.
Subclasses of IOLoop must implement either `add_timeout` or
`call_at`; the default implementations of each will call
the other. `call_at` is usually easier to implement, but
subclasses that wish to maintain compatibility with Tornado
versions prior to 4.0 must use `add_timeout` instead.
.. versionchanged:: 4.0
Now passes through ``*args`` and ``**kwargs`` to the callback.
"""
if isinstance(deadline, numbers.Real):
return self.call_at(deadline, callback, *args, **kwargs)
elif isinstance(deadline, datetime.timedelta):
return self.call_at(self.time() + timedelta_to_seconds(deadline),
callback, *args, **kwargs)
else:
raise TypeError("Unsupported deadline %r" % deadline)
def call_later(self, delay, callback, *args, **kwargs):
"""Runs the ``callback`` after ``delay`` seconds have passed.
Returns an opaque handle that may be passed to `remove_timeout`
to cancel. Note that unlike the `asyncio` method of the same
name, the returned object does not have a ``cancel()`` method.
See `add_timeout` for comments on thread-safety and subclassing.
.. versionadded:: 4.0
"""
self.call_at(self.time() + delay, callback, *args, **kwargs)
def call_at(self, when, callback, *args, **kwargs):
"""Runs the ``callback`` at the absolute time designated by ``when``.
``when`` must be a number using the same reference point as
`IOLoop.time`.
Returns an opaque handle that may be passed to `remove_timeout`
to cancel. Note that unlike the `asyncio` method of the same
name, the returned object does not have a ``cancel()`` method.
See `add_timeout` for comments on thread-safety and subclassing.
.. versionadded:: 4.0
"""
self.add_timeout(when, callback, *args, **kwargs)
def remove_timeout(self, timeout):
"""Cancels a pending timeout.
The argument is a handle as returned by `add_timeout`. It is
safe to call `remove_timeout` even if the callback has already
been run.
"""
raise NotImplementedError()
def add_callback(self, callback, *args, **kwargs):
"""Calls the given callback on the next I/O loop iteration.
It is safe to call this method from any thread at any time,
except from a signal handler. Note that this is the **only**
method in `IOLoop` that makes this thread-safety guarantee; all
other interaction with the `IOLoop` must be done from that
`IOLoop`'s thread. `add_callback()` may be used to transfer
control from other threads to the `IOLoop`'s thread.
To add a callback from a signal handler, see
`add_callback_from_signal`.
"""
raise NotImplementedError()
def add_callback_from_signal(self, callback, *args, **kwargs):
"""Calls the given callback on the next I/O loop iteration.
Safe for use from a Python signal handler; should not be used
otherwise.
Callbacks added with this method will be run without any
`.stack_context`, to avoid picking up the context of the function
that was interrupted by the signal.
"""
raise NotImplementedError()
def spawn_callback(self, callback, *args, **kwargs):
"""Calls the given callback on the next IOLoop iteration.
Unlike all other callback-related methods on IOLoop,
``spawn_callback`` does not associate the callback with its caller's
``stack_context``, so it is suitable for fire-and-forget callbacks
that should not interfere with the caller.
.. versionadded:: 4.0
"""
with stack_context.NullContext():
self.add_callback(callback, *args, **kwargs)
def add_future(self, future, callback):
"""Schedules a callback on the ``IOLoop`` when the given
`.Future` is finished.
The callback is invoked with one argument, the
`.Future`.
"""
assert is_future(future)
callback = stack_context.wrap(callback)
future.add_done_callback(
lambda future: self.add_callback(callback, future))
def _run_callback(self, callback):
"""Runs a callback with error handling.
For use in subclasses.
"""
try:
ret = callback()
if ret is not None and is_future(ret):
# Functions that return Futures typically swallow all
# exceptions and store them in the Future. If a Future
# makes it out to the IOLoop, ensure its exception (if any)
# gets logged too.
self.add_future(ret, lambda f: f.result())
except Exception:
self.handle_callback_exception(callback)
def handle_callback_exception(self, callback):
"""This method is called whenever a callback run by the `IOLoop`
throws an exception.
By default simply logs the exception as an error. Subclasses
may override this method to customize reporting of exceptions.
The exception itself is not passed explicitly, but is available
in `sys.exc_info`.
"""
app_log.error("Exception in callback %r", callback, exc_info=True)
def split_fd(self, fd):
"""Returns an (fd, obj) pair from an ``fd`` parameter.
We accept both raw file descriptors and file-like objects as
input to `add_handler` and related methods. When a file-like
object is passed, we must retain the object itself so we can
close it correctly when the `IOLoop` shuts down, but the
poller interfaces favor file descriptors (they will accept
file-like objects and call ``fileno()`` for you, but they
always return the descriptor itself).
This method is provided for use by `IOLoop` subclasses and should
not generally be used by application code.
.. versionadded:: 4.0
"""
try:
return fd.fileno(), fd
except AttributeError:
return fd, fd
def close_fd(self, fd):
"""Utility method to close an ``fd``.
If ``fd`` is a file-like object, we close it directly; otherwise
we use `os.close`.
This method is provided for use by `IOLoop` subclasses (in
implementations of ``IOLoop.close(all_fds=True)`` and should
not generally be used by application code.
.. versionadded:: 4.0
"""
try:
try:
fd.close()
except AttributeError:
os.close(fd)
except OSError:
pass
class PollIOLoop(IOLoop):
"""Base class for IOLoops built around a select-like function.
For concrete implementations, see `tornado.platform.epoll.EPollIOLoop`
(Linux), `tornado.platform.kqueue.KQueueIOLoop` (BSD and Mac), or
`tornado.platform.select.SelectIOLoop` (all platforms).
"""
def initialize(self, impl, time_func=None):
super(PollIOLoop, self).initialize()
self._impl = impl
if hasattr(self._impl, 'fileno'):
set_close_exec(self._impl.fileno())
self.time_func = time_func or time.time
self._handlers = {}
self._events = {}
self._callbacks = []
self._callback_lock = threading.Lock()
self._timeouts = []
self._cancellations = 0
self._running = False
self._stopped = False
self._closing = False
self._thread_ident = None
self._blocking_signal_threshold = None
self._timeout_counter = itertools.count()
# Create a pipe that we send bogus data to when we want to wake
# the I/O loop when it is idle
self._waker = Waker()
self.add_handler(self._waker.fileno(),
lambda fd, events: self._waker.consume(),
self.READ)
def close(self, all_fds=False):
with self._callback_lock:
self._closing = True
self.remove_handler(self._waker.fileno())
if all_fds:
for fd, handler in self._handlers.values():
self.close_fd(fd)
self._waker.close()
self._impl.close()
self._callbacks = None
self._timeouts = None
def add_handler(self, fd, handler, events):
fd, obj = self.split_fd(fd)
self._handlers[fd] = (obj, stack_context.wrap(handler))
self._impl.register(fd, events | self.ERROR)
def update_handler(self, fd, events):
fd, obj = self.split_fd(fd)
self._impl.modify(fd, events | self.ERROR)
def remove_handler(self, fd):
fd, obj = self.split_fd(fd)
self._handlers.pop(fd, None)
self._events.pop(fd, None)
try:
self._impl.unregister(fd)
except Exception:
gen_log.debug("Error deleting fd from IOLoop", exc_info=True)
def set_blocking_signal_threshold(self, seconds, action):
if not hasattr(signal, "setitimer"):
gen_log.error("set_blocking_signal_threshold requires a signal module "
"with the setitimer method")
return
self._blocking_signal_threshold = seconds
if seconds is not None:
signal.signal(signal.SIGALRM,
action if action is not None else signal.SIG_DFL)
def start(self):
if self._running:
raise RuntimeError("IOLoop is already running")
self._setup_logging()
if self._stopped:
self._stopped = False
return
old_current = getattr(IOLoop._current, "instance", None)
IOLoop._current.instance = self
self._thread_ident = thread.get_ident()
self._running = True
# signal.set_wakeup_fd closes a race condition in event loops:
# a signal may arrive at the beginning of select/poll/etc
# before it goes into its interruptible sleep, so the signal
# will be consumed without waking the select. The solution is
# for the (C, synchronous) signal handler to write to a pipe,
# which will then be seen by select.
#
# In python's signal handling semantics, this only matters on the
# main thread (fortunately, set_wakeup_fd only works on the main
# thread and will raise a ValueError otherwise).
#
# If someone has already set a wakeup fd, we don't want to
# disturb it. This is an issue for twisted, which does its
# SIGCHILD processing in response to its own wakeup fd being
# written to. As long as the wakeup fd is registered on the IOLoop,
# the loop will still wake up and everything should work.
old_wakeup_fd = None
if hasattr(signal, 'set_wakeup_fd') and os.name == 'posix':
# requires python 2.6+, unix. set_wakeup_fd exists but crashes
# the python process on windows.
try:
old_wakeup_fd = signal.set_wakeup_fd(self._waker.write_fileno())
if old_wakeup_fd != -1:
# Already set, restore previous value. This is a little racy,
# but there's no clean get_wakeup_fd and in real use the
# IOLoop is just started once at the beginning.
signal.set_wakeup_fd(old_wakeup_fd)
old_wakeup_fd = None
except ValueError: # non-main thread
pass
try:
while True:
# Prevent IO event starvation by delaying new callbacks
# to the next iteration of the event loop.
with self._callback_lock:
callbacks = self._callbacks
self._callbacks = []
# Add any timeouts that have come due to the callback list.
# Do not run anything until we have determined which ones
# are ready, so timeouts that call add_timeout cannot
# schedule anything in this iteration.
if self._timeouts:
now = self.time()
while self._timeouts:
if self._timeouts[0].callback is None:
# the timeout was cancelled
heapq.heappop(self._timeouts)
self._cancellations -= 1
elif self._timeouts[0].deadline <= now:
timeout = heapq.heappop(self._timeouts)
callbacks.append(timeout.callback)
del timeout
else:
break
if (self._cancellations > 512
and self._cancellations > (len(self._timeouts) >> 1)):
# Clean up the timeout queue when it gets large and it's
# more than half cancellations.
self._cancellations = 0
self._timeouts = [x for x in self._timeouts
if x.callback is not None]
heapq.heapify(self._timeouts)
for callback in callbacks:
self._run_callback(callback)
# Closures may be holding on to a lot of memory, so allow
# them to be freed before we go into our poll wait.
callbacks = callback = None
if self._callbacks:
# If any callbacks or timeouts called add_callback,
# we don't want to wait in poll() before we run them.
poll_timeout = 0.0
elif self._timeouts:
# If there are any timeouts, schedule the first one.
# Use self.time() instead of 'now' to account for time
# spent running callbacks.
poll_timeout = self._timeouts[0].deadline - self.time()
poll_timeout = max(0, min(poll_timeout, _POLL_TIMEOUT))
else:
# No timeouts and no callbacks, so use the default.
poll_timeout = _POLL_TIMEOUT
if not self._running:
break
if self._blocking_signal_threshold is not None:
# clear alarm so it doesn't fire while poll is waiting for
# events.
signal.setitimer(signal.ITIMER_REAL, 0, 0)
try:
event_pairs = self._impl.poll(poll_timeout)
except Exception as e:
# Depending on python version and IOLoop implementation,
# different exception types may be thrown and there are
# two ways EINTR might be signaled:
# * e.errno == errno.EINTR
# * e.args is like (errno.EINTR, 'Interrupted system call')
if errno_from_exception(e) == errno.EINTR:
continue
else:
raise
if self._blocking_signal_threshold is not None:
signal.setitimer(signal.ITIMER_REAL,
self._blocking_signal_threshold, 0)
# Pop one fd at a time from the set of pending fds and run
# its handler. Since that handler may perform actions on
# other file descriptors, there may be reentrant calls to
# this IOLoop that update self._events
self._events.update(event_pairs)
while self._events:
fd, events = self._events.popitem()
try:
fd_obj, handler_func = self._handlers[fd]
handler_func(fd_obj, events)
except (OSError, IOError) as e:
if errno_from_exception(e) == errno.EPIPE:
# Happens when the client closes the connection
pass
else:
self.handle_callback_exception(self._handlers.get(fd))
except Exception:
self.handle_callback_exception(self._handlers.get(fd))
fd_obj = handler_func = None
finally:
# reset the stopped flag so another start/stop pair can be issued
self._stopped = False
if self._blocking_signal_threshold is not None:
signal.setitimer(signal.ITIMER_REAL, 0, 0)
IOLoop._current.instance = old_current
if old_wakeup_fd is not None:
signal.set_wakeup_fd(old_wakeup_fd)
def stop(self):
self._running = False
self._stopped = True
self._waker.wake()
def time(self):
return self.time_func()
def call_at(self, deadline, callback, *args, **kwargs):
timeout = _Timeout(
deadline,
functools.partial(stack_context.wrap(callback), *args, **kwargs),
self)
heapq.heappush(self._timeouts, timeout)
return timeout
def remove_timeout(self, timeout):
# Removing from a heap is complicated, so just leave the defunct
# timeout object in the queue (see discussion in
# http://docs.python.org/library/heapq.html).
# If this turns out to be a problem, we could add a garbage
# collection pass whenever there are too many dead timeouts.
timeout.callback = None
self._cancellations += 1
def add_callback(self, callback, *args, **kwargs):
with self._callback_lock:
if self._closing:
raise RuntimeError("IOLoop is closing")
list_empty = not self._callbacks
self._callbacks.append(functools.partial(
stack_context.wrap(callback), *args, **kwargs))
if list_empty and thread.get_ident() != self._thread_ident:
# If we're in the IOLoop's thread, we know it's not currently
# polling. If we're not, and we added the first callback to an
# empty list, we may need to wake it up (it may wake up on its
# own, but an occasional extra wake is harmless). Waking
# up a polling IOLoop is relatively expensive, so we try to
# avoid it when we can.
self._waker.wake()
def add_callback_from_signal(self, callback, *args, **kwargs):
with stack_context.NullContext():
if thread.get_ident() != self._thread_ident:
# if the signal is handled on another thread, we can add
# it normally (modulo the NullContext)
self.add_callback(callback, *args, **kwargs)
else:
# If we're on the IOLoop's thread, we cannot use
# the regular add_callback because it may deadlock on
# _callback_lock. Blindly insert into self._callbacks.
# This is safe because the GIL makes list.append atomic.
# One subtlety is that if the signal interrupted the
# _callback_lock block in IOLoop.start, we may modify
# either the old or new version of self._callbacks,
# but either way will work.
self._callbacks.append(functools.partial(
stack_context.wrap(callback), *args, **kwargs))
class _Timeout(object):
"""An IOLoop timeout, a UNIX timestamp and a callback"""
# Reduce memory overhead when there are lots of pending callbacks
__slots__ = ['deadline', 'callback', 'tiebreaker']
def __init__(self, deadline, callback, io_loop):
if not isinstance(deadline, numbers.Real):
raise TypeError("Unsupported deadline %r" % deadline)
self.deadline = deadline
self.callback = callback
self.tiebreaker = next(io_loop._timeout_counter)
# Comparison methods to sort by deadline, with object id as a tiebreaker
# to guarantee a consistent ordering. The heapq module uses __le__
# in python2.5, and __lt__ in 2.6+ (sort() and most other comparisons
# use __lt__).
def __lt__(self, other):
return ((self.deadline, self.tiebreaker) <
(other.deadline, other.tiebreaker))
def __le__(self, other):
return ((self.deadline, self.tiebreaker) <=
(other.deadline, other.tiebreaker))
class PeriodicCallback(object):
"""Schedules the given callback to be called periodically.
The callback is called every ``callback_time`` milliseconds.
`start` must be called after the `PeriodicCallback` is created.
"""
def __init__(self, callback, callback_time, io_loop=None):
self.callback = callback
if callback_time <= 0:
raise ValueError("Periodic callback must have a positive callback_time")
self.callback_time = callback_time
self.io_loop = io_loop or IOLoop.current()
self._running = False
self._timeout = None
def start(self):
"""Starts the timer."""
self._running = True
self._next_timeout = self.io_loop.time()
self._schedule_next()
def stop(self):
"""Stops the timer."""
self._running = False
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
self._timeout = None
def _run(self):
if not self._running:
return
try:
self.callback()
except Exception:
self.io_loop.handle_callback_exception(self.callback)
self._schedule_next()
def _schedule_next(self):
if self._running:
current_time = self.io_loop.time()
while self._next_timeout <= current_time:
self._next_timeout += self.callback_time / 1000.0
self._timeout = self.io_loop.add_timeout(self._next_timeout, self._run)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,511 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Translation methods for generating localized strings.
To load a locale and generate a translated string::
user_locale = tornado.locale.get("es_LA")
print user_locale.translate("Sign out")
`tornado.locale.get()` returns the closest matching locale, not necessarily the
specific locale you requested. You can support pluralization with
additional arguments to `~Locale.translate()`, e.g.::
people = [...]
message = user_locale.translate(
"%(list)s is online", "%(list)s are online", len(people))
print message % {"list": user_locale.list(people)}
The first string is chosen if ``len(people) == 1``, otherwise the second
string is chosen.
Applications should call one of `load_translations` (which uses a simple
CSV format) or `load_gettext_translations` (which uses the ``.mo`` format
supported by `gettext` and related tools). If neither method is called,
the `Locale.translate` method will simply return the original string.
"""
from __future__ import absolute_import, division, print_function, with_statement
import csv
import datetime
import numbers
import os
import re
from tornado import escape
from tornado.log import gen_log
from tornado.util import u
_default_locale = "en_US"
_translations = {}
_supported_locales = frozenset([_default_locale])
_use_gettext = False
def get(*locale_codes):
"""Returns the closest match for the given locale codes.
We iterate over all given locale codes in order. If we have a tight
or a loose match for the code (e.g., "en" for "en_US"), we return
the locale. Otherwise we move to the next code in the list.
By default we return ``en_US`` if no translations are found for any of
the specified locales. You can change the default locale with
`set_default_locale()`.
"""
return Locale.get_closest(*locale_codes)
def set_default_locale(code):
"""Sets the default locale.
The default locale is assumed to be the language used for all strings
in the system. The translations loaded from disk are mappings from
the default locale to the destination locale. Consequently, you don't
need to create a translation file for the default locale.
"""
global _default_locale
global _supported_locales
_default_locale = code
_supported_locales = frozenset(list(_translations.keys()) + [_default_locale])
def load_translations(directory):
"""Loads translations from CSV files in a directory.
Translations are strings with optional Python-style named placeholders
(e.g., ``My name is %(name)s``) and their associated translations.
The directory should have translation files of the form ``LOCALE.csv``,
e.g. ``es_GT.csv``. The CSV files should have two or three columns: string,
translation, and an optional plural indicator. Plural indicators should
be one of "plural" or "singular". A given string can have both singular
and plural forms. For example ``%(name)s liked this`` may have a
different verb conjugation depending on whether %(name)s is one
name or a list of names. There should be two rows in the CSV file for
that string, one with plural indicator "singular", and one "plural".
For strings with no verbs that would change on translation, simply
use "unknown" or the empty string (or don't include the column at all).
The file is read using the `csv` module in the default "excel" dialect.
In this format there should not be spaces after the commas.
Example translation ``es_LA.csv``::
"I love you","Te amo"
"%(name)s liked this","A %(name)s les gustó esto","plural"
"%(name)s liked this","A %(name)s le gustó esto","singular"
"""
global _translations
global _supported_locales
_translations = {}
for path in os.listdir(directory):
if not path.endswith(".csv"):
continue
locale, extension = path.split(".")
if not re.match("[a-z]+(_[A-Z]+)?$", locale):
gen_log.error("Unrecognized locale %r (path: %s)", locale,
os.path.join(directory, path))
continue
full_path = os.path.join(directory, path)
try:
# python 3: csv.reader requires a file open in text mode.
# Force utf8 to avoid dependence on $LANG environment variable.
f = open(full_path, "r", encoding="utf-8")
except TypeError:
# python 2: files return byte strings, which are decoded below.
f = open(full_path, "r")
_translations[locale] = {}
for i, row in enumerate(csv.reader(f)):
if not row or len(row) < 2:
continue
row = [escape.to_unicode(c).strip() for c in row]
english, translation = row[:2]
if len(row) > 2:
plural = row[2] or "unknown"
else:
plural = "unknown"
if plural not in ("plural", "singular", "unknown"):
gen_log.error("Unrecognized plural indicator %r in %s line %d",
plural, path, i + 1)
continue
_translations[locale].setdefault(plural, {})[english] = translation
f.close()
_supported_locales = frozenset(list(_translations.keys()) + [_default_locale])
gen_log.debug("Supported locales: %s", sorted(_supported_locales))
def load_gettext_translations(directory, domain):
"""Loads translations from `gettext`'s locale tree
Locale tree is similar to system's ``/usr/share/locale``, like::
{directory}/{lang}/LC_MESSAGES/{domain}.mo
Three steps are required to have you app translated:
1. Generate POT translation file::
xgettext --language=Python --keyword=_:1,2 -d mydomain file1.py file2.html etc
2. Merge against existing POT file::
msgmerge old.po mydomain.po > new.po
3. Compile::
msgfmt mydomain.po -o {directory}/pt_BR/LC_MESSAGES/mydomain.mo
"""
import gettext
global _translations
global _supported_locales
global _use_gettext
_translations = {}
for lang in os.listdir(directory):
if lang.startswith('.'):
continue # skip .svn, etc
if os.path.isfile(os.path.join(directory, lang)):
continue
try:
os.stat(os.path.join(directory, lang, "LC_MESSAGES", domain + ".mo"))
_translations[lang] = gettext.translation(domain, directory,
languages=[lang])
except Exception as e:
gen_log.error("Cannot load translation for '%s': %s", lang, str(e))
continue
_supported_locales = frozenset(list(_translations.keys()) + [_default_locale])
_use_gettext = True
gen_log.debug("Supported locales: %s", sorted(_supported_locales))
def get_supported_locales():
"""Returns a list of all the supported locale codes."""
return _supported_locales
class Locale(object):
"""Object representing a locale.
After calling one of `load_translations` or `load_gettext_translations`,
call `get` or `get_closest` to get a Locale object.
"""
@classmethod
def get_closest(cls, *locale_codes):
"""Returns the closest match for the given locale code."""
for code in locale_codes:
if not code:
continue
code = code.replace("-", "_")
parts = code.split("_")
if len(parts) > 2:
continue
elif len(parts) == 2:
code = parts[0].lower() + "_" + parts[1].upper()
if code in _supported_locales:
return cls.get(code)
if parts[0].lower() in _supported_locales:
return cls.get(parts[0].lower())
return cls.get(_default_locale)
@classmethod
def get(cls, code):
"""Returns the Locale for the given locale code.
If it is not supported, we raise an exception.
"""
if not hasattr(cls, "_cache"):
cls._cache = {}
if code not in cls._cache:
assert code in _supported_locales
translations = _translations.get(code, None)
if translations is None:
locale = CSVLocale(code, {})
elif _use_gettext:
locale = GettextLocale(code, translations)
else:
locale = CSVLocale(code, translations)
cls._cache[code] = locale
return cls._cache[code]
def __init__(self, code, translations):
self.code = code
self.name = LOCALE_NAMES.get(code, {}).get("name", u("Unknown"))
self.rtl = False
for prefix in ["fa", "ar", "he"]:
if self.code.startswith(prefix):
self.rtl = True
break
self.translations = translations
# Initialize strings for date formatting
_ = self.translate
self._months = [
_("January"), _("February"), _("March"), _("April"),
_("May"), _("June"), _("July"), _("August"),
_("September"), _("October"), _("November"), _("December")]
self._weekdays = [
_("Monday"), _("Tuesday"), _("Wednesday"), _("Thursday"),
_("Friday"), _("Saturday"), _("Sunday")]
def translate(self, message, plural_message=None, count=None):
"""Returns the translation for the given message for this locale.
If ``plural_message`` is given, you must also provide
``count``. We return ``plural_message`` when ``count != 1``,
and we return the singular form for the given message when
``count == 1``.
"""
raise NotImplementedError()
def format_date(self, date, gmt_offset=0, relative=True, shorter=False,
full_format=False):
"""Formats the given date (which should be GMT).
By default, we return a relative time (e.g., "2 minutes ago"). You
can return an absolute date string with ``relative=False``.
You can force a full format date ("July 10, 1980") with
``full_format=True``.
This method is primarily intended for dates in the past.
For dates in the future, we fall back to full format.
"""
if isinstance(date, numbers.Real):
date = datetime.datetime.utcfromtimestamp(date)
now = datetime.datetime.utcnow()
if date > now:
if relative and (date - now).seconds < 60:
# Due to click skew, things are some things slightly
# in the future. Round timestamps in the immediate
# future down to now in relative mode.
date = now
else:
# Otherwise, future dates always use the full format.
full_format = True
local_date = date - datetime.timedelta(minutes=gmt_offset)
local_now = now - datetime.timedelta(minutes=gmt_offset)
local_yesterday = local_now - datetime.timedelta(hours=24)
difference = now - date
seconds = difference.seconds
days = difference.days
_ = self.translate
format = None
if not full_format:
if relative and days == 0:
if seconds < 50:
return _("1 second ago", "%(seconds)d seconds ago",
seconds) % {"seconds": seconds}
if seconds < 50 * 60:
minutes = round(seconds / 60.0)
return _("1 minute ago", "%(minutes)d minutes ago",
minutes) % {"minutes": minutes}
hours = round(seconds / (60.0 * 60))
return _("1 hour ago", "%(hours)d hours ago",
hours) % {"hours": hours}
if days == 0:
format = _("%(time)s")
elif days == 1 and local_date.day == local_yesterday.day and \
relative:
format = _("yesterday") if shorter else \
_("yesterday at %(time)s")
elif days < 5:
format = _("%(weekday)s") if shorter else \
_("%(weekday)s at %(time)s")
elif days < 334: # 11mo, since confusing for same month last year
format = _("%(month_name)s %(day)s") if shorter else \
_("%(month_name)s %(day)s at %(time)s")
if format is None:
format = _("%(month_name)s %(day)s, %(year)s") if shorter else \
_("%(month_name)s %(day)s, %(year)s at %(time)s")
tfhour_clock = self.code not in ("en", "en_US", "zh_CN")
if tfhour_clock:
str_time = "%d:%02d" % (local_date.hour, local_date.minute)
elif self.code == "zh_CN":
str_time = "%s%d:%02d" % (
(u('\u4e0a\u5348'), u('\u4e0b\u5348'))[local_date.hour >= 12],
local_date.hour % 12 or 12, local_date.minute)
else:
str_time = "%d:%02d %s" % (
local_date.hour % 12 or 12, local_date.minute,
("am", "pm")[local_date.hour >= 12])
return format % {
"month_name": self._months[local_date.month - 1],
"weekday": self._weekdays[local_date.weekday()],
"day": str(local_date.day),
"year": str(local_date.year),
"time": str_time
}
def format_day(self, date, gmt_offset=0, dow=True):
"""Formats the given date as a day of week.
Example: "Monday, January 22". You can remove the day of week with
``dow=False``.
"""
local_date = date - datetime.timedelta(minutes=gmt_offset)
_ = self.translate
if dow:
return _("%(weekday)s, %(month_name)s %(day)s") % {
"month_name": self._months[local_date.month - 1],
"weekday": self._weekdays[local_date.weekday()],
"day": str(local_date.day),
}
else:
return _("%(month_name)s %(day)s") % {
"month_name": self._months[local_date.month - 1],
"day": str(local_date.day),
}
def list(self, parts):
"""Returns a comma-separated list for the given list of parts.
The format is, e.g., "A, B and C", "A and B" or just "A" for lists
of size 1.
"""
_ = self.translate
if len(parts) == 0:
return ""
if len(parts) == 1:
return parts[0]
comma = u(' \u0648 ') if self.code.startswith("fa") else u(", ")
return _("%(commas)s and %(last)s") % {
"commas": comma.join(parts[:-1]),
"last": parts[len(parts) - 1],
}
def friendly_number(self, value):
"""Returns a comma-separated number for the given integer."""
if self.code not in ("en", "en_US"):
return str(value)
value = str(value)
parts = []
while value:
parts.append(value[-3:])
value = value[:-3]
return ",".join(reversed(parts))
class CSVLocale(Locale):
"""Locale implementation using tornado's CSV translation format."""
def translate(self, message, plural_message=None, count=None):
if plural_message is not None:
assert count is not None
if count != 1:
message = plural_message
message_dict = self.translations.get("plural", {})
else:
message_dict = self.translations.get("singular", {})
else:
message_dict = self.translations.get("unknown", {})
return message_dict.get(message, message)
class GettextLocale(Locale):
"""Locale implementation using the `gettext` module."""
def __init__(self, code, translations):
try:
# python 2
self.ngettext = translations.ungettext
self.gettext = translations.ugettext
except AttributeError:
# python 3
self.ngettext = translations.ngettext
self.gettext = translations.gettext
# self.gettext must exist before __init__ is called, since it
# calls into self.translate
super(GettextLocale, self).__init__(code, translations)
def translate(self, message, plural_message=None, count=None):
if plural_message is not None:
assert count is not None
return self.ngettext(message, plural_message, count)
else:
return self.gettext(message)
LOCALE_NAMES = {
"af_ZA": {"name_en": u("Afrikaans"), "name": u("Afrikaans")},
"am_ET": {"name_en": u("Amharic"), "name": u('\u12a0\u121b\u122d\u129b')},
"ar_AR": {"name_en": u("Arabic"), "name": u("\u0627\u0644\u0639\u0631\u0628\u064a\u0629")},
"bg_BG": {"name_en": u("Bulgarian"), "name": u("\u0411\u044a\u043b\u0433\u0430\u0440\u0441\u043a\u0438")},
"bn_IN": {"name_en": u("Bengali"), "name": u("\u09ac\u09be\u0982\u09b2\u09be")},
"bs_BA": {"name_en": u("Bosnian"), "name": u("Bosanski")},
"ca_ES": {"name_en": u("Catalan"), "name": u("Catal\xe0")},
"cs_CZ": {"name_en": u("Czech"), "name": u("\u010ce\u0161tina")},
"cy_GB": {"name_en": u("Welsh"), "name": u("Cymraeg")},
"da_DK": {"name_en": u("Danish"), "name": u("Dansk")},
"de_DE": {"name_en": u("German"), "name": u("Deutsch")},
"el_GR": {"name_en": u("Greek"), "name": u("\u0395\u03bb\u03bb\u03b7\u03bd\u03b9\u03ba\u03ac")},
"en_GB": {"name_en": u("English (UK)"), "name": u("English (UK)")},
"en_US": {"name_en": u("English (US)"), "name": u("English (US)")},
"es_ES": {"name_en": u("Spanish (Spain)"), "name": u("Espa\xf1ol (Espa\xf1a)")},
"es_LA": {"name_en": u("Spanish"), "name": u("Espa\xf1ol")},
"et_EE": {"name_en": u("Estonian"), "name": u("Eesti")},
"eu_ES": {"name_en": u("Basque"), "name": u("Euskara")},
"fa_IR": {"name_en": u("Persian"), "name": u("\u0641\u0627\u0631\u0633\u06cc")},
"fi_FI": {"name_en": u("Finnish"), "name": u("Suomi")},
"fr_CA": {"name_en": u("French (Canada)"), "name": u("Fran\xe7ais (Canada)")},
"fr_FR": {"name_en": u("French"), "name": u("Fran\xe7ais")},
"ga_IE": {"name_en": u("Irish"), "name": u("Gaeilge")},
"gl_ES": {"name_en": u("Galician"), "name": u("Galego")},
"he_IL": {"name_en": u("Hebrew"), "name": u("\u05e2\u05d1\u05e8\u05d9\u05ea")},
"hi_IN": {"name_en": u("Hindi"), "name": u("\u0939\u093f\u0928\u094d\u0926\u0940")},
"hr_HR": {"name_en": u("Croatian"), "name": u("Hrvatski")},
"hu_HU": {"name_en": u("Hungarian"), "name": u("Magyar")},
"id_ID": {"name_en": u("Indonesian"), "name": u("Bahasa Indonesia")},
"is_IS": {"name_en": u("Icelandic"), "name": u("\xcdslenska")},
"it_IT": {"name_en": u("Italian"), "name": u("Italiano")},
"ja_JP": {"name_en": u("Japanese"), "name": u("\u65e5\u672c\u8a9e")},
"ko_KR": {"name_en": u("Korean"), "name": u("\ud55c\uad6d\uc5b4")},
"lt_LT": {"name_en": u("Lithuanian"), "name": u("Lietuvi\u0173")},
"lv_LV": {"name_en": u("Latvian"), "name": u("Latvie\u0161u")},
"mk_MK": {"name_en": u("Macedonian"), "name": u("\u041c\u0430\u043a\u0435\u0434\u043e\u043d\u0441\u043a\u0438")},
"ml_IN": {"name_en": u("Malayalam"), "name": u("\u0d2e\u0d32\u0d2f\u0d3e\u0d33\u0d02")},
"ms_MY": {"name_en": u("Malay"), "name": u("Bahasa Melayu")},
"nb_NO": {"name_en": u("Norwegian (bokmal)"), "name": u("Norsk (bokm\xe5l)")},
"nl_NL": {"name_en": u("Dutch"), "name": u("Nederlands")},
"nn_NO": {"name_en": u("Norwegian (nynorsk)"), "name": u("Norsk (nynorsk)")},
"pa_IN": {"name_en": u("Punjabi"), "name": u("\u0a2a\u0a70\u0a1c\u0a3e\u0a2c\u0a40")},
"pl_PL": {"name_en": u("Polish"), "name": u("Polski")},
"pt_BR": {"name_en": u("Portuguese (Brazil)"), "name": u("Portugu\xeas (Brasil)")},
"pt_PT": {"name_en": u("Portuguese (Portugal)"), "name": u("Portugu\xeas (Portugal)")},
"ro_RO": {"name_en": u("Romanian"), "name": u("Rom\xe2n\u0103")},
"ru_RU": {"name_en": u("Russian"), "name": u("\u0420\u0443\u0441\u0441\u043a\u0438\u0439")},
"sk_SK": {"name_en": u("Slovak"), "name": u("Sloven\u010dina")},
"sl_SI": {"name_en": u("Slovenian"), "name": u("Sloven\u0161\u010dina")},
"sq_AL": {"name_en": u("Albanian"), "name": u("Shqip")},
"sr_RS": {"name_en": u("Serbian"), "name": u("\u0421\u0440\u043f\u0441\u043a\u0438")},
"sv_SE": {"name_en": u("Swedish"), "name": u("Svenska")},
"sw_KE": {"name_en": u("Swahili"), "name": u("Kiswahili")},
"ta_IN": {"name_en": u("Tamil"), "name": u("\u0ba4\u0bae\u0bbf\u0bb4\u0bcd")},
"te_IN": {"name_en": u("Telugu"), "name": u("\u0c24\u0c46\u0c32\u0c41\u0c17\u0c41")},
"th_TH": {"name_en": u("Thai"), "name": u("\u0e20\u0e32\u0e29\u0e32\u0e44\u0e17\u0e22")},
"tl_PH": {"name_en": u("Filipino"), "name": u("Filipino")},
"tr_TR": {"name_en": u("Turkish"), "name": u("T\xfcrk\xe7e")},
"uk_UA": {"name_en": u("Ukraini "), "name": u("\u0423\u043a\u0440\u0430\u0457\u043d\u0441\u044c\u043a\u0430")},
"vi_VN": {"name_en": u("Vietnamese"), "name": u("Ti\u1ebfng Vi\u1ec7t")},
"zh_CN": {"name_en": u("Chinese (Simplified)"), "name": u("\u4e2d\u6587(\u7b80\u4f53)")},
"zh_TW": {"name_en": u("Chinese (Traditional)"), "name": u("\u4e2d\u6587(\u7e41\u9ad4)")},
}

View file

@ -0,0 +1,230 @@
#!/usr/bin/env python
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Logging support for Tornado.
Tornado uses three logger streams:
* ``tornado.access``: Per-request logging for Tornado's HTTP servers (and
potentially other servers in the future)
* ``tornado.application``: Logging of errors from application code (i.e.
uncaught exceptions from callbacks)
* ``tornado.general``: General-purpose logging, including any errors
or warnings from Tornado itself.
These streams may be configured independently using the standard library's
`logging` module. For example, you may wish to send ``tornado.access`` logs
to a separate file for analysis.
"""
from __future__ import absolute_import, division, print_function, with_statement
import logging
import logging.handlers
import sys
from tornado.escape import _unicode
from tornado.util import unicode_type, basestring_type
try:
import curses
except ImportError:
curses = None
# Logger objects for internal tornado use
access_log = logging.getLogger("tornado.access")
app_log = logging.getLogger("tornado.application")
gen_log = logging.getLogger("tornado.general")
def _stderr_supports_color():
color = False
if curses and hasattr(sys.stderr, 'isatty') and sys.stderr.isatty():
try:
curses.setupterm()
if curses.tigetnum("colors") > 0:
color = True
except Exception:
pass
return color
def _safe_unicode(s):
try:
return _unicode(s)
except UnicodeDecodeError:
return repr(s)
class LogFormatter(logging.Formatter):
"""Log formatter used in Tornado.
Key features of this formatter are:
* Color support when logging to a terminal that supports it.
* Timestamps on every log line.
* Robust against str/bytes encoding problems.
This formatter is enabled automatically by
`tornado.options.parse_command_line` (unless ``--logging=none`` is
used).
"""
DEFAULT_FORMAT = '%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s'
DEFAULT_DATE_FORMAT = '%y%m%d %H:%M:%S'
DEFAULT_COLORS = {
logging.DEBUG: 4, # Blue
logging.INFO: 2, # Green
logging.WARNING: 3, # Yellow
logging.ERROR: 1, # Red
}
def __init__(self, color=True, fmt=DEFAULT_FORMAT,
datefmt=DEFAULT_DATE_FORMAT, colors=DEFAULT_COLORS):
r"""
:arg bool color: Enables color support.
:arg string fmt: Log message format.
It will be applied to the attributes dict of log records. The
text between ``%(color)s`` and ``%(end_color)s`` will be colored
depending on the level if color support is on.
:arg dict colors: color mappings from logging level to terminal color
code
:arg string datefmt: Datetime format.
Used for formatting ``(asctime)`` placeholder in ``prefix_fmt``.
.. versionchanged:: 3.2
Added ``fmt`` and ``datefmt`` arguments.
"""
logging.Formatter.__init__(self, datefmt=datefmt)
self._fmt = fmt
self._colors = {}
if color and _stderr_supports_color():
# The curses module has some str/bytes confusion in
# python3. Until version 3.2.3, most methods return
# bytes, but only accept strings. In addition, we want to
# output these strings with the logging module, which
# works with unicode strings. The explicit calls to
# unicode() below are harmless in python2 but will do the
# right conversion in python 3.
fg_color = (curses.tigetstr("setaf") or
curses.tigetstr("setf") or "")
if (3, 0) < sys.version_info < (3, 2, 3):
fg_color = unicode_type(fg_color, "ascii")
for levelno, code in colors.items():
self._colors[levelno] = unicode_type(curses.tparm(fg_color, code), "ascii")
self._normal = unicode_type(curses.tigetstr("sgr0"), "ascii")
else:
self._normal = ''
def format(self, record):
try:
message = record.getMessage()
assert isinstance(message, basestring_type) # guaranteed by logging
# Encoding notes: The logging module prefers to work with character
# strings, but only enforces that log messages are instances of
# basestring. In python 2, non-ascii bytestrings will make
# their way through the logging framework until they blow up with
# an unhelpful decoding error (with this formatter it happens
# when we attach the prefix, but there are other opportunities for
# exceptions further along in the framework).
#
# If a byte string makes it this far, convert it to unicode to
# ensure it will make it out to the logs. Use repr() as a fallback
# to ensure that all byte strings can be converted successfully,
# but don't do it by default so we don't add extra quotes to ascii
# bytestrings. This is a bit of a hacky place to do this, but
# it's worth it since the encoding errors that would otherwise
# result are so useless (and tornado is fond of using utf8-encoded
# byte strings whereever possible).
record.message = _safe_unicode(message)
except Exception as e:
record.message = "Bad message (%r): %r" % (e, record.__dict__)
record.asctime = self.formatTime(record, self.datefmt)
if record.levelno in self._colors:
record.color = self._colors[record.levelno]
record.end_color = self._normal
else:
record.color = record.end_color = ''
formatted = self._fmt % record.__dict__
if record.exc_info:
if not record.exc_text:
record.exc_text = self.formatException(record.exc_info)
if record.exc_text:
# exc_text contains multiple lines. We need to _safe_unicode
# each line separately so that non-utf8 bytes don't cause
# all the newlines to turn into '\n'.
lines = [formatted.rstrip()]
lines.extend(_safe_unicode(ln) for ln in record.exc_text.split('\n'))
formatted = '\n'.join(lines)
return formatted.replace("\n", "\n ")
def enable_pretty_logging(options=None, logger=None):
"""Turns on formatted logging output as configured.
This is called automatically by `tornado.options.parse_command_line`
and `tornado.options.parse_config_file`.
"""
if options is None:
from tornado.options import options
if options.logging is None or options.logging.lower() == 'none':
return
if logger is None:
logger = logging.getLogger()
logger.setLevel(getattr(logging, options.logging.upper()))
if options.log_file_prefix:
channel = logging.handlers.RotatingFileHandler(
filename=options.log_file_prefix,
maxBytes=options.log_file_max_size,
backupCount=options.log_file_num_backups)
channel.setFormatter(LogFormatter(color=False))
logger.addHandler(channel)
if (options.log_to_stderr or
(options.log_to_stderr is None and not logger.handlers)):
# Set up color if we are in a tty and curses is installed
channel = logging.StreamHandler()
channel.setFormatter(LogFormatter())
logger.addHandler(channel)
def define_logging_options(options=None):
if options is None:
# late import to prevent cycle
from tornado.options import options
options.define("logging", default="info",
help=("Set the Python log level. If 'none', tornado won't touch the "
"logging configuration."),
metavar="debug|info|warning|error|none")
options.define("log_to_stderr", type=bool, default=None,
help=("Send log output to stderr (colorized if possible). "
"By default use stderr if --log_file_prefix is not set and "
"no other logging is configured."))
options.define("log_file_prefix", type=str, default=None, metavar="PATH",
help=("Path prefix for log files. "
"Note that if you are running multiple tornado processes, "
"log_file_prefix must be different for each of them (e.g. "
"include the port number)"))
options.define("log_file_max_size", type=int, default=100 * 1000 * 1000,
help="max size of log files before rollover")
options.define("log_file_num_backups", type=int, default=10,
help="number of log files to keep")
options.add_parse_callback(enable_pretty_logging)

View file

@ -0,0 +1,445 @@
#!/usr/bin/env python
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Miscellaneous network utility code."""
from __future__ import absolute_import, division, print_function, with_statement
import errno
import os
import platform
import socket
import stat
from tornado.concurrent import dummy_executor, run_on_executor
from tornado.ioloop import IOLoop
from tornado.platform.auto import set_close_exec
from tornado.util import u, Configurable, errno_from_exception
try:
import ssl
except ImportError:
# ssl is not available on Google App Engine
ssl = None
if hasattr(ssl, 'match_hostname') and hasattr(ssl, 'CertificateError'): # python 3.2+
ssl_match_hostname = ssl.match_hostname
SSLCertificateError = ssl.CertificateError
elif ssl is None:
ssl_match_hostname = SSLCertificateError = None
else:
import backports.ssl_match_hostname
ssl_match_hostname = backports.ssl_match_hostname.match_hostname
SSLCertificateError = backports.ssl_match_hostname.CertificateError
# ThreadedResolver runs getaddrinfo on a thread. If the hostname is unicode,
# getaddrinfo attempts to import encodings.idna. If this is done at
# module-import time, the import lock is already held by the main thread,
# leading to deadlock. Avoid it by caching the idna encoder on the main
# thread now.
u('foo').encode('idna')
# These errnos indicate that a non-blocking operation must be retried
# at a later time. On most platforms they're the same value, but on
# some they differ.
_ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN)
if hasattr(errno, "WSAEWOULDBLOCK"):
_ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,)
def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128, flags=None):
"""Creates listening sockets bound to the given port and address.
Returns a list of socket objects (multiple sockets are returned if
the given address maps to multiple IP addresses, which is most common
for mixed IPv4 and IPv6 use).
Address may be either an IP address or hostname. If it's a hostname,
the server will listen on all IP addresses associated with the
name. Address may be an empty string or None to listen on all
available interfaces. Family may be set to either `socket.AF_INET`
or `socket.AF_INET6` to restrict to IPv4 or IPv6 addresses, otherwise
both will be used if available.
The ``backlog`` argument has the same meaning as for
`socket.listen() <socket.socket.listen>`.
``flags`` is a bitmask of AI_* flags to `~socket.getaddrinfo`, like
``socket.AI_PASSIVE | socket.AI_NUMERICHOST``.
"""
sockets = []
if address == "":
address = None
if not socket.has_ipv6 and family == socket.AF_UNSPEC:
# Python can be compiled with --disable-ipv6, which causes
# operations on AF_INET6 sockets to fail, but does not
# automatically exclude those results from getaddrinfo
# results.
# http://bugs.python.org/issue16208
family = socket.AF_INET
if flags is None:
flags = socket.AI_PASSIVE
bound_port = None
for res in set(socket.getaddrinfo(address, port, family, socket.SOCK_STREAM,
0, flags)):
af, socktype, proto, canonname, sockaddr = res
if (platform.system() == 'Darwin' and address == 'localhost' and
af == socket.AF_INET6 and sockaddr[3] != 0):
# Mac OS X includes a link-local address fe80::1%lo0 in the
# getaddrinfo results for 'localhost'. However, the firewall
# doesn't understand that this is a local address and will
# prompt for access (often repeatedly, due to an apparent
# bug in its ability to remember granting access to an
# application). Skip these addresses.
continue
try:
sock = socket.socket(af, socktype, proto)
except socket.error as e:
if errno_from_exception(e) == errno.EAFNOSUPPORT:
continue
raise
set_close_exec(sock.fileno())
if os.name != 'nt':
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if af == socket.AF_INET6:
# On linux, ipv6 sockets accept ipv4 too by default,
# but this makes it impossible to bind to both
# 0.0.0.0 in ipv4 and :: in ipv6. On other systems,
# separate sockets *must* be used to listen for both ipv4
# and ipv6. For consistency, always disable ipv4 on our
# ipv6 sockets and use a separate ipv4 socket when needed.
#
# Python 2.x on windows doesn't have IPPROTO_IPV6.
if hasattr(socket, "IPPROTO_IPV6"):
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
# automatic port allocation with port=None
# should bind on the same port on IPv4 and IPv6
host, requested_port = sockaddr[:2]
if requested_port == 0 and bound_port is not None:
sockaddr = tuple([host, bound_port] + list(sockaddr[2:]))
sock.setblocking(0)
sock.bind(sockaddr)
bound_port = sock.getsockname()[1]
sock.listen(backlog)
sockets.append(sock)
return sockets
if hasattr(socket, 'AF_UNIX'):
def bind_unix_socket(file, mode=0o600, backlog=128):
"""Creates a listening unix socket.
If a socket with the given name already exists, it will be deleted.
If any other file with that name exists, an exception will be
raised.
Returns a socket object (not a list of socket objects like
`bind_sockets`)
"""
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
set_close_exec(sock.fileno())
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setblocking(0)
try:
st = os.stat(file)
except OSError as err:
if errno_from_exception(err) != errno.ENOENT:
raise
else:
if stat.S_ISSOCK(st.st_mode):
os.remove(file)
else:
raise ValueError("File %s exists and is not a socket", file)
sock.bind(file)
os.chmod(file, mode)
sock.listen(backlog)
return sock
def add_accept_handler(sock, callback, io_loop=None):
"""Adds an `.IOLoop` event handler to accept new connections on ``sock``.
When a connection is accepted, ``callback(connection, address)`` will
be run (``connection`` is a socket object, and ``address`` is the
address of the other end of the connection). Note that this signature
is different from the ``callback(fd, events)`` signature used for
`.IOLoop` handlers.
"""
if io_loop is None:
io_loop = IOLoop.current()
def accept_handler(fd, events):
while True:
try:
connection, address = sock.accept()
except socket.error as e:
# _ERRNO_WOULDBLOCK indicate we have accepted every
# connection that is available.
if errno_from_exception(e) in _ERRNO_WOULDBLOCK:
return
# ECONNABORTED indicates that there was a connection
# but it was closed while still in the accept queue.
# (observed on FreeBSD).
if errno_from_exception(e) == errno.ECONNABORTED:
continue
raise
callback(connection, address)
io_loop.add_handler(sock, accept_handler, IOLoop.READ)
def is_valid_ip(ip):
"""Returns true if the given string is a well-formed IP address.
Supports IPv4 and IPv6.
"""
if not ip or '\x00' in ip:
# getaddrinfo resolves empty strings to localhost, and truncates
# on zero bytes.
return False
try:
res = socket.getaddrinfo(ip, 0, socket.AF_UNSPEC,
socket.SOCK_STREAM,
0, socket.AI_NUMERICHOST)
return bool(res)
except socket.gaierror as e:
if e.args[0] == socket.EAI_NONAME:
return False
raise
return True
class Resolver(Configurable):
"""Configurable asynchronous DNS resolver interface.
By default, a blocking implementation is used (which simply calls
`socket.getaddrinfo`). An alternative implementation can be
chosen with the `Resolver.configure <.Configurable.configure>`
class method::
Resolver.configure('tornado.netutil.ThreadedResolver')
The implementations of this interface included with Tornado are
* `tornado.netutil.BlockingResolver`
* `tornado.netutil.ThreadedResolver`
* `tornado.netutil.OverrideResolver`
* `tornado.platform.twisted.TwistedResolver`
* `tornado.platform.caresresolver.CaresResolver`
"""
@classmethod
def configurable_base(cls):
return Resolver
@classmethod
def configurable_default(cls):
return BlockingResolver
def resolve(self, host, port, family=socket.AF_UNSPEC, callback=None):
"""Resolves an address.
The ``host`` argument is a string which may be a hostname or a
literal IP address.
Returns a `.Future` whose result is a list of (family,
address) pairs, where address is a tuple suitable to pass to
`socket.connect <socket.socket.connect>` (i.e. a ``(host,
port)`` pair for IPv4; additional fields may be present for
IPv6). If a ``callback`` is passed, it will be run with the
result as an argument when it is complete.
"""
raise NotImplementedError()
def close(self):
"""Closes the `Resolver`, freeing any resources used.
.. versionadded:: 3.1
"""
pass
class ExecutorResolver(Resolver):
"""Resolver implementation using a `concurrent.futures.Executor`.
Use this instead of `ThreadedResolver` when you require additional
control over the executor being used.
The executor will be shut down when the resolver is closed unless
``close_resolver=False``; use this if you want to reuse the same
executor elsewhere.
"""
def initialize(self, io_loop=None, executor=None, close_executor=True):
self.io_loop = io_loop or IOLoop.current()
if executor is not None:
self.executor = executor
self.close_executor = close_executor
else:
self.executor = dummy_executor
self.close_executor = False
def close(self):
if self.close_executor:
self.executor.shutdown()
self.executor = None
@run_on_executor
def resolve(self, host, port, family=socket.AF_UNSPEC):
# On Solaris, getaddrinfo fails if the given port is not found
# in /etc/services and no socket type is given, so we must pass
# one here. The socket type used here doesn't seem to actually
# matter (we discard the one we get back in the results),
# so the addresses we return should still be usable with SOCK_DGRAM.
addrinfo = socket.getaddrinfo(host, port, family, socket.SOCK_STREAM)
results = []
for family, socktype, proto, canonname, address in addrinfo:
results.append((family, address))
return results
class BlockingResolver(ExecutorResolver):
"""Default `Resolver` implementation, using `socket.getaddrinfo`.
The `.IOLoop` will be blocked during the resolution, although the
callback will not be run until the next `.IOLoop` iteration.
"""
def initialize(self, io_loop=None):
super(BlockingResolver, self).initialize(io_loop=io_loop)
class ThreadedResolver(ExecutorResolver):
"""Multithreaded non-blocking `Resolver` implementation.
Requires the `concurrent.futures` package to be installed
(available in the standard library since Python 3.2,
installable with ``pip install futures`` in older versions).
The thread pool size can be configured with::
Resolver.configure('tornado.netutil.ThreadedResolver',
num_threads=10)
.. versionchanged:: 3.1
All ``ThreadedResolvers`` share a single thread pool, whose
size is set by the first one to be created.
"""
_threadpool = None
_threadpool_pid = None
def initialize(self, io_loop=None, num_threads=10):
threadpool = ThreadedResolver._create_threadpool(num_threads)
super(ThreadedResolver, self).initialize(
io_loop=io_loop, executor=threadpool, close_executor=False)
@classmethod
def _create_threadpool(cls, num_threads):
pid = os.getpid()
if cls._threadpool_pid != pid:
# Threads cannot survive after a fork, so if our pid isn't what it
# was when we created the pool then delete it.
cls._threadpool = None
if cls._threadpool is None:
from concurrent.futures import ThreadPoolExecutor
cls._threadpool = ThreadPoolExecutor(num_threads)
cls._threadpool_pid = pid
return cls._threadpool
class OverrideResolver(Resolver):
"""Wraps a resolver with a mapping of overrides.
This can be used to make local DNS changes (e.g. for testing)
without modifying system-wide settings.
The mapping can contain either host strings or host-port pairs.
"""
def initialize(self, resolver, mapping):
self.resolver = resolver
self.mapping = mapping
def close(self):
self.resolver.close()
def resolve(self, host, port, *args, **kwargs):
if (host, port) in self.mapping:
host, port = self.mapping[(host, port)]
elif host in self.mapping:
host = self.mapping[host]
return self.resolver.resolve(host, port, *args, **kwargs)
# These are the keyword arguments to ssl.wrap_socket that must be translated
# to their SSLContext equivalents (the other arguments are still passed
# to SSLContext.wrap_socket).
_SSL_CONTEXT_KEYWORDS = frozenset(['ssl_version', 'certfile', 'keyfile',
'cert_reqs', 'ca_certs', 'ciphers'])
def ssl_options_to_context(ssl_options):
"""Try to convert an ``ssl_options`` dictionary to an
`~ssl.SSLContext` object.
The ``ssl_options`` dictionary contains keywords to be passed to
`ssl.wrap_socket`. In Python 3.2+, `ssl.SSLContext` objects can
be used instead. This function converts the dict form to its
`~ssl.SSLContext` equivalent, and may be used when a component which
accepts both forms needs to upgrade to the `~ssl.SSLContext` version
to use features like SNI or NPN.
"""
if isinstance(ssl_options, dict):
assert all(k in _SSL_CONTEXT_KEYWORDS for k in ssl_options), ssl_options
if (not hasattr(ssl, 'SSLContext') or
isinstance(ssl_options, ssl.SSLContext)):
return ssl_options
context = ssl.SSLContext(
ssl_options.get('ssl_version', ssl.PROTOCOL_SSLv23))
if 'certfile' in ssl_options:
context.load_cert_chain(ssl_options['certfile'], ssl_options.get('keyfile', None))
if 'cert_reqs' in ssl_options:
context.verify_mode = ssl_options['cert_reqs']
if 'ca_certs' in ssl_options:
context.load_verify_locations(ssl_options['ca_certs'])
if 'ciphers' in ssl_options:
context.set_ciphers(ssl_options['ciphers'])
if hasattr(ssl, 'OP_NO_COMPRESSION'):
# Disable TLS compression to avoid CRIME and related attacks.
# This constant wasn't added until python 3.3.
context.options |= ssl.OP_NO_COMPRESSION
return context
def ssl_wrap_socket(socket, ssl_options, server_hostname=None, **kwargs):
"""Returns an ``ssl.SSLSocket`` wrapping the given socket.
``ssl_options`` may be either a dictionary (as accepted by
`ssl_options_to_context`) or an `ssl.SSLContext` object.
Additional keyword arguments are passed to ``wrap_socket``
(either the `~ssl.SSLContext` method or the `ssl` module function
as appropriate).
"""
context = ssl_options_to_context(ssl_options)
if hasattr(ssl, 'SSLContext') and isinstance(context, ssl.SSLContext):
if server_hostname is not None and getattr(ssl, 'HAS_SNI'):
# Python doesn't have server-side SNI support so we can't
# really unittest this, but it can be manually tested with
# python3.2 -m tornado.httpclient https://sni.velox.ch
return context.wrap_socket(socket, server_hostname=server_hostname,
**kwargs)
else:
return context.wrap_socket(socket, **kwargs)
else:
return ssl.wrap_socket(socket, **dict(context, **kwargs))

View file

@ -0,0 +1,553 @@
#!/usr/bin/env python
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""A command line parsing module that lets modules define their own options.
Each module defines its own options which are added to the global
option namespace, e.g.::
from tornado.options import define, options
define("mysql_host", default="127.0.0.1:3306", help="Main user DB")
define("memcache_hosts", default="127.0.0.1:11011", multiple=True,
help="Main user memcache servers")
def connect():
db = database.Connection(options.mysql_host)
...
The ``main()`` method of your application does not need to be aware of all of
the options used throughout your program; they are all automatically loaded
when the modules are loaded. However, all modules that define options
must have been imported before the command line is parsed.
Your ``main()`` method can parse the command line or parse a config file with
either::
tornado.options.parse_command_line()
# or
tornado.options.parse_config_file("/etc/server.conf")
Command line formats are what you would expect (``--myoption=myvalue``).
Config files are just Python files. Global names become options, e.g.::
myoption = "myvalue"
myotheroption = "myothervalue"
We support `datetimes <datetime.datetime>`, `timedeltas
<datetime.timedelta>`, ints, and floats (just pass a ``type`` kwarg to
`define`). We also accept multi-value options. See the documentation for
`define()` below.
`tornado.options.options` is a singleton instance of `OptionParser`, and
the top-level functions in this module (`define`, `parse_command_line`, etc)
simply call methods on it. You may create additional `OptionParser`
instances to define isolated sets of options, such as for subcommands.
.. note::
By default, several options are defined that will configure the
standard `logging` module when `parse_command_line` or `parse_config_file`
are called. If you want Tornado to leave the logging configuration
alone so you can manage it yourself, either pass ``--logging=none``
on the command line or do the following to disable it in code::
from tornado.options import options, parse_command_line
options.logging = None
parse_command_line()
"""
from __future__ import absolute_import, division, print_function, with_statement
import datetime
import numbers
import re
import sys
import os
import textwrap
from tornado.escape import _unicode
from tornado.log import define_logging_options
from tornado import stack_context
from tornado.util import basestring_type, exec_in
class Error(Exception):
"""Exception raised by errors in the options module."""
pass
class OptionParser(object):
"""A collection of options, a dictionary with object-like access.
Normally accessed via static functions in the `tornado.options` module,
which reference a global instance.
"""
def __init__(self):
# we have to use self.__dict__ because we override setattr.
self.__dict__['_options'] = {}
self.__dict__['_parse_callbacks'] = []
self.define("help", type=bool, help="show this help information",
callback=self._help_callback)
def __getattr__(self, name):
if isinstance(self._options.get(name), _Option):
return self._options[name].value()
raise AttributeError("Unrecognized option %r" % name)
def __setattr__(self, name, value):
if isinstance(self._options.get(name), _Option):
return self._options[name].set(value)
raise AttributeError("Unrecognized option %r" % name)
def __iter__(self):
return iter(self._options)
def __getitem__(self, item):
return self._options[item].value()
def items(self):
"""A sequence of (name, value) pairs.
.. versionadded:: 3.1
"""
return [(name, opt.value()) for name, opt in self._options.items()]
def groups(self):
"""The set of option-groups created by ``define``.
.. versionadded:: 3.1
"""
return set(opt.group_name for opt in self._options.values())
def group_dict(self, group):
"""The names and values of options in a group.
Useful for copying options into Application settings::
from tornado.options import define, parse_command_line, options
define('template_path', group='application')
define('static_path', group='application')
parse_command_line()
application = Application(
handlers, **options.group_dict('application'))
.. versionadded:: 3.1
"""
return dict(
(name, opt.value()) for name, opt in self._options.items()
if not group or group == opt.group_name)
def as_dict(self):
"""The names and values of all options.
.. versionadded:: 3.1
"""
return dict(
(name, opt.value()) for name, opt in self._options.items())
def define(self, name, default=None, type=None, help=None, metavar=None,
multiple=False, group=None, callback=None):
"""Defines a new command line option.
If ``type`` is given (one of str, float, int, datetime, or timedelta)
or can be inferred from the ``default``, we parse the command line
arguments based on the given type. If ``multiple`` is True, we accept
comma-separated values, and the option value is always a list.
For multi-value integers, we also accept the syntax ``x:y``, which
turns into ``range(x, y)`` - very useful for long integer ranges.
``help`` and ``metavar`` are used to construct the
automatically generated command line help string. The help
message is formatted like::
--name=METAVAR help string
``group`` is used to group the defined options in logical
groups. By default, command line options are grouped by the
file in which they are defined.
Command line option names must be unique globally. They can be parsed
from the command line with `parse_command_line` or parsed from a
config file with `parse_config_file`.
If a ``callback`` is given, it will be run with the new value whenever
the option is changed. This can be used to combine command-line
and file-based options::
define("config", type=str, help="path to config file",
callback=lambda path: parse_config_file(path, final=False))
With this definition, options in the file specified by ``--config`` will
override options set earlier on the command line, but can be overridden
by later flags.
"""
if name in self._options:
raise Error("Option %r already defined in %s" %
(name, self._options[name].file_name))
frame = sys._getframe(0)
options_file = frame.f_code.co_filename
file_name = frame.f_back.f_code.co_filename
if file_name == options_file:
file_name = ""
if type is None:
if not multiple and default is not None:
type = default.__class__
else:
type = str
if group:
group_name = group
else:
group_name = file_name
self._options[name] = _Option(name, file_name=file_name,
default=default, type=type, help=help,
metavar=metavar, multiple=multiple,
group_name=group_name,
callback=callback)
def parse_command_line(self, args=None, final=True):
"""Parses all options given on the command line (defaults to
`sys.argv`).
Note that ``args[0]`` is ignored since it is the program name
in `sys.argv`.
We return a list of all arguments that are not parsed as options.
If ``final`` is ``False``, parse callbacks will not be run.
This is useful for applications that wish to combine configurations
from multiple sources.
"""
if args is None:
args = sys.argv
remaining = []
for i in range(1, len(args)):
# All things after the last option are command line arguments
if not args[i].startswith("-"):
remaining = args[i:]
break
if args[i] == "--":
remaining = args[i + 1:]
break
arg = args[i].lstrip("-")
name, equals, value = arg.partition("=")
name = name.replace('-', '_')
if not name in self._options:
self.print_help()
raise Error('Unrecognized command line option: %r' % name)
option = self._options[name]
if not equals:
if option.type == bool:
value = "true"
else:
raise Error('Option %r requires a value' % name)
option.parse(value)
if final:
self.run_parse_callbacks()
return remaining
def parse_config_file(self, path, final=True):
"""Parses and loads the Python config file at the given path.
If ``final`` is ``False``, parse callbacks will not be run.
This is useful for applications that wish to combine configurations
from multiple sources.
"""
config = {}
with open(path) as f:
exec_in(f.read(), config, config)
for name in config:
if name in self._options:
self._options[name].set(config[name])
if final:
self.run_parse_callbacks()
def print_help(self, file=None):
"""Prints all the command line options to stderr (or another file)."""
if file is None:
file = sys.stderr
print("Usage: %s [OPTIONS]" % sys.argv[0], file=file)
print("\nOptions:\n", file=file)
by_group = {}
for option in self._options.values():
by_group.setdefault(option.group_name, []).append(option)
for filename, o in sorted(by_group.items()):
if filename:
print("\n%s options:\n" % os.path.normpath(filename), file=file)
o.sort(key=lambda option: option.name)
for option in o:
prefix = option.name
if option.metavar:
prefix += "=" + option.metavar
description = option.help or ""
if option.default is not None and option.default != '':
description += " (default %s)" % option.default
lines = textwrap.wrap(description, 79 - 35)
if len(prefix) > 30 or len(lines) == 0:
lines.insert(0, '')
print(" --%-30s %s" % (prefix, lines[0]), file=file)
for line in lines[1:]:
print("%-34s %s" % (' ', line), file=file)
print(file=file)
def _help_callback(self, value):
if value:
self.print_help()
sys.exit(0)
def add_parse_callback(self, callback):
"""Adds a parse callback, to be invoked when option parsing is done."""
self._parse_callbacks.append(stack_context.wrap(callback))
def run_parse_callbacks(self):
for callback in self._parse_callbacks:
callback()
def mockable(self):
"""Returns a wrapper around self that is compatible with
`mock.patch <unittest.mock.patch>`.
The `mock.patch <unittest.mock.patch>` function (included in
the standard library `unittest.mock` package since Python 3.3,
or in the third-party ``mock`` package for older versions of
Python) is incompatible with objects like ``options`` that
override ``__getattr__`` and ``__setattr__``. This function
returns an object that can be used with `mock.patch.object
<unittest.mock.patch.object>` to modify option values::
with mock.patch.object(options.mockable(), 'name', value):
assert options.name == value
"""
return _Mockable(self)
class _Mockable(object):
"""`mock.patch` compatible wrapper for `OptionParser`.
As of ``mock`` version 1.0.1, when an object uses ``__getattr__``
hooks instead of ``__dict__``, ``patch.__exit__`` tries to delete
the attribute it set instead of setting a new one (assuming that
the object does not catpure ``__setattr__``, so the patch
created a new attribute in ``__dict__``).
_Mockable's getattr and setattr pass through to the underlying
OptionParser, and delattr undoes the effect of a previous setattr.
"""
def __init__(self, options):
# Modify __dict__ directly to bypass __setattr__
self.__dict__['_options'] = options
self.__dict__['_originals'] = {}
def __getattr__(self, name):
return getattr(self._options, name)
def __setattr__(self, name, value):
assert name not in self._originals, "don't reuse mockable objects"
self._originals[name] = getattr(self._options, name)
setattr(self._options, name, value)
def __delattr__(self, name):
setattr(self._options, name, self._originals.pop(name))
class _Option(object):
UNSET = object()
def __init__(self, name, default=None, type=basestring_type, help=None,
metavar=None, multiple=False, file_name=None, group_name=None,
callback=None):
if default is None and multiple:
default = []
self.name = name
self.type = type
self.help = help
self.metavar = metavar
self.multiple = multiple
self.file_name = file_name
self.group_name = group_name
self.callback = callback
self.default = default
self._value = _Option.UNSET
def value(self):
return self.default if self._value is _Option.UNSET else self._value
def parse(self, value):
_parse = {
datetime.datetime: self._parse_datetime,
datetime.timedelta: self._parse_timedelta,
bool: self._parse_bool,
basestring_type: self._parse_string,
}.get(self.type, self.type)
if self.multiple:
self._value = []
for part in value.split(","):
if issubclass(self.type, numbers.Integral):
# allow ranges of the form X:Y (inclusive at both ends)
lo, _, hi = part.partition(":")
lo = _parse(lo)
hi = _parse(hi) if hi else lo
self._value.extend(range(lo, hi + 1))
else:
self._value.append(_parse(part))
else:
self._value = _parse(value)
if self.callback is not None:
self.callback(self._value)
return self.value()
def set(self, value):
if self.multiple:
if not isinstance(value, list):
raise Error("Option %r is required to be a list of %s" %
(self.name, self.type.__name__))
for item in value:
if item is not None and not isinstance(item, self.type):
raise Error("Option %r is required to be a list of %s" %
(self.name, self.type.__name__))
else:
if value is not None and not isinstance(value, self.type):
raise Error("Option %r is required to be a %s (%s given)" %
(self.name, self.type.__name__, type(value)))
self._value = value
if self.callback is not None:
self.callback(self._value)
# Supported date/time formats in our options
_DATETIME_FORMATS = [
"%a %b %d %H:%M:%S %Y",
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%d %H:%M",
"%Y-%m-%dT%H:%M",
"%Y%m%d %H:%M:%S",
"%Y%m%d %H:%M",
"%Y-%m-%d",
"%Y%m%d",
"%H:%M:%S",
"%H:%M",
]
def _parse_datetime(self, value):
for format in self._DATETIME_FORMATS:
try:
return datetime.datetime.strptime(value, format)
except ValueError:
pass
raise Error('Unrecognized date/time format: %r' % value)
_TIMEDELTA_ABBREVS = [
('hours', ['h']),
('minutes', ['m', 'min']),
('seconds', ['s', 'sec']),
('milliseconds', ['ms']),
('microseconds', ['us']),
('days', ['d']),
('weeks', ['w']),
]
_TIMEDELTA_ABBREV_DICT = dict(
(abbrev, full) for full, abbrevs in _TIMEDELTA_ABBREVS
for abbrev in abbrevs)
_FLOAT_PATTERN = r'[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?'
_TIMEDELTA_PATTERN = re.compile(
r'\s*(%s)\s*(\w*)\s*' % _FLOAT_PATTERN, re.IGNORECASE)
def _parse_timedelta(self, value):
try:
sum = datetime.timedelta()
start = 0
while start < len(value):
m = self._TIMEDELTA_PATTERN.match(value, start)
if not m:
raise Exception()
num = float(m.group(1))
units = m.group(2) or 'seconds'
units = self._TIMEDELTA_ABBREV_DICT.get(units, units)
sum += datetime.timedelta(**{units: num})
start = m.end()
return sum
except Exception:
raise
def _parse_bool(self, value):
return value.lower() not in ("false", "0", "f")
def _parse_string(self, value):
return _unicode(value)
options = OptionParser()
"""Global options object.
All defined options are available as attributes on this object.
"""
def define(name, default=None, type=None, help=None, metavar=None,
multiple=False, group=None, callback=None):
"""Defines an option in the global namespace.
See `OptionParser.define`.
"""
return options.define(name, default=default, type=type, help=help,
metavar=metavar, multiple=multiple, group=group,
callback=callback)
def parse_command_line(args=None, final=True):
"""Parses global options from the command line.
See `OptionParser.parse_command_line`.
"""
return options.parse_command_line(args, final=final)
def parse_config_file(path, final=True):
"""Parses global options from a config file.
See `OptionParser.parse_config_file`.
"""
return options.parse_config_file(path, final=final)
def print_help(file=None):
"""Prints all the command line options to stderr (or another file).
See `OptionParser.print_help`.
"""
return options.print_help(file)
def add_parse_callback(callback):
"""Adds a parse callback, to be invoked when option parsing is done.
See `OptionParser.add_parse_callback`
"""
options.add_parse_callback(callback)
# Default options
define_logging_options(options)

View file

@ -0,0 +1,142 @@
"""Bridges between the `asyncio` module and Tornado IOLoop.
This is a work in progress and interfaces are subject to change.
To test:
python3.4 -m tornado.test.runtests --ioloop=tornado.platform.asyncio.AsyncIOLoop
python3.4 -m tornado.test.runtests --ioloop=tornado.platform.asyncio.AsyncIOMainLoop
(the tests log a few warnings with AsyncIOMainLoop because they leave some
unfinished callbacks on the event loop that fail when it resumes)
"""
from __future__ import absolute_import, division, print_function, with_statement
import datetime
import functools
from tornado.ioloop import IOLoop
from tornado import stack_context
from tornado.util import timedelta_to_seconds
try:
# Import the real asyncio module for py33+ first. Older versions of the
# trollius backport also use this name.
import asyncio
except ImportError as e:
# Asyncio itself isn't available; see if trollius is (backport to py26+).
try:
import trollius as asyncio
except ImportError:
# Re-raise the original asyncio error, not the trollius one.
raise e
class BaseAsyncIOLoop(IOLoop):
def initialize(self, asyncio_loop, close_loop=False):
self.asyncio_loop = asyncio_loop
self.close_loop = close_loop
self.asyncio_loop.call_soon(self.make_current)
# Maps fd to (fileobj, handler function) pair (as in IOLoop.add_handler)
self.handlers = {}
# Set of fds listening for reads/writes
self.readers = set()
self.writers = set()
self.closing = False
def close(self, all_fds=False):
self.closing = True
for fd in list(self.handlers):
fileobj, handler_func = self.handlers[fd]
self.remove_handler(fd)
if all_fds:
self.close_fd(fileobj)
if self.close_loop:
self.asyncio_loop.close()
def add_handler(self, fd, handler, events):
fd, fileobj = self.split_fd(fd)
if fd in self.handlers:
raise ValueError("fd %s added twice" % fd)
self.handlers[fd] = (fileobj, stack_context.wrap(handler))
if events & IOLoop.READ:
self.asyncio_loop.add_reader(
fd, self._handle_events, fd, IOLoop.READ)
self.readers.add(fd)
if events & IOLoop.WRITE:
self.asyncio_loop.add_writer(
fd, self._handle_events, fd, IOLoop.WRITE)
self.writers.add(fd)
def update_handler(self, fd, events):
fd, fileobj = self.split_fd(fd)
if events & IOLoop.READ:
if fd not in self.readers:
self.asyncio_loop.add_reader(
fd, self._handle_events, fd, IOLoop.READ)
self.readers.add(fd)
else:
if fd in self.readers:
self.asyncio_loop.remove_reader(fd)
self.readers.remove(fd)
if events & IOLoop.WRITE:
if fd not in self.writers:
self.asyncio_loop.add_writer(
fd, self._handle_events, fd, IOLoop.WRITE)
self.writers.add(fd)
else:
if fd in self.writers:
self.asyncio_loop.remove_writer(fd)
self.writers.remove(fd)
def remove_handler(self, fd):
fd, fileobj = self.split_fd(fd)
if fd not in self.handlers:
return
if fd in self.readers:
self.asyncio_loop.remove_reader(fd)
self.readers.remove(fd)
if fd in self.writers:
self.asyncio_loop.remove_writer(fd)
self.writers.remove(fd)
del self.handlers[fd]
def _handle_events(self, fd, events):
fileobj, handler_func = self.handlers[fd]
handler_func(fileobj, events)
def start(self):
self._setup_logging()
self.asyncio_loop.run_forever()
def stop(self):
self.asyncio_loop.stop()
def call_at(self, when, callback, *args, **kwargs):
# asyncio.call_at supports *args but not **kwargs, so bind them here.
# We do not synchronize self.time and asyncio_loop.time, so
# convert from absolute to relative.
return self.asyncio_loop.call_later(
max(0, when - self.time()), self._run_callback,
functools.partial(stack_context.wrap(callback), *args, **kwargs))
def remove_timeout(self, timeout):
timeout.cancel()
def add_callback(self, callback, *args, **kwargs):
if self.closing:
raise RuntimeError("IOLoop is closing")
self.asyncio_loop.call_soon_threadsafe(
self._run_callback,
functools.partial(stack_context.wrap(callback), *args, **kwargs))
add_callback_from_signal = add_callback
class AsyncIOMainLoop(BaseAsyncIOLoop):
def initialize(self):
super(AsyncIOMainLoop, self).initialize(asyncio.get_event_loop(),
close_loop=False)
class AsyncIOLoop(BaseAsyncIOLoop):
def initialize(self):
super(AsyncIOLoop, self).initialize(asyncio.new_event_loop(),
close_loop=True)

View file

@ -0,0 +1,49 @@
#!/usr/bin/env python
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Implementation of platform-specific functionality.
For each function or class described in `tornado.platform.interface`,
the appropriate platform-specific implementation exists in this module.
Most code that needs access to this functionality should do e.g.::
from tornado.platform.auto import set_close_exec
"""
from __future__ import absolute_import, division, print_function, with_statement
import os
if os.name == 'nt':
from tornado.platform.common import Waker
from tornado.platform.windows import set_close_exec
elif 'APPENGINE_RUNTIME' in os.environ:
from tornado.platform.common import Waker
def set_close_exec(fd):
pass
else:
from tornado.platform.posix import set_close_exec, Waker
try:
# monotime monkey-patches the time module to have a monotonic function
# in versions of python before 3.3.
import monotime
except ImportError:
pass
try:
from time import monotonic as monotonic_time
except ImportError:
monotonic_time = None

View file

@ -0,0 +1,76 @@
from __future__ import absolute_import, division, print_function, with_statement
import pycares
import socket
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.netutil import Resolver, is_valid_ip
class CaresResolver(Resolver):
"""Name resolver based on the c-ares library.
This is a non-blocking and non-threaded resolver. It may not produce
the same results as the system resolver, but can be used for non-blocking
resolution when threads cannot be used.
c-ares fails to resolve some names when ``family`` is ``AF_UNSPEC``,
so it is only recommended for use in ``AF_INET`` (i.e. IPv4). This is
the default for ``tornado.simple_httpclient``, but other libraries
may default to ``AF_UNSPEC``.
"""
def initialize(self, io_loop=None):
self.io_loop = io_loop or IOLoop.current()
self.channel = pycares.Channel(sock_state_cb=self._sock_state_cb)
self.fds = {}
def _sock_state_cb(self, fd, readable, writable):
state = ((IOLoop.READ if readable else 0) |
(IOLoop.WRITE if writable else 0))
if not state:
self.io_loop.remove_handler(fd)
del self.fds[fd]
elif fd in self.fds:
self.io_loop.update_handler(fd, state)
self.fds[fd] = state
else:
self.io_loop.add_handler(fd, self._handle_events, state)
self.fds[fd] = state
def _handle_events(self, fd, events):
read_fd = pycares.ARES_SOCKET_BAD
write_fd = pycares.ARES_SOCKET_BAD
if events & IOLoop.READ:
read_fd = fd
if events & IOLoop.WRITE:
write_fd = fd
self.channel.process_fd(read_fd, write_fd)
@gen.coroutine
def resolve(self, host, port, family=0):
if is_valid_ip(host):
addresses = [host]
else:
# gethostbyname doesn't take callback as a kwarg
self.channel.gethostbyname(host, family, (yield gen.Callback(1)))
callback_args = yield gen.Wait(1)
assert isinstance(callback_args, gen.Arguments)
assert not callback_args.kwargs
result, error = callback_args.args
if error:
raise Exception('C-Ares returned error %s: %s while resolving %s' %
(error, pycares.errno.strerror(error), host))
addresses = result.addresses
addrinfo = []
for address in addresses:
if '.' in address:
address_family = socket.AF_INET
elif ':' in address:
address_family = socket.AF_INET6
else:
address_family = socket.AF_UNSPEC
if family != socket.AF_UNSPEC and family != address_family:
raise Exception('Requested socket family %d but got %d' %
(family, address_family))
addrinfo.append((address_family, (address, port)))
raise gen.Return(addrinfo)

View file

@ -0,0 +1,92 @@
"""Lowest-common-denominator implementations of platform functionality."""
from __future__ import absolute_import, division, print_function, with_statement
import errno
import socket
from tornado.platform import interface
class Waker(interface.Waker):
"""Create an OS independent asynchronous pipe.
For use on platforms that don't have os.pipe() (or where pipes cannot
be passed to select()), but do have sockets. This includes Windows
and Jython.
"""
def __init__(self):
# Based on Zope select_trigger.py:
# https://github.com/zopefoundation/Zope/blob/master/src/ZServer/medusa/thread/select_trigger.py
self.writer = socket.socket()
# Disable buffering -- pulling the trigger sends 1 byte,
# and we want that sent immediately, to wake up ASAP.
self.writer.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
count = 0
while 1:
count += 1
# Bind to a local port; for efficiency, let the OS pick
# a free port for us.
# Unfortunately, stress tests showed that we may not
# be able to connect to that port ("Address already in
# use") despite that the OS picked it. This appears
# to be a race bug in the Windows socket implementation.
# So we loop until a connect() succeeds (almost always
# on the first try). See the long thread at
# http://mail.zope.org/pipermail/zope/2005-July/160433.html
# for hideous details.
a = socket.socket()
a.bind(("127.0.0.1", 0))
a.listen(1)
connect_address = a.getsockname() # assigned (host, port) pair
try:
self.writer.connect(connect_address)
break # success
except socket.error as detail:
if (not hasattr(errno, 'WSAEADDRINUSE') or
detail[0] != errno.WSAEADDRINUSE):
# "Address already in use" is the only error
# I've seen on two WinXP Pro SP2 boxes, under
# Pythons 2.3.5 and 2.4.1.
raise
# (10048, 'Address already in use')
# assert count <= 2 # never triggered in Tim's tests
if count >= 10: # I've never seen it go above 2
a.close()
self.writer.close()
raise socket.error("Cannot bind trigger!")
# Close `a` and try again. Note: I originally put a short
# sleep() here, but it didn't appear to help or hurt.
a.close()
self.reader, addr = a.accept()
self.reader.setblocking(0)
self.writer.setblocking(0)
a.close()
self.reader_fd = self.reader.fileno()
def fileno(self):
return self.reader.fileno()
def write_fileno(self):
return self.writer.fileno()
def wake(self):
try:
self.writer.send(b"x")
except (IOError, socket.error):
pass
def consume(self):
try:
while True:
result = self.reader.recv(1024)
if not result:
break
except (IOError, socket.error):
pass
def close(self):
self.reader.close()
self.writer.close()

View file

@ -0,0 +1,26 @@
#!/usr/bin/env python
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""EPoll-based IOLoop implementation for Linux systems."""
from __future__ import absolute_import, division, print_function, with_statement
import select
from tornado.ioloop import PollIOLoop
class EPollIOLoop(PollIOLoop):
def initialize(self, **kwargs):
super(EPollIOLoop, self).initialize(impl=select.epoll(), **kwargs)

View file

@ -0,0 +1,63 @@
#!/usr/bin/env python
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Interfaces for platform-specific functionality.
This module exists primarily for documentation purposes and as base classes
for other tornado.platform modules. Most code should import the appropriate
implementation from `tornado.platform.auto`.
"""
from __future__ import absolute_import, division, print_function, with_statement
def set_close_exec(fd):
"""Sets the close-on-exec bit (``FD_CLOEXEC``)for a file descriptor."""
raise NotImplementedError()
class Waker(object):
"""A socket-like object that can wake another thread from ``select()``.
The `~tornado.ioloop.IOLoop` will add the Waker's `fileno()` to
its ``select`` (or ``epoll`` or ``kqueue``) calls. When another
thread wants to wake up the loop, it calls `wake`. Once it has woken
up, it will call `consume` to do any necessary per-wake cleanup. When
the ``IOLoop`` is closed, it closes its waker too.
"""
def fileno(self):
"""Returns the read file descriptor for this waker.
Must be suitable for use with ``select()`` or equivalent on the
local platform.
"""
raise NotImplementedError()
def write_fileno(self):
"""Returns the write file descriptor for this waker."""
raise NotImplementedError()
def wake(self):
"""Triggers activity on the waker's file descriptor."""
raise NotImplementedError()
def consume(self):
"""Called after the listen has woken up to do any necessary cleanup."""
raise NotImplementedError()
def close(self):
"""Closes the waker's file descriptor(s)."""
raise NotImplementedError()

View file

@ -0,0 +1,92 @@
#!/usr/bin/env python
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""KQueue-based IOLoop implementation for BSD/Mac systems."""
from __future__ import absolute_import, division, print_function, with_statement
import select
from tornado.ioloop import IOLoop, PollIOLoop
assert hasattr(select, 'kqueue'), 'kqueue not supported'
class _KQueue(object):
"""A kqueue-based event loop for BSD/Mac systems."""
def __init__(self):
self._kqueue = select.kqueue()
self._active = {}
def fileno(self):
return self._kqueue.fileno()
def close(self):
self._kqueue.close()
def register(self, fd, events):
if fd in self._active:
raise IOError("fd %s already registered" % fd)
self._control(fd, events, select.KQ_EV_ADD)
self._active[fd] = events
def modify(self, fd, events):
self.unregister(fd)
self.register(fd, events)
def unregister(self, fd):
events = self._active.pop(fd)
self._control(fd, events, select.KQ_EV_DELETE)
def _control(self, fd, events, flags):
kevents = []
if events & IOLoop.WRITE:
kevents.append(select.kevent(
fd, filter=select.KQ_FILTER_WRITE, flags=flags))
if events & IOLoop.READ or not kevents:
# Always read when there is not a write
kevents.append(select.kevent(
fd, filter=select.KQ_FILTER_READ, flags=flags))
# Even though control() takes a list, it seems to return EINVAL
# on Mac OS X (10.6) when there is more than one event in the list.
for kevent in kevents:
self._kqueue.control([kevent], 0)
def poll(self, timeout):
kevents = self._kqueue.control(None, 1000, timeout)
events = {}
for kevent in kevents:
fd = kevent.ident
if kevent.filter == select.KQ_FILTER_READ:
events[fd] = events.get(fd, 0) | IOLoop.READ
if kevent.filter == select.KQ_FILTER_WRITE:
if kevent.flags & select.KQ_EV_EOF:
# If an asynchronous connection is refused, kqueue
# returns a write event with the EOF flag set.
# Turn this into an error for consistency with the
# other IOLoop implementations.
# Note that for read events, EOF may be returned before
# all data has been consumed from the socket buffer,
# so we only check for EOF on write events.
events[fd] = IOLoop.ERROR
else:
events[fd] = events.get(fd, 0) | IOLoop.WRITE
if kevent.flags & select.KQ_EV_ERROR:
events[fd] = events.get(fd, 0) | IOLoop.ERROR
return events.items()
class KQueueIOLoop(PollIOLoop):
def initialize(self, **kwargs):
super(KQueueIOLoop, self).initialize(impl=_KQueue(), **kwargs)

View file

@ -0,0 +1,70 @@
#!/usr/bin/env python
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Posix implementations of platform-specific functionality."""
from __future__ import absolute_import, division, print_function, with_statement
import fcntl
import os
from tornado.platform import interface
def set_close_exec(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFD)
fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC)
def _set_nonblocking(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
class Waker(interface.Waker):
def __init__(self):
r, w = os.pipe()
_set_nonblocking(r)
_set_nonblocking(w)
set_close_exec(r)
set_close_exec(w)
self.reader = os.fdopen(r, "rb", 0)
self.writer = os.fdopen(w, "wb", 0)
def fileno(self):
return self.reader.fileno()
def write_fileno(self):
return self.writer.fileno()
def wake(self):
try:
self.writer.write(b"x")
except IOError:
pass
def consume(self):
try:
while True:
result = self.reader.read()
if not result:
break
except IOError:
pass
def close(self):
self.reader.close()
self.writer.close()

View file

@ -0,0 +1,76 @@
#!/usr/bin/env python
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Select-based IOLoop implementation.
Used as a fallback for systems that don't support epoll or kqueue.
"""
from __future__ import absolute_import, division, print_function, with_statement
import select
from tornado.ioloop import IOLoop, PollIOLoop
class _Select(object):
"""A simple, select()-based IOLoop implementation for non-Linux systems"""
def __init__(self):
self.read_fds = set()
self.write_fds = set()
self.error_fds = set()
self.fd_sets = (self.read_fds, self.write_fds, self.error_fds)
def close(self):
pass
def register(self, fd, events):
if fd in self.read_fds or fd in self.write_fds or fd in self.error_fds:
raise IOError("fd %s already registered" % fd)
if events & IOLoop.READ:
self.read_fds.add(fd)
if events & IOLoop.WRITE:
self.write_fds.add(fd)
if events & IOLoop.ERROR:
self.error_fds.add(fd)
# Closed connections are reported as errors by epoll and kqueue,
# but as zero-byte reads by select, so when errors are requested
# we need to listen for both read and error.
self.read_fds.add(fd)
def modify(self, fd, events):
self.unregister(fd)
self.register(fd, events)
def unregister(self, fd):
self.read_fds.discard(fd)
self.write_fds.discard(fd)
self.error_fds.discard(fd)
def poll(self, timeout):
readable, writeable, errors = select.select(
self.read_fds, self.write_fds, self.error_fds, timeout)
events = {}
for fd in readable:
events[fd] = events.get(fd, 0) | IOLoop.READ
for fd in writeable:
events[fd] = events.get(fd, 0) | IOLoop.WRITE
for fd in errors:
events[fd] = events.get(fd, 0) | IOLoop.ERROR
return events.items()
class SelectIOLoop(PollIOLoop):
def initialize(self, **kwargs):
super(SelectIOLoop, self).initialize(impl=_Select(), **kwargs)

View file

@ -0,0 +1,556 @@
# Author: Ovidiu Predescu
# Date: July 2011
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
# Note: This module's docs are not currently extracted automatically,
# so changes must be made manually to twisted.rst
# TODO: refactor doc build process to use an appropriate virtualenv
"""Bridges between the Twisted reactor and Tornado IOLoop.
This module lets you run applications and libraries written for
Twisted in a Tornado application. It can be used in two modes,
depending on which library's underlying event loop you want to use.
This module has been tested with Twisted versions 11.0.0 and newer.
Twisted on Tornado
------------------
`TornadoReactor` implements the Twisted reactor interface on top of
the Tornado IOLoop. To use it, simply call `install` at the beginning
of the application::
import tornado.platform.twisted
tornado.platform.twisted.install()
from twisted.internet import reactor
When the app is ready to start, call `IOLoop.instance().start()`
instead of `reactor.run()`.
It is also possible to create a non-global reactor by calling
`tornado.platform.twisted.TornadoReactor(io_loop)`. However, if
the `IOLoop` and reactor are to be short-lived (such as those used in
unit tests), additional cleanup may be required. Specifically, it is
recommended to call::
reactor.fireSystemEvent('shutdown')
reactor.disconnectAll()
before closing the `IOLoop`.
Tornado on Twisted
------------------
`TwistedIOLoop` implements the Tornado IOLoop interface on top of the Twisted
reactor. Recommended usage::
from tornado.platform.twisted import TwistedIOLoop
from twisted.internet import reactor
TwistedIOLoop().install()
# Set up your tornado application as usual using `IOLoop.instance`
reactor.run()
`TwistedIOLoop` always uses the global Twisted reactor.
"""
from __future__ import absolute_import, division, print_function, with_statement
import datetime
import functools
import numbers
import socket
import twisted.internet.abstract
from twisted.internet.posixbase import PosixReactorBase
from twisted.internet.interfaces import \
IReactorFDSet, IDelayedCall, IReactorTime, IReadDescriptor, IWriteDescriptor
from twisted.python import failure, log
from twisted.internet import error
import twisted.names.cache
import twisted.names.client
import twisted.names.hosts
import twisted.names.resolve
from zope.interface import implementer
from tornado.escape import utf8
from tornado import gen
import tornado.ioloop
from tornado.log import app_log
from tornado.netutil import Resolver
from tornado.stack_context import NullContext, wrap
from tornado.ioloop import IOLoop
from tornado.util import timedelta_to_seconds
@implementer(IDelayedCall)
class TornadoDelayedCall(object):
"""DelayedCall object for Tornado."""
def __init__(self, reactor, seconds, f, *args, **kw):
self._reactor = reactor
self._func = functools.partial(f, *args, **kw)
self._time = self._reactor.seconds() + seconds
self._timeout = self._reactor._io_loop.add_timeout(self._time,
self._called)
self._active = True
def _called(self):
self._active = False
self._reactor._removeDelayedCall(self)
try:
self._func()
except:
app_log.error("_called caught exception", exc_info=True)
def getTime(self):
return self._time
def cancel(self):
self._active = False
self._reactor._io_loop.remove_timeout(self._timeout)
self._reactor._removeDelayedCall(self)
def delay(self, seconds):
self._reactor._io_loop.remove_timeout(self._timeout)
self._time += seconds
self._timeout = self._reactor._io_loop.add_timeout(self._time,
self._called)
def reset(self, seconds):
self._reactor._io_loop.remove_timeout(self._timeout)
self._time = self._reactor.seconds() + seconds
self._timeout = self._reactor._io_loop.add_timeout(self._time,
self._called)
def active(self):
return self._active
@implementer(IReactorTime, IReactorFDSet)
class TornadoReactor(PosixReactorBase):
"""Twisted reactor built on the Tornado IOLoop.
Since it is intented to be used in applications where the top-level
event loop is ``io_loop.start()`` rather than ``reactor.run()``,
it is implemented a little differently than other Twisted reactors.
We override `mainLoop` instead of `doIteration` and must implement
timed call functionality on top of `IOLoop.add_timeout` rather than
using the implementation in `PosixReactorBase`.
"""
def __init__(self, io_loop=None):
if not io_loop:
io_loop = tornado.ioloop.IOLoop.current()
self._io_loop = io_loop
self._readers = {} # map of reader objects to fd
self._writers = {} # map of writer objects to fd
self._fds = {} # a map of fd to a (reader, writer) tuple
self._delayedCalls = {}
PosixReactorBase.__init__(self)
self.addSystemEventTrigger('during', 'shutdown', self.crash)
# IOLoop.start() bypasses some of the reactor initialization.
# Fire off the necessary events if they weren't already triggered
# by reactor.run().
def start_if_necessary():
if not self._started:
self.fireSystemEvent('startup')
self._io_loop.add_callback(start_if_necessary)
# IReactorTime
def seconds(self):
return self._io_loop.time()
def callLater(self, seconds, f, *args, **kw):
dc = TornadoDelayedCall(self, seconds, f, *args, **kw)
self._delayedCalls[dc] = True
return dc
def getDelayedCalls(self):
return [x for x in self._delayedCalls if x._active]
def _removeDelayedCall(self, dc):
if dc in self._delayedCalls:
del self._delayedCalls[dc]
# IReactorThreads
def callFromThread(self, f, *args, **kw):
"""See `twisted.internet.interfaces.IReactorThreads.callFromThread`"""
assert callable(f), "%s is not callable" % f
with NullContext():
# This NullContext is mainly for an edge case when running
# TwistedIOLoop on top of a TornadoReactor.
# TwistedIOLoop.add_callback uses reactor.callFromThread and
# should not pick up additional StackContexts along the way.
self._io_loop.add_callback(f, *args, **kw)
# We don't need the waker code from the super class, Tornado uses
# its own waker.
def installWaker(self):
pass
def wakeUp(self):
pass
# IReactorFDSet
def _invoke_callback(self, fd, events):
if fd not in self._fds:
return
(reader, writer) = self._fds[fd]
if reader:
err = None
if reader.fileno() == -1:
err = error.ConnectionLost()
elif events & IOLoop.READ:
err = log.callWithLogger(reader, reader.doRead)
if err is None and events & IOLoop.ERROR:
err = error.ConnectionLost()
if err is not None:
self.removeReader(reader)
reader.readConnectionLost(failure.Failure(err))
if writer:
err = None
if writer.fileno() == -1:
err = error.ConnectionLost()
elif events & IOLoop.WRITE:
err = log.callWithLogger(writer, writer.doWrite)
if err is None and events & IOLoop.ERROR:
err = error.ConnectionLost()
if err is not None:
self.removeWriter(writer)
writer.writeConnectionLost(failure.Failure(err))
def addReader(self, reader):
"""Add a FileDescriptor for notification of data available to read."""
if reader in self._readers:
# Don't add the reader if it's already there
return
fd = reader.fileno()
self._readers[reader] = fd
if fd in self._fds:
(_, writer) = self._fds[fd]
self._fds[fd] = (reader, writer)
if writer:
# We already registered this fd for write events,
# update it for read events as well.
self._io_loop.update_handler(fd, IOLoop.READ | IOLoop.WRITE)
else:
with NullContext():
self._fds[fd] = (reader, None)
self._io_loop.add_handler(fd, self._invoke_callback,
IOLoop.READ)
def addWriter(self, writer):
"""Add a FileDescriptor for notification of data available to write."""
if writer in self._writers:
return
fd = writer.fileno()
self._writers[writer] = fd
if fd in self._fds:
(reader, _) = self._fds[fd]
self._fds[fd] = (reader, writer)
if reader:
# We already registered this fd for read events,
# update it for write events as well.
self._io_loop.update_handler(fd, IOLoop.READ | IOLoop.WRITE)
else:
with NullContext():
self._fds[fd] = (None, writer)
self._io_loop.add_handler(fd, self._invoke_callback,
IOLoop.WRITE)
def removeReader(self, reader):
"""Remove a Selectable for notification of data available to read."""
if reader in self._readers:
fd = self._readers.pop(reader)
(_, writer) = self._fds[fd]
if writer:
# We have a writer so we need to update the IOLoop for
# write events only.
self._fds[fd] = (None, writer)
self._io_loop.update_handler(fd, IOLoop.WRITE)
else:
# Since we have no writer registered, we remove the
# entry from _fds and unregister the handler from the
# IOLoop
del self._fds[fd]
self._io_loop.remove_handler(fd)
def removeWriter(self, writer):
"""Remove a Selectable for notification of data available to write."""
if writer in self._writers:
fd = self._writers.pop(writer)
(reader, _) = self._fds[fd]
if reader:
# We have a reader so we need to update the IOLoop for
# read events only.
self._fds[fd] = (reader, None)
self._io_loop.update_handler(fd, IOLoop.READ)
else:
# Since we have no reader registered, we remove the
# entry from the _fds and unregister the handler from
# the IOLoop.
del self._fds[fd]
self._io_loop.remove_handler(fd)
def removeAll(self):
return self._removeAll(self._readers, self._writers)
def getReaders(self):
return self._readers.keys()
def getWriters(self):
return self._writers.keys()
# The following functions are mainly used in twisted-style test cases;
# it is expected that most users of the TornadoReactor will call
# IOLoop.start() instead of Reactor.run().
def stop(self):
PosixReactorBase.stop(self)
fire_shutdown = functools.partial(self.fireSystemEvent, "shutdown")
self._io_loop.add_callback(fire_shutdown)
def crash(self):
PosixReactorBase.crash(self)
self._io_loop.stop()
def doIteration(self, delay):
raise NotImplementedError("doIteration")
def mainLoop(self):
self._io_loop.start()
class _TestReactor(TornadoReactor):
"""Subclass of TornadoReactor for use in unittests.
This can't go in the test.py file because of import-order dependencies
with the Twisted reactor test builder.
"""
def __init__(self):
# always use a new ioloop
super(_TestReactor, self).__init__(IOLoop())
def listenTCP(self, port, factory, backlog=50, interface=''):
# default to localhost to avoid firewall prompts on the mac
if not interface:
interface = '127.0.0.1'
return super(_TestReactor, self).listenTCP(
port, factory, backlog=backlog, interface=interface)
def listenUDP(self, port, protocol, interface='', maxPacketSize=8192):
if not interface:
interface = '127.0.0.1'
return super(_TestReactor, self).listenUDP(
port, protocol, interface=interface, maxPacketSize=maxPacketSize)
def install(io_loop=None):
"""Install this package as the default Twisted reactor."""
if not io_loop:
io_loop = tornado.ioloop.IOLoop.current()
reactor = TornadoReactor(io_loop)
from twisted.internet.main import installReactor
installReactor(reactor)
return reactor
@implementer(IReadDescriptor, IWriteDescriptor)
class _FD(object):
def __init__(self, fd, fileobj, handler):
self.fd = fd
self.fileobj = fileobj
self.handler = handler
self.reading = False
self.writing = False
self.lost = False
def fileno(self):
return self.fd
def doRead(self):
if not self.lost:
self.handler(self.fileobj, tornado.ioloop.IOLoop.READ)
def doWrite(self):
if not self.lost:
self.handler(self.fileobj, tornado.ioloop.IOLoop.WRITE)
def connectionLost(self, reason):
if not self.lost:
self.handler(self.fileobj, tornado.ioloop.IOLoop.ERROR)
self.lost = True
def logPrefix(self):
return ''
class TwistedIOLoop(tornado.ioloop.IOLoop):
"""IOLoop implementation that runs on Twisted.
Uses the global Twisted reactor by default. To create multiple
`TwistedIOLoops` in the same process, you must pass a unique reactor
when constructing each one.
Not compatible with `tornado.process.Subprocess.set_exit_callback`
because the ``SIGCHLD`` handlers used by Tornado and Twisted conflict
with each other.
"""
def initialize(self, reactor=None):
if reactor is None:
import twisted.internet.reactor
reactor = twisted.internet.reactor
self.reactor = reactor
self.fds = {}
self.reactor.callWhenRunning(self.make_current)
def close(self, all_fds=False):
fds = self.fds
self.reactor.removeAll()
for c in self.reactor.getDelayedCalls():
c.cancel()
if all_fds:
for fd in fds.values():
self.close_fd(fd.fileobj)
def add_handler(self, fd, handler, events):
if fd in self.fds:
raise ValueError('fd %s added twice' % fd)
fd, fileobj = self.split_fd(fd)
self.fds[fd] = _FD(fd, fileobj, wrap(handler))
if events & tornado.ioloop.IOLoop.READ:
self.fds[fd].reading = True
self.reactor.addReader(self.fds[fd])
if events & tornado.ioloop.IOLoop.WRITE:
self.fds[fd].writing = True
self.reactor.addWriter(self.fds[fd])
def update_handler(self, fd, events):
fd, fileobj = self.split_fd(fd)
if events & tornado.ioloop.IOLoop.READ:
if not self.fds[fd].reading:
self.fds[fd].reading = True
self.reactor.addReader(self.fds[fd])
else:
if self.fds[fd].reading:
self.fds[fd].reading = False
self.reactor.removeReader(self.fds[fd])
if events & tornado.ioloop.IOLoop.WRITE:
if not self.fds[fd].writing:
self.fds[fd].writing = True
self.reactor.addWriter(self.fds[fd])
else:
if self.fds[fd].writing:
self.fds[fd].writing = False
self.reactor.removeWriter(self.fds[fd])
def remove_handler(self, fd):
fd, fileobj = self.split_fd(fd)
if fd not in self.fds:
return
self.fds[fd].lost = True
if self.fds[fd].reading:
self.reactor.removeReader(self.fds[fd])
if self.fds[fd].writing:
self.reactor.removeWriter(self.fds[fd])
del self.fds[fd]
def start(self):
self._setup_logging()
self.reactor.run()
def stop(self):
self.reactor.crash()
def add_timeout(self, deadline, callback, *args, **kwargs):
# This method could be simplified (since tornado 4.0) by
# overriding call_at instead of add_timeout, but we leave it
# for now as a test of backwards-compatibility.
if isinstance(deadline, numbers.Real):
delay = max(deadline - self.time(), 0)
elif isinstance(deadline, datetime.timedelta):
delay = timedelta_to_seconds(deadline)
else:
raise TypeError("Unsupported deadline %r")
return self.reactor.callLater(
delay, self._run_callback,
functools.partial(wrap(callback), *args, **kwargs))
def remove_timeout(self, timeout):
if timeout.active():
timeout.cancel()
def add_callback(self, callback, *args, **kwargs):
self.reactor.callFromThread(
self._run_callback,
functools.partial(wrap(callback), *args, **kwargs))
def add_callback_from_signal(self, callback, *args, **kwargs):
self.add_callback(callback, *args, **kwargs)
class TwistedResolver(Resolver):
"""Twisted-based asynchronous resolver.
This is a non-blocking and non-threaded resolver. It is
recommended only when threads cannot be used, since it has
limitations compared to the standard ``getaddrinfo``-based
`~tornado.netutil.Resolver` and
`~tornado.netutil.ThreadedResolver`. Specifically, it returns at
most one result, and arguments other than ``host`` and ``family``
are ignored. It may fail to resolve when ``family`` is not
``socket.AF_UNSPEC``.
Requires Twisted 12.1 or newer.
"""
def initialize(self, io_loop=None):
self.io_loop = io_loop or IOLoop.current()
# partial copy of twisted.names.client.createResolver, which doesn't
# allow for a reactor to be passed in.
self.reactor = tornado.platform.twisted.TornadoReactor(io_loop)
host_resolver = twisted.names.hosts.Resolver('/etc/hosts')
cache_resolver = twisted.names.cache.CacheResolver(reactor=self.reactor)
real_resolver = twisted.names.client.Resolver('/etc/resolv.conf',
reactor=self.reactor)
self.resolver = twisted.names.resolve.ResolverChain(
[host_resolver, cache_resolver, real_resolver])
@gen.coroutine
def resolve(self, host, port, family=0):
# getHostByName doesn't accept IP addresses, so if the input
# looks like an IP address just return it immediately.
if twisted.internet.abstract.isIPAddress(host):
resolved = host
resolved_family = socket.AF_INET
elif twisted.internet.abstract.isIPv6Address(host):
resolved = host
resolved_family = socket.AF_INET6
else:
deferred = self.resolver.getHostByName(utf8(host))
resolved = yield gen.Task(deferred.addBoth)
if isinstance(resolved, failure.Failure):
resolved.raiseException()
elif twisted.internet.abstract.isIPAddress(resolved):
resolved_family = socket.AF_INET
elif twisted.internet.abstract.isIPv6Address(resolved):
resolved_family = socket.AF_INET6
else:
resolved_family = socket.AF_UNSPEC
if family != socket.AF_UNSPEC and family != resolved_family:
raise Exception('Requested socket family %d but got %d' %
(family, resolved_family))
result = [
(resolved_family, (resolved, port)),
]
raise gen.Return(result)

View file

@ -0,0 +1,20 @@
# NOTE: win32 support is currently experimental, and not recommended
# for production use.
from __future__ import absolute_import, division, print_function, with_statement
import ctypes
import ctypes.wintypes
# See: http://msdn.microsoft.com/en-us/library/ms724935(VS.85).aspx
SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation
SetHandleInformation.argtypes = (ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD, ctypes.wintypes.DWORD)
SetHandleInformation.restype = ctypes.wintypes.BOOL
HANDLE_FLAG_INHERIT = 0x00000001
def set_close_exec(fd):
success = SetHandleInformation(fd, HANDLE_FLAG_INHERIT, 0)
if not success:
raise ctypes.GetLastError()

View file

@ -0,0 +1,312 @@
#!/usr/bin/env python
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Utilities for working with multiple processes, including both forking
the server into multiple processes and managing subprocesses.
"""
from __future__ import absolute_import, division, print_function, with_statement
import errno
import os
import signal
import subprocess
import sys
import time
from binascii import hexlify
from tornado import ioloop
from tornado.iostream import PipeIOStream
from tornado.log import gen_log
from tornado.platform.auto import set_close_exec
from tornado import stack_context
from tornado.util import errno_from_exception
try:
import multiprocessing
except ImportError:
# Multiprocessing is not availble on Google App Engine.
multiprocessing = None
try:
long # py2
except NameError:
long = int # py3
def cpu_count():
"""Returns the number of processors on this machine."""
if multiprocessing is None:
return 1
try:
return multiprocessing.cpu_count()
except NotImplementedError:
pass
try:
return os.sysconf("SC_NPROCESSORS_CONF")
except ValueError:
pass
gen_log.error("Could not detect number of processors; assuming 1")
return 1
def _reseed_random():
if 'random' not in sys.modules:
return
import random
# If os.urandom is available, this method does the same thing as
# random.seed (at least as of python 2.6). If os.urandom is not
# available, we mix in the pid in addition to a timestamp.
try:
seed = long(hexlify(os.urandom(16)), 16)
except NotImplementedError:
seed = int(time.time() * 1000) ^ os.getpid()
random.seed(seed)
def _pipe_cloexec():
r, w = os.pipe()
set_close_exec(r)
set_close_exec(w)
return r, w
_task_id = None
def fork_processes(num_processes, max_restarts=100):
"""Starts multiple worker processes.
If ``num_processes`` is None or <= 0, we detect the number of cores
available on this machine and fork that number of child
processes. If ``num_processes`` is given and > 0, we fork that
specific number of sub-processes.
Since we use processes and not threads, there is no shared memory
between any server code.
Note that multiple processes are not compatible with the autoreload
module (or the ``autoreload=True`` option to `tornado.web.Application`
which defaults to True when ``debug=True``).
When using multiple processes, no IOLoops can be created or
referenced until after the call to ``fork_processes``.
In each child process, ``fork_processes`` returns its *task id*, a
number between 0 and ``num_processes``. Processes that exit
abnormally (due to a signal or non-zero exit status) are restarted
with the same id (up to ``max_restarts`` times). In the parent
process, ``fork_processes`` returns None if all child processes
have exited normally, but will otherwise only exit by throwing an
exception.
"""
global _task_id
assert _task_id is None
if num_processes is None or num_processes <= 0:
num_processes = cpu_count()
if ioloop.IOLoop.initialized():
raise RuntimeError("Cannot run in multiple processes: IOLoop instance "
"has already been initialized. You cannot call "
"IOLoop.instance() before calling start_processes()")
gen_log.info("Starting %d processes", num_processes)
children = {}
def start_child(i):
pid = os.fork()
if pid == 0:
# child process
_reseed_random()
global _task_id
_task_id = i
return i
else:
children[pid] = i
return None
for i in range(num_processes):
id = start_child(i)
if id is not None:
return id
num_restarts = 0
while children:
try:
pid, status = os.wait()
except OSError as e:
if errno_from_exception(e) == errno.EINTR:
continue
raise
if pid not in children:
continue
id = children.pop(pid)
if os.WIFSIGNALED(status):
gen_log.warning("child %d (pid %d) killed by signal %d, restarting",
id, pid, os.WTERMSIG(status))
elif os.WEXITSTATUS(status) != 0:
gen_log.warning("child %d (pid %d) exited with status %d, restarting",
id, pid, os.WEXITSTATUS(status))
else:
gen_log.info("child %d (pid %d) exited normally", id, pid)
continue
num_restarts += 1
if num_restarts > max_restarts:
raise RuntimeError("Too many child restarts, giving up")
new_id = start_child(id)
if new_id is not None:
return new_id
# All child processes exited cleanly, so exit the master process
# instead of just returning to right after the call to
# fork_processes (which will probably just start up another IOLoop
# unless the caller checks the return value).
sys.exit(0)
def task_id():
"""Returns the current task id, if any.
Returns None if this process was not created by `fork_processes`.
"""
global _task_id
return _task_id
class Subprocess(object):
"""Wraps ``subprocess.Popen`` with IOStream support.
The constructor is the same as ``subprocess.Popen`` with the following
additions:
* ``stdin``, ``stdout``, and ``stderr`` may have the value
``tornado.process.Subprocess.STREAM``, which will make the corresponding
attribute of the resulting Subprocess a `.PipeIOStream`.
* A new keyword argument ``io_loop`` may be used to pass in an IOLoop.
"""
STREAM = object()
_initialized = False
_waiting = {}
def __init__(self, *args, **kwargs):
self.io_loop = kwargs.pop('io_loop', None) or ioloop.IOLoop.current()
# All FDs we create should be closed on error; those in to_close
# should be closed in the parent process on success.
pipe_fds = []
to_close = []
if kwargs.get('stdin') is Subprocess.STREAM:
in_r, in_w = _pipe_cloexec()
kwargs['stdin'] = in_r
pipe_fds.extend((in_r, in_w))
to_close.append(in_r)
self.stdin = PipeIOStream(in_w, io_loop=self.io_loop)
if kwargs.get('stdout') is Subprocess.STREAM:
out_r, out_w = _pipe_cloexec()
kwargs['stdout'] = out_w
pipe_fds.extend((out_r, out_w))
to_close.append(out_w)
self.stdout = PipeIOStream(out_r, io_loop=self.io_loop)
if kwargs.get('stderr') is Subprocess.STREAM:
err_r, err_w = _pipe_cloexec()
kwargs['stderr'] = err_w
pipe_fds.extend((err_r, err_w))
to_close.append(err_w)
self.stderr = PipeIOStream(err_r, io_loop=self.io_loop)
try:
self.proc = subprocess.Popen(*args, **kwargs)
except:
for fd in pipe_fds:
os.close(fd)
raise
for fd in to_close:
os.close(fd)
for attr in ['stdin', 'stdout', 'stderr', 'pid']:
if not hasattr(self, attr): # don't clobber streams set above
setattr(self, attr, getattr(self.proc, attr))
self._exit_callback = None
self.returncode = None
def set_exit_callback(self, callback):
"""Runs ``callback`` when this process exits.
The callback takes one argument, the return code of the process.
This method uses a ``SIGCHILD`` handler, which is a global setting
and may conflict if you have other libraries trying to handle the
same signal. If you are using more than one ``IOLoop`` it may
be necessary to call `Subprocess.initialize` first to designate
one ``IOLoop`` to run the signal handlers.
In many cases a close callback on the stdout or stderr streams
can be used as an alternative to an exit callback if the
signal handler is causing a problem.
"""
self._exit_callback = stack_context.wrap(callback)
Subprocess.initialize(self.io_loop)
Subprocess._waiting[self.pid] = self
Subprocess._try_cleanup_process(self.pid)
@classmethod
def initialize(cls, io_loop=None):
"""Initializes the ``SIGCHILD`` handler.
The signal handler is run on an `.IOLoop` to avoid locking issues.
Note that the `.IOLoop` used for signal handling need not be the
same one used by individual Subprocess objects (as long as the
``IOLoops`` are each running in separate threads).
"""
if cls._initialized:
return
if io_loop is None:
io_loop = ioloop.IOLoop.current()
cls._old_sigchld = signal.signal(
signal.SIGCHLD,
lambda sig, frame: io_loop.add_callback_from_signal(cls._cleanup))
cls._initialized = True
@classmethod
def uninitialize(cls):
"""Removes the ``SIGCHILD`` handler."""
if not cls._initialized:
return
signal.signal(signal.SIGCHLD, cls._old_sigchld)
cls._initialized = False
@classmethod
def _cleanup(cls):
for pid in list(cls._waiting.keys()): # make a copy
cls._try_cleanup_process(pid)
@classmethod
def _try_cleanup_process(cls, pid):
try:
ret_pid, status = os.waitpid(pid, os.WNOHANG)
except OSError as e:
if errno_from_exception(e) == errno.ECHILD:
return
if ret_pid == 0:
return
assert ret_pid == pid
subproc = cls._waiting.pop(pid)
subproc.io_loop.add_callback_from_signal(
subproc._set_returncode, status)
def _set_returncode(self, status):
if os.WIFSIGNALED(status):
self.returncode = -os.WTERMSIG(status)
else:
assert os.WIFEXITED(status)
self.returncode = os.WEXITSTATUS(status)
if self._exit_callback:
callback = self._exit_callback
self._exit_callback = None
callback(self.returncode)

View file

@ -0,0 +1,527 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
from tornado.concurrent import is_future
from tornado.escape import utf8, _unicode
from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main, _RequestProxy
from tornado import httputil
from tornado.http1connection import HTTP1Connection, HTTP1ConnectionParameters
from tornado.iostream import StreamClosedError
from tornado.netutil import Resolver, OverrideResolver
from tornado.log import gen_log
from tornado import stack_context
from tornado.tcpclient import TCPClient
import base64
import collections
import copy
import functools
import re
import socket
import sys
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO # python 2
try:
import urlparse # py2
except ImportError:
import urllib.parse as urlparse # py3
try:
import ssl
except ImportError:
# ssl is not available on Google App Engine.
ssl = None
try:
import certifi
except ImportError:
certifi = None
def _default_ca_certs():
if certifi is None:
raise Exception("The 'certifi' package is required to use https "
"in simple_httpclient")
return certifi.where()
class SimpleAsyncHTTPClient(AsyncHTTPClient):
"""Non-blocking HTTP client with no external dependencies.
This class implements an HTTP 1.1 client on top of Tornado's IOStreams.
It does not currently implement all applicable parts of the HTTP
specification, but it does enough to work with major web service APIs.
Some features found in the curl-based AsyncHTTPClient are not yet
supported. In particular, proxies are not supported, connections
are not reused, and callers cannot select the network interface to be
used.
"""
def initialize(self, io_loop, max_clients=10,
hostname_mapping=None, max_buffer_size=104857600,
resolver=None, defaults=None, max_header_size=None):
"""Creates a AsyncHTTPClient.
Only a single AsyncHTTPClient instance exists per IOLoop
in order to provide limitations on the number of pending connections.
force_instance=True may be used to suppress this behavior.
max_clients is the number of concurrent requests that can be
in progress. Note that this arguments are only used when the
client is first created, and will be ignored when an existing
client is reused.
hostname_mapping is a dictionary mapping hostnames to IP addresses.
It can be used to make local DNS changes when modifying system-wide
settings like /etc/hosts is not possible or desirable (e.g. in
unittests).
max_buffer_size is the number of bytes that can be read by IOStream. It
defaults to 100mb.
"""
super(SimpleAsyncHTTPClient, self).initialize(io_loop,
defaults=defaults)
self.max_clients = max_clients
self.queue = collections.deque()
self.active = {}
self.waiting = {}
self.max_buffer_size = max_buffer_size
self.max_header_size = max_header_size
# TCPClient could create a Resolver for us, but we have to do it
# ourselves to support hostname_mapping.
if resolver:
self.resolver = resolver
self.own_resolver = False
else:
self.resolver = Resolver(io_loop=io_loop)
self.own_resolver = True
if hostname_mapping is not None:
self.resolver = OverrideResolver(resolver=self.resolver,
mapping=hostname_mapping)
self.tcp_client = TCPClient(resolver=self.resolver, io_loop=io_loop)
def close(self):
super(SimpleAsyncHTTPClient, self).close()
if self.own_resolver:
self.resolver.close()
self.tcp_client.close()
def fetch_impl(self, request, callback):
key = object()
self.queue.append((key, request, callback))
if not len(self.active) < self.max_clients:
timeout_handle = self.io_loop.add_timeout(
self.io_loop.time() + min(request.connect_timeout,
request.request_timeout),
functools.partial(self._on_timeout, key))
else:
timeout_handle = None
self.waiting[key] = (request, callback, timeout_handle)
self._process_queue()
if self.queue:
gen_log.debug("max_clients limit reached, request queued. "
"%d active, %d queued requests." % (
len(self.active), len(self.queue)))
def _process_queue(self):
with stack_context.NullContext():
while self.queue and len(self.active) < self.max_clients:
key, request, callback = self.queue.popleft()
if key not in self.waiting:
continue
self._remove_timeout(key)
self.active[key] = (request, callback)
release_callback = functools.partial(self._release_fetch, key)
self._handle_request(request, release_callback, callback)
def _handle_request(self, request, release_callback, final_callback):
_HTTPConnection(self.io_loop, self, request, release_callback,
final_callback, self.max_buffer_size, self.tcp_client,
self.max_header_size)
def _release_fetch(self, key):
del self.active[key]
self._process_queue()
def _remove_timeout(self, key):
if key in self.waiting:
request, callback, timeout_handle = self.waiting[key]
if timeout_handle is not None:
self.io_loop.remove_timeout(timeout_handle)
del self.waiting[key]
def _on_timeout(self, key):
request, callback, timeout_handle = self.waiting[key]
self.queue.remove((key, request, callback))
timeout_response = HTTPResponse(
request, 599, error=HTTPError(599, "Timeout"),
request_time=self.io_loop.time() - request.start_time)
self.io_loop.add_callback(callback, timeout_response)
del self.waiting[key]
class _HTTPConnection(httputil.HTTPMessageDelegate):
_SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
def __init__(self, io_loop, client, request, release_callback,
final_callback, max_buffer_size, tcp_client,
max_header_size):
self.start_time = io_loop.time()
self.io_loop = io_loop
self.client = client
self.request = request
self.release_callback = release_callback
self.final_callback = final_callback
self.max_buffer_size = max_buffer_size
self.tcp_client = tcp_client
self.max_header_size = max_header_size
self.code = None
self.headers = None
self.chunks = []
self._decompressor = None
# Timeout handle returned by IOLoop.add_timeout
self._timeout = None
self._sockaddr = None
with stack_context.ExceptionStackContext(self._handle_exception):
self.parsed = urlparse.urlsplit(_unicode(self.request.url))
if self.parsed.scheme not in ("http", "https"):
raise ValueError("Unsupported url scheme: %s" %
self.request.url)
# urlsplit results have hostname and port results, but they
# didn't support ipv6 literals until python 2.7.
netloc = self.parsed.netloc
if "@" in netloc:
userpass, _, netloc = netloc.rpartition("@")
match = re.match(r'^(.+):(\d+)$', netloc)
if match:
host = match.group(1)
port = int(match.group(2))
else:
host = netloc
port = 443 if self.parsed.scheme == "https" else 80
if re.match(r'^\[.*\]$', host):
# raw ipv6 addresses in urls are enclosed in brackets
host = host[1:-1]
self.parsed_hostname = host # save final host for _on_connect
if request.allow_ipv6 is False:
af = socket.AF_INET
else:
af = socket.AF_UNSPEC
ssl_options = self._get_ssl_options(self.parsed.scheme)
timeout = min(self.request.connect_timeout, self.request.request_timeout)
if timeout:
self._timeout = self.io_loop.add_timeout(
self.start_time + timeout,
stack_context.wrap(self._on_timeout))
self.tcp_client.connect(host, port, af=af,
ssl_options=ssl_options,
callback=self._on_connect)
def _get_ssl_options(self, scheme):
if scheme == "https":
ssl_options = {}
if self.request.validate_cert:
ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
if self.request.ca_certs is not None:
ssl_options["ca_certs"] = self.request.ca_certs
else:
ssl_options["ca_certs"] = _default_ca_certs()
if self.request.client_key is not None:
ssl_options["keyfile"] = self.request.client_key
if self.request.client_cert is not None:
ssl_options["certfile"] = self.request.client_cert
# SSL interoperability is tricky. We want to disable
# SSLv2 for security reasons; it wasn't disabled by default
# until openssl 1.0. The best way to do this is to use
# the SSL_OP_NO_SSLv2, but that wasn't exposed to python
# until 3.2. Python 2.7 adds the ciphers argument, which
# can also be used to disable SSLv2. As a last resort
# on python 2.6, we set ssl_version to TLSv1. This is
# more narrow than we'd like since it also breaks
# compatibility with servers configured for SSLv3 only,
# but nearly all servers support both SSLv3 and TLSv1:
# http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html
if sys.version_info >= (2, 7):
# In addition to disabling SSLv2, we also exclude certain
# classes of insecure ciphers.
ssl_options["ciphers"] = "DEFAULT:!SSLv2:!EXPORT:!DES"
else:
# This is really only necessary for pre-1.0 versions
# of openssl, but python 2.6 doesn't expose version
# information.
ssl_options["ssl_version"] = ssl.PROTOCOL_TLSv1
return ssl_options
return None
def _on_timeout(self):
self._timeout = None
if self.final_callback is not None:
raise HTTPError(599, "Timeout")
def _remove_timeout(self):
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
self._timeout = None
def _on_connect(self, stream):
if self.final_callback is None:
# final_callback is cleared if we've hit our timeout.
stream.close()
return
self.stream = stream
self.stream.set_close_callback(self._on_close)
self._remove_timeout()
if self.final_callback is None:
return
if self.request.request_timeout:
self._timeout = self.io_loop.add_timeout(
self.start_time + self.request.request_timeout,
stack_context.wrap(self._on_timeout))
if (self.request.method not in self._SUPPORTED_METHODS and
not self.request.allow_nonstandard_methods):
raise KeyError("unknown method %s" % self.request.method)
for key in ('network_interface',
'proxy_host', 'proxy_port',
'proxy_username', 'proxy_password'):
if getattr(self.request, key, None):
raise NotImplementedError('%s not supported' % key)
if "Connection" not in self.request.headers:
self.request.headers["Connection"] = "close"
if "Host" not in self.request.headers:
if '@' in self.parsed.netloc:
self.request.headers["Host"] = self.parsed.netloc.rpartition('@')[-1]
else:
self.request.headers["Host"] = self.parsed.netloc
username, password = None, None
if self.parsed.username is not None:
username, password = self.parsed.username, self.parsed.password
elif self.request.auth_username is not None:
username = self.request.auth_username
password = self.request.auth_password or ''
if username is not None:
if self.request.auth_mode not in (None, "basic"):
raise ValueError("unsupported auth_mode %s",
self.request.auth_mode)
auth = utf8(username) + b":" + utf8(password)
self.request.headers["Authorization"] = (b"Basic " +
base64.b64encode(auth))
if self.request.user_agent:
self.request.headers["User-Agent"] = self.request.user_agent
if not self.request.allow_nonstandard_methods:
if self.request.method in ("POST", "PATCH", "PUT"):
if (self.request.body is None and
self.request.body_producer is None):
raise AssertionError(
'Body must not be empty for "%s" request'
% self.request.method)
else:
if (self.request.body is not None or
self.request.body_producer is not None):
raise AssertionError(
'Body must be empty for "%s" request'
% self.request.method)
if self.request.expect_100_continue:
self.request.headers["Expect"] = "100-continue"
if self.request.body is not None:
# When body_producer is used the caller is responsible for
# setting Content-Length (or else chunked encoding will be used).
self.request.headers["Content-Length"] = str(len(
self.request.body))
if (self.request.method == "POST" and
"Content-Type" not in self.request.headers):
self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
if self.request.decompress_response:
self.request.headers["Accept-Encoding"] = "gzip"
req_path = ((self.parsed.path or '/') +
(('?' + self.parsed.query) if self.parsed.query else ''))
self.stream.set_nodelay(True)
self.connection = HTTP1Connection(
self.stream, True,
HTTP1ConnectionParameters(
no_keep_alive=True,
max_header_size=self.max_header_size,
decompress=self.request.decompress_response),
self._sockaddr)
start_line = httputil.RequestStartLine(self.request.method,
req_path, 'HTTP/1.1')
self.connection.write_headers(start_line, self.request.headers)
if self.request.expect_100_continue:
self._read_response()
else:
self._write_body(True)
def _write_body(self, start_read):
if self.request.body is not None:
self.connection.write(self.request.body)
self.connection.finish()
elif self.request.body_producer is not None:
fut = self.request.body_producer(self.connection.write)
if is_future(fut):
def on_body_written(fut):
fut.result()
self.connection.finish()
if start_read:
self._read_response()
self.io_loop.add_future(fut, on_body_written)
return
self.connection.finish()
if start_read:
self._read_response()
def _read_response(self):
# Ensure that any exception raised in read_response ends up in our
# stack context.
self.io_loop.add_future(
self.connection.read_response(self),
lambda f: f.result())
def _release(self):
if self.release_callback is not None:
release_callback = self.release_callback
self.release_callback = None
release_callback()
def _run_callback(self, response):
self._release()
if self.final_callback is not None:
final_callback = self.final_callback
self.final_callback = None
self.io_loop.add_callback(final_callback, response)
def _handle_exception(self, typ, value, tb):
if self.final_callback:
self._remove_timeout()
if isinstance(value, StreamClosedError):
value = HTTPError(599, "Stream closed")
self._run_callback(HTTPResponse(self.request, 599, error=value,
request_time=self.io_loop.time() - self.start_time,
))
if hasattr(self, "stream"):
# TODO: this may cause a StreamClosedError to be raised
# by the connection's Future. Should we cancel the
# connection more gracefully?
self.stream.close()
return True
else:
# If our callback has already been called, we are probably
# catching an exception that is not caused by us but rather
# some child of our callback. Rather than drop it on the floor,
# pass it along, unless it's just the stream being closed.
return isinstance(value, StreamClosedError)
def _on_close(self):
if self.final_callback is not None:
message = "Connection closed"
if self.stream.error:
raise self.stream.error
raise HTTPError(599, message)
def headers_received(self, first_line, headers):
if self.request.expect_100_continue and first_line.code == 100:
self._write_body(False)
return
self.headers = headers
self.code = first_line.code
self.reason = first_line.reason
if "Content-Length" in self.headers:
if "," in self.headers["Content-Length"]:
# Proxies sometimes cause Content-Length headers to get
# duplicated. If all the values are identical then we can
# use them but if they differ it's an error.
pieces = re.split(r',\s*', self.headers["Content-Length"])
if any(i != pieces[0] for i in pieces):
raise ValueError("Multiple unequal Content-Lengths: %r" %
self.headers["Content-Length"])
self.headers["Content-Length"] = pieces[0]
content_length = int(self.headers["Content-Length"])
else:
content_length = None
if self.request.header_callback is not None:
# Reassemble the start line.
self.request.header_callback('%s %s %s\r\n' % first_line)
for k, v in self.headers.get_all():
self.request.header_callback("%s: %s\r\n" % (k, v))
self.request.header_callback('\r\n')
if 100 <= self.code < 200 or self.code == 204:
# These response codes never have bodies
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
if ("Transfer-Encoding" in self.headers or
content_length not in (None, 0)):
raise ValueError("Response with code %d should not have body" %
self.code)
def finish(self):
data = b''.join(self.chunks)
self._remove_timeout()
original_request = getattr(self.request, "original_request",
self.request)
if (self.request.follow_redirects and
self.request.max_redirects > 0 and
self.code in (301, 302, 303, 307)):
assert isinstance(self.request, _RequestProxy)
new_request = copy.copy(self.request.request)
new_request.url = urlparse.urljoin(self.request.url,
self.headers["Location"])
new_request.max_redirects = self.request.max_redirects - 1
del new_request.headers["Host"]
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4
# Client SHOULD make a GET request after a 303.
# According to the spec, 302 should be followed by the same
# method as the original request, but in practice browsers
# treat 302 the same as 303, and many servers use 302 for
# compatibility with pre-HTTP/1.1 user agents which don't
# understand the 303 status.
if self.code in (302, 303):
new_request.method = "GET"
new_request.body = None
for h in ["Content-Length", "Content-Type",
"Content-Encoding", "Transfer-Encoding"]:
try:
del self.request.headers[h]
except KeyError:
pass
new_request.original_request = original_request
final_callback = self.final_callback
self.final_callback = None
self._release()
self.client.fetch(new_request, final_callback)
self._on_end_request()
return
if self.request.streaming_callback:
buffer = BytesIO()
else:
buffer = BytesIO(data) # TODO: don't require one big string?
response = HTTPResponse(original_request,
self.code, reason=getattr(self, 'reason', None),
headers=self.headers,
request_time=self.io_loop.time() - self.start_time,
buffer=buffer,
effective_url=self.request.url)
self._run_callback(response)
self._on_end_request()
def _on_end_request(self):
self.stream.close()
def data_received(self, chunk):
if self.request.streaming_callback is not None:
self.request.streaming_callback(chunk)
else:
self.chunks.append(chunk)
if __name__ == "__main__":
AsyncHTTPClient.configure(SimpleAsyncHTTPClient)
main()

View file

@ -0,0 +1,388 @@
#!/usr/bin/env python
#
# Copyright 2010 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""`StackContext` allows applications to maintain threadlocal-like state
that follows execution as it moves to other execution contexts.
The motivating examples are to eliminate the need for explicit
``async_callback`` wrappers (as in `tornado.web.RequestHandler`), and to
allow some additional context to be kept for logging.
This is slightly magic, but it's an extension of the idea that an
exception handler is a kind of stack-local state and when that stack
is suspended and resumed in a new context that state needs to be
preserved. `StackContext` shifts the burden of restoring that state
from each call site (e.g. wrapping each `.AsyncHTTPClient` callback
in ``async_callback``) to the mechanisms that transfer control from
one context to another (e.g. `.AsyncHTTPClient` itself, `.IOLoop`,
thread pools, etc).
Example usage::
@contextlib.contextmanager
def die_on_error():
try:
yield
except Exception:
logging.error("exception in asynchronous operation",exc_info=True)
sys.exit(1)
with StackContext(die_on_error):
# Any exception thrown here *or in callback and its desendents*
# will cause the process to exit instead of spinning endlessly
# in the ioloop.
http_client.fetch(url, callback)
ioloop.start()
Most applications shouln't have to work with `StackContext` directly.
Here are a few rules of thumb for when it's necessary:
* If you're writing an asynchronous library that doesn't rely on a
stack_context-aware library like `tornado.ioloop` or `tornado.iostream`
(for example, if you're writing a thread pool), use
`.stack_context.wrap()` before any asynchronous operations to capture the
stack context from where the operation was started.
* If you're writing an asynchronous library that has some shared
resources (such as a connection pool), create those shared resources
within a ``with stack_context.NullContext():`` block. This will prevent
``StackContexts`` from leaking from one request to another.
* If you want to write something like an exception handler that will
persist across asynchronous calls, create a new `StackContext` (or
`ExceptionStackContext`), and make your asynchronous calls in a ``with``
block that references your `StackContext`.
"""
from __future__ import absolute_import, division, print_function, with_statement
import sys
import threading
from tornado.util import raise_exc_info
class StackContextInconsistentError(Exception):
pass
class _State(threading.local):
def __init__(self):
self.contexts = (tuple(), None)
_state = _State()
class StackContext(object):
"""Establishes the given context as a StackContext that will be transferred.
Note that the parameter is a callable that returns a context
manager, not the context itself. That is, where for a
non-transferable context manager you would say::
with my_context():
StackContext takes the function itself rather than its result::
with StackContext(my_context):
The result of ``with StackContext() as cb:`` is a deactivation
callback. Run this callback when the StackContext is no longer
needed to ensure that it is not propagated any further (note that
deactivating a context does not affect any instances of that
context that are currently pending). This is an advanced feature
and not necessary in most applications.
"""
def __init__(self, context_factory):
self.context_factory = context_factory
self.contexts = []
self.active = True
def _deactivate(self):
self.active = False
# StackContext protocol
def enter(self):
context = self.context_factory()
self.contexts.append(context)
context.__enter__()
def exit(self, type, value, traceback):
context = self.contexts.pop()
context.__exit__(type, value, traceback)
# Note that some of this code is duplicated in ExceptionStackContext
# below. ExceptionStackContext is more common and doesn't need
# the full generality of this class.
def __enter__(self):
self.old_contexts = _state.contexts
self.new_contexts = (self.old_contexts[0] + (self,), self)
_state.contexts = self.new_contexts
try:
self.enter()
except:
_state.contexts = self.old_contexts
raise
return self._deactivate
def __exit__(self, type, value, traceback):
try:
self.exit(type, value, traceback)
finally:
final_contexts = _state.contexts
_state.contexts = self.old_contexts
# Generator coroutines and with-statements with non-local
# effects interact badly. Check here for signs of
# the stack getting out of sync.
# Note that this check comes after restoring _state.context
# so that if it fails things are left in a (relatively)
# consistent state.
if final_contexts is not self.new_contexts:
raise StackContextInconsistentError(
'stack_context inconsistency (may be caused by yield '
'within a "with StackContext" block)')
# Break up a reference to itself to allow for faster GC on CPython.
self.new_contexts = None
class ExceptionStackContext(object):
"""Specialization of StackContext for exception handling.
The supplied ``exception_handler`` function will be called in the
event of an uncaught exception in this context. The semantics are
similar to a try/finally clause, and intended use cases are to log
an error, close a socket, or similar cleanup actions. The
``exc_info`` triple ``(type, value, traceback)`` will be passed to the
exception_handler function.
If the exception handler returns true, the exception will be
consumed and will not be propagated to other exception handlers.
"""
def __init__(self, exception_handler):
self.exception_handler = exception_handler
self.active = True
def _deactivate(self):
self.active = False
def exit(self, type, value, traceback):
if type is not None:
return self.exception_handler(type, value, traceback)
def __enter__(self):
self.old_contexts = _state.contexts
self.new_contexts = (self.old_contexts[0], self)
_state.contexts = self.new_contexts
return self._deactivate
def __exit__(self, type, value, traceback):
try:
if type is not None:
return self.exception_handler(type, value, traceback)
finally:
final_contexts = _state.contexts
_state.contexts = self.old_contexts
if final_contexts is not self.new_contexts:
raise StackContextInconsistentError(
'stack_context inconsistency (may be caused by yield '
'within a "with StackContext" block)')
# Break up a reference to itself to allow for faster GC on CPython.
self.new_contexts = None
class NullContext(object):
"""Resets the `StackContext`.
Useful when creating a shared resource on demand (e.g. an
`.AsyncHTTPClient`) where the stack that caused the creating is
not relevant to future operations.
"""
def __enter__(self):
self.old_contexts = _state.contexts
_state.contexts = (tuple(), None)
def __exit__(self, type, value, traceback):
_state.contexts = self.old_contexts
def _remove_deactivated(contexts):
"""Remove deactivated handlers from the chain"""
# Clean ctx handlers
stack_contexts = tuple([h for h in contexts[0] if h.active])
# Find new head
head = contexts[1]
while head is not None and not head.active:
head = head.old_contexts[1]
# Process chain
ctx = head
while ctx is not None:
parent = ctx.old_contexts[1]
while parent is not None:
if parent.active:
break
ctx.old_contexts = parent.old_contexts
parent = parent.old_contexts[1]
ctx = parent
return (stack_contexts, head)
def wrap(fn):
"""Returns a callable object that will restore the current `StackContext`
when executed.
Use this whenever saving a callback to be executed later in a
different execution context (either in a different thread or
asynchronously in the same thread).
"""
# Check if function is already wrapped
if fn is None or hasattr(fn, '_wrapped'):
return fn
# Capture current stack head
# TODO: Any other better way to store contexts and update them in wrapped function?
cap_contexts = [_state.contexts]
if not cap_contexts[0][0] and not cap_contexts[0][1]:
# Fast path when there are no active contexts.
def null_wrapper(*args, **kwargs):
try:
current_state = _state.contexts
_state.contexts = cap_contexts[0]
return fn(*args, **kwargs)
finally:
_state.contexts = current_state
null_wrapper._wrapped = True
return null_wrapper
def wrapped(*args, **kwargs):
ret = None
try:
# Capture old state
current_state = _state.contexts
# Remove deactivated items
cap_contexts[0] = contexts = _remove_deactivated(cap_contexts[0])
# Force new state
_state.contexts = contexts
# Current exception
exc = (None, None, None)
top = None
# Apply stack contexts
last_ctx = 0
stack = contexts[0]
# Apply state
for n in stack:
try:
n.enter()
last_ctx += 1
except:
# Exception happened. Record exception info and store top-most handler
exc = sys.exc_info()
top = n.old_contexts[1]
# Execute callback if no exception happened while restoring state
if top is None:
try:
ret = fn(*args, **kwargs)
except:
exc = sys.exc_info()
top = contexts[1]
# If there was exception, try to handle it by going through the exception chain
if top is not None:
exc = _handle_exception(top, exc)
else:
# Otherwise take shorter path and run stack contexts in reverse order
while last_ctx > 0:
last_ctx -= 1
c = stack[last_ctx]
try:
c.exit(*exc)
except:
exc = sys.exc_info()
top = c.old_contexts[1]
break
else:
top = None
# If if exception happened while unrolling, take longer exception handler path
if top is not None:
exc = _handle_exception(top, exc)
# If exception was not handled, raise it
if exc != (None, None, None):
raise_exc_info(exc)
finally:
_state.contexts = current_state
return ret
wrapped._wrapped = True
return wrapped
def _handle_exception(tail, exc):
while tail is not None:
try:
if tail.exit(*exc):
exc = (None, None, None)
except:
exc = sys.exc_info()
tail = tail.old_contexts[1]
return exc
def run_with_stack_context(context, func):
"""Run a coroutine ``func`` in the given `StackContext`.
It is not safe to have a ``yield`` statement within a ``with StackContext``
block, so it is difficult to use stack context with `.gen.coroutine`.
This helper function runs the function in the correct context while
keeping the ``yield`` and ``with`` statements syntactically separate.
Example::
@gen.coroutine
def incorrect():
with StackContext(ctx):
# ERROR: this will raise StackContextInconsistentError
yield other_coroutine()
@gen.coroutine
def correct():
yield run_with_stack_context(StackContext(ctx), other_coroutine)
.. versionadded:: 3.1
"""
with context:
return func()

View file

@ -0,0 +1,179 @@
#!/usr/bin/env python
#
# Copyright 2014 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""A non-blocking TCP connection factory.
"""
from __future__ import absolute_import, division, print_function, with_statement
import functools
import socket
from tornado.concurrent import Future
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream
from tornado import gen
from tornado.netutil import Resolver
_INITIAL_CONNECT_TIMEOUT = 0.3
class _Connector(object):
"""A stateless implementation of the "Happy Eyeballs" algorithm.
"Happy Eyeballs" is documented in RFC6555 as the recommended practice
for when both IPv4 and IPv6 addresses are available.
In this implementation, we partition the addresses by family, and
make the first connection attempt to whichever address was
returned first by ``getaddrinfo``. If that connection fails or
times out, we begin a connection in parallel to the first address
of the other family. If there are additional failures we retry
with other addresses, keeping one connection attempt per family
in flight at a time.
http://tools.ietf.org/html/rfc6555
"""
def __init__(self, addrinfo, io_loop, connect):
self.io_loop = io_loop
self.connect = connect
self.future = Future()
self.timeout = None
self.last_error = None
self.remaining = len(addrinfo)
self.primary_addrs, self.secondary_addrs = self.split(addrinfo)
@staticmethod
def split(addrinfo):
"""Partition the ``addrinfo`` list by address family.
Returns two lists. The first list contains the first entry from
``addrinfo`` and all others with the same family, and the
second list contains all other addresses (normally one list will
be AF_INET and the other AF_INET6, although non-standard resolvers
may return additional families).
"""
primary = []
secondary = []
primary_af = addrinfo[0][0]
for af, addr in addrinfo:
if af == primary_af:
primary.append((af, addr))
else:
secondary.append((af, addr))
return primary, secondary
def start(self, timeout=_INITIAL_CONNECT_TIMEOUT):
self.try_connect(iter(self.primary_addrs))
self.set_timout(timeout)
return self.future
def try_connect(self, addrs):
try:
af, addr = next(addrs)
except StopIteration:
# We've reached the end of our queue, but the other queue
# might still be working. Send a final error on the future
# only when both queues are finished.
if self.remaining == 0 and not self.future.done():
self.future.set_exception(self.last_error or
IOError("connection failed"))
return
future = self.connect(af, addr)
future.add_done_callback(functools.partial(self.on_connect_done,
addrs, af, addr))
def on_connect_done(self, addrs, af, addr, future):
self.remaining -= 1
try:
stream = future.result()
except Exception as e:
if self.future.done():
return
# Error: try again (but remember what happened so we have an
# error to raise in the end)
self.last_error = e
self.try_connect(addrs)
if self.timeout is not None:
# If the first attempt failed, don't wait for the
# timeout to try an address from the secondary queue.
self.on_timeout()
return
self.clear_timeout()
if self.future.done():
# This is a late arrival; just drop it.
stream.close()
else:
self.future.set_result((af, addr, stream))
def set_timout(self, timeout):
self.timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout,
self.on_timeout)
def on_timeout(self):
self.timeout = None
self.try_connect(iter(self.secondary_addrs))
def clear_timeout(self):
if self.timeout is not None:
self.io_loop.remove_timeout(self.timeout)
class TCPClient(object):
"""A non-blocking TCP connection factory.
"""
def __init__(self, resolver=None, io_loop=None):
self.io_loop = io_loop or IOLoop.current()
if resolver is not None:
self.resolver = resolver
self._own_resolver = False
else:
self.resolver = Resolver(io_loop=io_loop)
self._own_resolver = True
def close(self):
if self._own_resolver:
self.resolver.close()
@gen.coroutine
def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None,
max_buffer_size=None):
"""Connect to the given host and port.
Asynchronously returns an `.IOStream` (or `.SSLIOStream` if
``ssl_options`` is not None).
"""
addrinfo = yield self.resolver.resolve(host, port, af)
connector = _Connector(
addrinfo, self.io_loop,
functools.partial(self._create_stream, max_buffer_size))
af, addr, stream = yield connector.start()
# TODO: For better performance we could cache the (af, addr)
# information here and re-use it on sbusequent connections to
# the same host. (http://tools.ietf.org/html/rfc6555#section-4.2)
if ssl_options is not None:
stream = yield stream.start_tls(False, ssl_options=ssl_options,
server_hostname=host)
raise gen.Return(stream)
def _create_stream(self, max_buffer_size, af, addr):
# Always connect in plaintext; we'll convert to ssl if necessary
# after one connection has completed.
stream = IOStream(socket.socket(af),
io_loop=self.io_loop,
max_buffer_size=max_buffer_size)
return stream.connect(addr)

View file

@ -0,0 +1,257 @@
#!/usr/bin/env python
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""A non-blocking, single-threaded TCP server."""
from __future__ import absolute_import, division, print_function, with_statement
import errno
import os
import socket
from tornado.log import app_log
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream, SSLIOStream
from tornado.netutil import bind_sockets, add_accept_handler, ssl_wrap_socket
from tornado import process
from tornado.util import errno_from_exception
try:
import ssl
except ImportError:
# ssl is not available on Google App Engine.
ssl = None
class TCPServer(object):
r"""A non-blocking, single-threaded TCP server.
To use `TCPServer`, define a subclass which overrides the `handle_stream`
method.
To make this server serve SSL traffic, send the ssl_options dictionary
argument with the arguments required for the `ssl.wrap_socket` method,
including "certfile" and "keyfile"::
TCPServer(ssl_options={
"certfile": os.path.join(data_dir, "mydomain.crt"),
"keyfile": os.path.join(data_dir, "mydomain.key"),
})
`TCPServer` initialization follows one of three patterns:
1. `listen`: simple single-process::
server = TCPServer()
server.listen(8888)
IOLoop.instance().start()
2. `bind`/`start`: simple multi-process::
server = TCPServer()
server.bind(8888)
server.start(0) # Forks multiple sub-processes
IOLoop.instance().start()
When using this interface, an `.IOLoop` must *not* be passed
to the `TCPServer` constructor. `start` will always start
the server on the default singleton `.IOLoop`.
3. `add_sockets`: advanced multi-process::
sockets = bind_sockets(8888)
tornado.process.fork_processes(0)
server = TCPServer()
server.add_sockets(sockets)
IOLoop.instance().start()
The `add_sockets` interface is more complicated, but it can be
used with `tornado.process.fork_processes` to give you more
flexibility in when the fork happens. `add_sockets` can
also be used in single-process servers if you want to create
your listening sockets in some way other than
`~tornado.netutil.bind_sockets`.
.. versionadded:: 3.1
The ``max_buffer_size`` argument.
"""
def __init__(self, io_loop=None, ssl_options=None, max_buffer_size=None,
read_chunk_size=None):
self.io_loop = io_loop
self.ssl_options = ssl_options
self._sockets = {} # fd -> socket object
self._pending_sockets = []
self._started = False
self.max_buffer_size = max_buffer_size
self.read_chunk_size = None
# Verify the SSL options. Otherwise we don't get errors until clients
# connect. This doesn't verify that the keys are legitimate, but
# the SSL module doesn't do that until there is a connected socket
# which seems like too much work
if self.ssl_options is not None and isinstance(self.ssl_options, dict):
# Only certfile is required: it can contain both keys
if 'certfile' not in self.ssl_options:
raise KeyError('missing key "certfile" in ssl_options')
if not os.path.exists(self.ssl_options['certfile']):
raise ValueError('certfile "%s" does not exist' %
self.ssl_options['certfile'])
if ('keyfile' in self.ssl_options and
not os.path.exists(self.ssl_options['keyfile'])):
raise ValueError('keyfile "%s" does not exist' %
self.ssl_options['keyfile'])
def listen(self, port, address=""):
"""Starts accepting connections on the given port.
This method may be called more than once to listen on multiple ports.
`listen` takes effect immediately; it is not necessary to call
`TCPServer.start` afterwards. It is, however, necessary to start
the `.IOLoop`.
"""
sockets = bind_sockets(port, address=address)
self.add_sockets(sockets)
def add_sockets(self, sockets):
"""Makes this server start accepting connections on the given sockets.
The ``sockets`` parameter is a list of socket objects such as
those returned by `~tornado.netutil.bind_sockets`.
`add_sockets` is typically used in combination with that
method and `tornado.process.fork_processes` to provide greater
control over the initialization of a multi-process server.
"""
if self.io_loop is None:
self.io_loop = IOLoop.current()
for sock in sockets:
self._sockets[sock.fileno()] = sock
add_accept_handler(sock, self._handle_connection,
io_loop=self.io_loop)
def add_socket(self, socket):
"""Singular version of `add_sockets`. Takes a single socket object."""
self.add_sockets([socket])
def bind(self, port, address=None, family=socket.AF_UNSPEC, backlog=128):
"""Binds this server to the given port on the given address.
To start the server, call `start`. If you want to run this server
in a single process, you can call `listen` as a shortcut to the
sequence of `bind` and `start` calls.
Address may be either an IP address or hostname. If it's a hostname,
the server will listen on all IP addresses associated with the
name. Address may be an empty string or None to listen on all
available interfaces. Family may be set to either `socket.AF_INET`
or `socket.AF_INET6` to restrict to IPv4 or IPv6 addresses, otherwise
both will be used if available.
The ``backlog`` argument has the same meaning as for
`socket.listen <socket.socket.listen>`.
This method may be called multiple times prior to `start` to listen
on multiple ports or interfaces.
"""
sockets = bind_sockets(port, address=address, family=family,
backlog=backlog)
if self._started:
self.add_sockets(sockets)
else:
self._pending_sockets.extend(sockets)
def start(self, num_processes=1):
"""Starts this server in the `.IOLoop`.
By default, we run the server in this process and do not fork any
additional child process.
If num_processes is ``None`` or <= 0, we detect the number of cores
available on this machine and fork that number of child
processes. If num_processes is given and > 1, we fork that
specific number of sub-processes.
Since we use processes and not threads, there is no shared memory
between any server code.
Note that multiple processes are not compatible with the autoreload
module (or the ``autoreload=True`` option to `tornado.web.Application`
which defaults to True when ``debug=True``).
When using multiple processes, no IOLoops can be created or
referenced until after the call to ``TCPServer.start(n)``.
"""
assert not self._started
self._started = True
if num_processes != 1:
process.fork_processes(num_processes)
sockets = self._pending_sockets
self._pending_sockets = []
self.add_sockets(sockets)
def stop(self):
"""Stops listening for new connections.
Requests currently in progress may still continue after the
server is stopped.
"""
for fd, sock in self._sockets.items():
self.io_loop.remove_handler(fd)
sock.close()
def handle_stream(self, stream, address):
"""Override to handle a new `.IOStream` from an incoming connection."""
raise NotImplementedError()
def _handle_connection(self, connection, address):
if self.ssl_options is not None:
assert ssl, "Python 2.6+ and OpenSSL required for SSL"
try:
connection = ssl_wrap_socket(connection,
self.ssl_options,
server_side=True,
do_handshake_on_connect=False)
except ssl.SSLError as err:
if err.args[0] == ssl.SSL_ERROR_EOF:
return connection.close()
else:
raise
except socket.error as err:
# If the connection is closed immediately after it is created
# (as in a port scan), we can get one of several errors.
# wrap_socket makes an internal call to getpeername,
# which may return either EINVAL (Mac OS X) or ENOTCONN
# (Linux). If it returns ENOTCONN, this error is
# silently swallowed by the ssl module, so we need to
# catch another error later on (AttributeError in
# SSLIOStream._do_ssl_handshake).
# To test this behavior, try nmap with the -sT flag.
# https://github.com/tornadoweb/tornado/pull/750
if errno_from_exception(err) in (errno.ECONNABORTED, errno.EINVAL):
return connection.close()
else:
raise
try:
if self.ssl_options is not None:
stream = SSLIOStream(connection, io_loop=self.io_loop,
max_buffer_size=self.max_buffer_size,
read_chunk_size=self.read_chunk_size)
else:
stream = IOStream(connection, io_loop=self.io_loop,
max_buffer_size=self.max_buffer_size,
read_chunk_size=self.read_chunk_size)
self.handle_stream(stream, address)
except Exception:
app_log.error("Error in connection callback", exc_info=True)

View file

@ -0,0 +1,865 @@
#!/usr/bin/env python
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""A simple template system that compiles templates to Python code.
Basic usage looks like::
t = template.Template("<html>{{ myvalue }}</html>")
print t.generate(myvalue="XXX")
`Loader` is a class that loads templates from a root directory and caches
the compiled templates::
loader = template.Loader("/home/btaylor")
print loader.load("test.html").generate(myvalue="XXX")
We compile all templates to raw Python. Error-reporting is currently... uh,
interesting. Syntax for the templates::
### base.html
<html>
<head>
<title>{% block title %}Default title{% end %}</title>
</head>
<body>
<ul>
{% for student in students %}
{% block student %}
<li>{{ escape(student.name) }}</li>
{% end %}
{% end %}
</ul>
</body>
</html>
### bold.html
{% extends "base.html" %}
{% block title %}A bolder title{% end %}
{% block student %}
<li><span style="bold">{{ escape(student.name) }}</span></li>
{% end %}
Unlike most other template systems, we do not put any restrictions on the
expressions you can include in your statements. ``if`` and ``for`` blocks get
translated exactly into Python, so you can do complex expressions like::
{% for student in [p for p in people if p.student and p.age > 23] %}
<li>{{ escape(student.name) }}</li>
{% end %}
Translating directly to Python means you can apply functions to expressions
easily, like the ``escape()`` function in the examples above. You can pass
functions in to your template just like any other variable
(In a `.RequestHandler`, override `.RequestHandler.get_template_namespace`)::
### Python code
def add(x, y):
return x + y
template.execute(add=add)
### The template
{{ add(1, 2) }}
We provide the functions `escape() <.xhtml_escape>`, `.url_escape()`,
`.json_encode()`, and `.squeeze()` to all templates by default.
Typical applications do not create `Template` or `Loader` instances by
hand, but instead use the `~.RequestHandler.render` and
`~.RequestHandler.render_string` methods of
`tornado.web.RequestHandler`, which load templates automatically based
on the ``template_path`` `.Application` setting.
Variable names beginning with ``_tt_`` are reserved by the template
system and should not be used by application code.
Syntax Reference
----------------
Template expressions are surrounded by double curly braces: ``{{ ... }}``.
The contents may be any python expression, which will be escaped according
to the current autoescape setting and inserted into the output. Other
template directives use ``{% %}``. These tags may be escaped as ``{{!``
and ``{%!`` if you need to include a literal ``{{`` or ``{%`` in the output.
To comment out a section so that it is omitted from the output, surround it
with ``{# ... #}``.
``{% apply *function* %}...{% end %}``
Applies a function to the output of all template code between ``apply``
and ``end``::
{% apply linkify %}{{name}} said: {{message}}{% end %}
Note that as an implementation detail apply blocks are implemented
as nested functions and thus may interact strangely with variables
set via ``{% set %}``, or the use of ``{% break %}`` or ``{% continue %}``
within loops.
``{% autoescape *function* %}``
Sets the autoescape mode for the current file. This does not affect
other files, even those referenced by ``{% include %}``. Note that
autoescaping can also be configured globally, at the `.Application`
or `Loader`.::
{% autoescape xhtml_escape %}
{% autoescape None %}
``{% block *name* %}...{% end %}``
Indicates a named, replaceable block for use with ``{% extends %}``.
Blocks in the parent template will be replaced with the contents of
the same-named block in a child template.::
<!-- base.html -->
<title>{% block title %}Default title{% end %}</title>
<!-- mypage.html -->
{% extends "base.html" %}
{% block title %}My page title{% end %}
``{% comment ... %}``
A comment which will be removed from the template output. Note that
there is no ``{% end %}`` tag; the comment goes from the word ``comment``
to the closing ``%}`` tag.
``{% extends *filename* %}``
Inherit from another template. Templates that use ``extends`` should
contain one or more ``block`` tags to replace content from the parent
template. Anything in the child template not contained in a ``block``
tag will be ignored. For an example, see the ``{% block %}`` tag.
``{% for *var* in *expr* %}...{% end %}``
Same as the python ``for`` statement. ``{% break %}`` and
``{% continue %}`` may be used inside the loop.
``{% from *x* import *y* %}``
Same as the python ``import`` statement.
``{% if *condition* %}...{% elif *condition* %}...{% else %}...{% end %}``
Conditional statement - outputs the first section whose condition is
true. (The ``elif`` and ``else`` sections are optional)
``{% import *module* %}``
Same as the python ``import`` statement.
``{% include *filename* %}``
Includes another template file. The included file can see all the local
variables as if it were copied directly to the point of the ``include``
directive (the ``{% autoescape %}`` directive is an exception).
Alternately, ``{% module Template(filename, **kwargs) %}`` may be used
to include another template with an isolated namespace.
``{% module *expr* %}``
Renders a `~tornado.web.UIModule`. The output of the ``UIModule`` is
not escaped::
{% module Template("foo.html", arg=42) %}
``UIModules`` are a feature of the `tornado.web.RequestHandler`
class (and specifically its ``render`` method) and will not work
when the template system is used on its own in other contexts.
``{% raw *expr* %}``
Outputs the result of the given expression without autoescaping.
``{% set *x* = *y* %}``
Sets a local variable.
``{% try %}...{% except %}...{% else %}...{% finally %}...{% end %}``
Same as the python ``try`` statement.
``{% while *condition* %}... {% end %}``
Same as the python ``while`` statement. ``{% break %}`` and
``{% continue %}`` may be used inside the loop.
"""
from __future__ import absolute_import, division, print_function, with_statement
import datetime
import linecache
import os.path
import posixpath
import re
import threading
from tornado import escape
from tornado.log import app_log
from tornado.util import bytes_type, ObjectDict, exec_in, unicode_type
try:
from cStringIO import StringIO # py2
except ImportError:
from io import StringIO # py3
_DEFAULT_AUTOESCAPE = "xhtml_escape"
_UNSET = object()
class Template(object):
"""A compiled template.
We compile into Python from the given template_string. You can generate
the template from variables with generate().
"""
# note that the constructor's signature is not extracted with
# autodoc because _UNSET looks like garbage. When changing
# this signature update website/sphinx/template.rst too.
def __init__(self, template_string, name="<string>", loader=None,
compress_whitespace=None, autoescape=_UNSET):
self.name = name
if compress_whitespace is None:
compress_whitespace = name.endswith(".html") or \
name.endswith(".js")
if autoescape is not _UNSET:
self.autoescape = autoescape
elif loader:
self.autoescape = loader.autoescape
else:
self.autoescape = _DEFAULT_AUTOESCAPE
self.namespace = loader.namespace if loader else {}
reader = _TemplateReader(name, escape.native_str(template_string))
self.file = _File(self, _parse(reader, self))
self.code = self._generate_python(loader, compress_whitespace)
self.loader = loader
try:
# Under python2.5, the fake filename used here must match
# the module name used in __name__ below.
# The dont_inherit flag prevents template.py's future imports
# from being applied to the generated code.
self.compiled = compile(
escape.to_unicode(self.code),
"%s.generated.py" % self.name.replace('.', '_'),
"exec", dont_inherit=True)
except Exception:
formatted_code = _format_code(self.code).rstrip()
app_log.error("%s code:\n%s", self.name, formatted_code)
raise
def generate(self, **kwargs):
"""Generate this template with the given arguments."""
namespace = {
"escape": escape.xhtml_escape,
"xhtml_escape": escape.xhtml_escape,
"url_escape": escape.url_escape,
"json_encode": escape.json_encode,
"squeeze": escape.squeeze,
"linkify": escape.linkify,
"datetime": datetime,
"_tt_utf8": escape.utf8, # for internal use
"_tt_string_types": (unicode_type, bytes_type),
# __name__ and __loader__ allow the traceback mechanism to find
# the generated source code.
"__name__": self.name.replace('.', '_'),
"__loader__": ObjectDict(get_source=lambda name: self.code),
}
namespace.update(self.namespace)
namespace.update(kwargs)
exec_in(self.compiled, namespace)
execute = namespace["_tt_execute"]
# Clear the traceback module's cache of source data now that
# we've generated a new template (mainly for this module's
# unittests, where different tests reuse the same name).
linecache.clearcache()
return execute()
def _generate_python(self, loader, compress_whitespace):
buffer = StringIO()
try:
# named_blocks maps from names to _NamedBlock objects
named_blocks = {}
ancestors = self._get_ancestors(loader)
ancestors.reverse()
for ancestor in ancestors:
ancestor.find_named_blocks(loader, named_blocks)
writer = _CodeWriter(buffer, named_blocks, loader, ancestors[0].template,
compress_whitespace)
ancestors[0].generate(writer)
return buffer.getvalue()
finally:
buffer.close()
def _get_ancestors(self, loader):
ancestors = [self.file]
for chunk in self.file.body.chunks:
if isinstance(chunk, _ExtendsBlock):
if not loader:
raise ParseError("{% extends %} block found, but no "
"template loader")
template = loader.load(chunk.name, self.name)
ancestors.extend(template._get_ancestors(loader))
return ancestors
class BaseLoader(object):
"""Base class for template loaders.
You must use a template loader to use template constructs like
``{% extends %}`` and ``{% include %}``. The loader caches all
templates after they are loaded the first time.
"""
def __init__(self, autoescape=_DEFAULT_AUTOESCAPE, namespace=None):
"""``autoescape`` must be either None or a string naming a function
in the template namespace, such as "xhtml_escape".
"""
self.autoescape = autoescape
self.namespace = namespace or {}
self.templates = {}
# self.lock protects self.templates. It's a reentrant lock
# because templates may load other templates via `include` or
# `extends`. Note that thanks to the GIL this code would be safe
# even without the lock, but could lead to wasted work as multiple
# threads tried to compile the same template simultaneously.
self.lock = threading.RLock()
def reset(self):
"""Resets the cache of compiled templates."""
with self.lock:
self.templates = {}
def resolve_path(self, name, parent_path=None):
"""Converts a possibly-relative path to absolute (used internally)."""
raise NotImplementedError()
def load(self, name, parent_path=None):
"""Loads a template."""
name = self.resolve_path(name, parent_path=parent_path)
with self.lock:
if name not in self.templates:
self.templates[name] = self._create_template(name)
return self.templates[name]
def _create_template(self, name):
raise NotImplementedError()
class Loader(BaseLoader):
"""A template loader that loads from a single root directory.
"""
def __init__(self, root_directory, **kwargs):
super(Loader, self).__init__(**kwargs)
self.root = os.path.abspath(root_directory)
def resolve_path(self, name, parent_path=None):
if parent_path and not parent_path.startswith("<") and \
not parent_path.startswith("/") and \
not name.startswith("/"):
current_path = os.path.join(self.root, parent_path)
file_dir = os.path.dirname(os.path.abspath(current_path))
relative_path = os.path.abspath(os.path.join(file_dir, name))
if relative_path.startswith(self.root):
name = relative_path[len(self.root) + 1:]
return name
def _create_template(self, name):
path = os.path.join(self.root, name)
with open(path, "rb") as f:
template = Template(f.read(), name=name, loader=self)
return template
class DictLoader(BaseLoader):
"""A template loader that loads from a dictionary."""
def __init__(self, dict, **kwargs):
super(DictLoader, self).__init__(**kwargs)
self.dict = dict
def resolve_path(self, name, parent_path=None):
if parent_path and not parent_path.startswith("<") and \
not parent_path.startswith("/") and \
not name.startswith("/"):
file_dir = posixpath.dirname(parent_path)
name = posixpath.normpath(posixpath.join(file_dir, name))
return name
def _create_template(self, name):
return Template(self.dict[name], name=name, loader=self)
class _Node(object):
def each_child(self):
return ()
def generate(self, writer):
raise NotImplementedError()
def find_named_blocks(self, loader, named_blocks):
for child in self.each_child():
child.find_named_blocks(loader, named_blocks)
class _File(_Node):
def __init__(self, template, body):
self.template = template
self.body = body
self.line = 0
def generate(self, writer):
writer.write_line("def _tt_execute():", self.line)
with writer.indent():
writer.write_line("_tt_buffer = []", self.line)
writer.write_line("_tt_append = _tt_buffer.append", self.line)
self.body.generate(writer)
writer.write_line("return _tt_utf8('').join(_tt_buffer)", self.line)
def each_child(self):
return (self.body,)
class _ChunkList(_Node):
def __init__(self, chunks):
self.chunks = chunks
def generate(self, writer):
for chunk in self.chunks:
chunk.generate(writer)
def each_child(self):
return self.chunks
class _NamedBlock(_Node):
def __init__(self, name, body, template, line):
self.name = name
self.body = body
self.template = template
self.line = line
def each_child(self):
return (self.body,)
def generate(self, writer):
block = writer.named_blocks[self.name]
with writer.include(block.template, self.line):
block.body.generate(writer)
def find_named_blocks(self, loader, named_blocks):
named_blocks[self.name] = self
_Node.find_named_blocks(self, loader, named_blocks)
class _ExtendsBlock(_Node):
def __init__(self, name):
self.name = name
class _IncludeBlock(_Node):
def __init__(self, name, reader, line):
self.name = name
self.template_name = reader.name
self.line = line
def find_named_blocks(self, loader, named_blocks):
included = loader.load(self.name, self.template_name)
included.file.find_named_blocks(loader, named_blocks)
def generate(self, writer):
included = writer.loader.load(self.name, self.template_name)
with writer.include(included, self.line):
included.file.body.generate(writer)
class _ApplyBlock(_Node):
def __init__(self, method, line, body=None):
self.method = method
self.line = line
self.body = body
def each_child(self):
return (self.body,)
def generate(self, writer):
method_name = "_tt_apply%d" % writer.apply_counter
writer.apply_counter += 1
writer.write_line("def %s():" % method_name, self.line)
with writer.indent():
writer.write_line("_tt_buffer = []", self.line)
writer.write_line("_tt_append = _tt_buffer.append", self.line)
self.body.generate(writer)
writer.write_line("return _tt_utf8('').join(_tt_buffer)", self.line)
writer.write_line("_tt_append(_tt_utf8(%s(%s())))" % (
self.method, method_name), self.line)
class _ControlBlock(_Node):
def __init__(self, statement, line, body=None):
self.statement = statement
self.line = line
self.body = body
def each_child(self):
return (self.body,)
def generate(self, writer):
writer.write_line("%s:" % self.statement, self.line)
with writer.indent():
self.body.generate(writer)
# Just in case the body was empty
writer.write_line("pass", self.line)
class _IntermediateControlBlock(_Node):
def __init__(self, statement, line):
self.statement = statement
self.line = line
def generate(self, writer):
# In case the previous block was empty
writer.write_line("pass", self.line)
writer.write_line("%s:" % self.statement, self.line, writer.indent_size() - 1)
class _Statement(_Node):
def __init__(self, statement, line):
self.statement = statement
self.line = line
def generate(self, writer):
writer.write_line(self.statement, self.line)
class _Expression(_Node):
def __init__(self, expression, line, raw=False):
self.expression = expression
self.line = line
self.raw = raw
def generate(self, writer):
writer.write_line("_tt_tmp = %s" % self.expression, self.line)
writer.write_line("if isinstance(_tt_tmp, _tt_string_types):"
" _tt_tmp = _tt_utf8(_tt_tmp)", self.line)
writer.write_line("else: _tt_tmp = _tt_utf8(str(_tt_tmp))", self.line)
if not self.raw and writer.current_template.autoescape is not None:
# In python3 functions like xhtml_escape return unicode,
# so we have to convert to utf8 again.
writer.write_line("_tt_tmp = _tt_utf8(%s(_tt_tmp))" %
writer.current_template.autoescape, self.line)
writer.write_line("_tt_append(_tt_tmp)", self.line)
class _Module(_Expression):
def __init__(self, expression, line):
super(_Module, self).__init__("_tt_modules." + expression, line,
raw=True)
class _Text(_Node):
def __init__(self, value, line):
self.value = value
self.line = line
def generate(self, writer):
value = self.value
# Compress lots of white space to a single character. If the whitespace
# breaks a line, have it continue to break a line, but just with a
# single \n character
if writer.compress_whitespace and "<pre>" not in value:
value = re.sub(r"([\t ]+)", " ", value)
value = re.sub(r"(\s*\n\s*)", "\n", value)
if value:
writer.write_line('_tt_append(%r)' % escape.utf8(value), self.line)
class ParseError(Exception):
"""Raised for template syntax errors."""
pass
class _CodeWriter(object):
def __init__(self, file, named_blocks, loader, current_template,
compress_whitespace):
self.file = file
self.named_blocks = named_blocks
self.loader = loader
self.current_template = current_template
self.compress_whitespace = compress_whitespace
self.apply_counter = 0
self.include_stack = []
self._indent = 0
def indent_size(self):
return self._indent
def indent(self):
class Indenter(object):
def __enter__(_):
self._indent += 1
return self
def __exit__(_, *args):
assert self._indent > 0
self._indent -= 1
return Indenter()
def include(self, template, line):
self.include_stack.append((self.current_template, line))
self.current_template = template
class IncludeTemplate(object):
def __enter__(_):
return self
def __exit__(_, *args):
self.current_template = self.include_stack.pop()[0]
return IncludeTemplate()
def write_line(self, line, line_number, indent=None):
if indent is None:
indent = self._indent
line_comment = ' # %s:%d' % (self.current_template.name, line_number)
if self.include_stack:
ancestors = ["%s:%d" % (tmpl.name, lineno)
for (tmpl, lineno) in self.include_stack]
line_comment += ' (via %s)' % ', '.join(reversed(ancestors))
print(" " * indent + line + line_comment, file=self.file)
class _TemplateReader(object):
def __init__(self, name, text):
self.name = name
self.text = text
self.line = 1
self.pos = 0
def find(self, needle, start=0, end=None):
assert start >= 0, start
pos = self.pos
start += pos
if end is None:
index = self.text.find(needle, start)
else:
end += pos
assert end >= start
index = self.text.find(needle, start, end)
if index != -1:
index -= pos
return index
def consume(self, count=None):
if count is None:
count = len(self.text) - self.pos
newpos = self.pos + count
self.line += self.text.count("\n", self.pos, newpos)
s = self.text[self.pos:newpos]
self.pos = newpos
return s
def remaining(self):
return len(self.text) - self.pos
def __len__(self):
return self.remaining()
def __getitem__(self, key):
if type(key) is slice:
size = len(self)
start, stop, step = key.indices(size)
if start is None:
start = self.pos
else:
start += self.pos
if stop is not None:
stop += self.pos
return self.text[slice(start, stop, step)]
elif key < 0:
return self.text[key]
else:
return self.text[self.pos + key]
def __str__(self):
return self.text[self.pos:]
def _format_code(code):
lines = code.splitlines()
format = "%%%dd %%s\n" % len(repr(len(lines) + 1))
return "".join([format % (i + 1, line) for (i, line) in enumerate(lines)])
def _parse(reader, template, in_block=None, in_loop=None):
body = _ChunkList([])
while True:
# Find next template directive
curly = 0
while True:
curly = reader.find("{", curly)
if curly == -1 or curly + 1 == reader.remaining():
# EOF
if in_block:
raise ParseError("Missing {%% end %%} block for %s" %
in_block)
body.chunks.append(_Text(reader.consume(), reader.line))
return body
# If the first curly brace is not the start of a special token,
# start searching from the character after it
if reader[curly + 1] not in ("{", "%", "#"):
curly += 1
continue
# When there are more than 2 curlies in a row, use the
# innermost ones. This is useful when generating languages
# like latex where curlies are also meaningful
if (curly + 2 < reader.remaining() and
reader[curly + 1] == '{' and reader[curly + 2] == '{'):
curly += 1
continue
break
# Append any text before the special token
if curly > 0:
cons = reader.consume(curly)
body.chunks.append(_Text(cons, reader.line))
start_brace = reader.consume(2)
line = reader.line
# Template directives may be escaped as "{{!" or "{%!".
# In this case output the braces and consume the "!".
# This is especially useful in conjunction with jquery templates,
# which also use double braces.
if reader.remaining() and reader[0] == "!":
reader.consume(1)
body.chunks.append(_Text(start_brace, line))
continue
# Comment
if start_brace == "{#":
end = reader.find("#}")
if end == -1:
raise ParseError("Missing end expression #} on line %d" % line)
contents = reader.consume(end).strip()
reader.consume(2)
continue
# Expression
if start_brace == "{{":
end = reader.find("}}")
if end == -1:
raise ParseError("Missing end expression }} on line %d" % line)
contents = reader.consume(end).strip()
reader.consume(2)
if not contents:
raise ParseError("Empty expression on line %d" % line)
body.chunks.append(_Expression(contents, line))
continue
# Block
assert start_brace == "{%", start_brace
end = reader.find("%}")
if end == -1:
raise ParseError("Missing end block %%} on line %d" % line)
contents = reader.consume(end).strip()
reader.consume(2)
if not contents:
raise ParseError("Empty block tag ({%% %%}) on line %d" % line)
operator, space, suffix = contents.partition(" ")
suffix = suffix.strip()
# Intermediate ("else", "elif", etc) blocks
intermediate_blocks = {
"else": set(["if", "for", "while", "try"]),
"elif": set(["if"]),
"except": set(["try"]),
"finally": set(["try"]),
}
allowed_parents = intermediate_blocks.get(operator)
if allowed_parents is not None:
if not in_block:
raise ParseError("%s outside %s block" %
(operator, allowed_parents))
if in_block not in allowed_parents:
raise ParseError("%s block cannot be attached to %s block" % (operator, in_block))
body.chunks.append(_IntermediateControlBlock(contents, line))
continue
# End tag
elif operator == "end":
if not in_block:
raise ParseError("Extra {%% end %%} block on line %d" % line)
return body
elif operator in ("extends", "include", "set", "import", "from",
"comment", "autoescape", "raw", "module"):
if operator == "comment":
continue
if operator == "extends":
suffix = suffix.strip('"').strip("'")
if not suffix:
raise ParseError("extends missing file path on line %d" % line)
block = _ExtendsBlock(suffix)
elif operator in ("import", "from"):
if not suffix:
raise ParseError("import missing statement on line %d" % line)
block = _Statement(contents, line)
elif operator == "include":
suffix = suffix.strip('"').strip("'")
if not suffix:
raise ParseError("include missing file path on line %d" % line)
block = _IncludeBlock(suffix, reader, line)
elif operator == "set":
if not suffix:
raise ParseError("set missing statement on line %d" % line)
block = _Statement(suffix, line)
elif operator == "autoescape":
fn = suffix.strip()
if fn == "None":
fn = None
template.autoescape = fn
continue
elif operator == "raw":
block = _Expression(suffix, line, raw=True)
elif operator == "module":
block = _Module(suffix, line)
body.chunks.append(block)
continue
elif operator in ("apply", "block", "try", "if", "for", "while"):
# parse inner body recursively
if operator in ("for", "while"):
block_body = _parse(reader, template, operator, operator)
elif operator == "apply":
# apply creates a nested function so syntactically it's not
# in the loop.
block_body = _parse(reader, template, operator, None)
else:
block_body = _parse(reader, template, operator, in_loop)
if operator == "apply":
if not suffix:
raise ParseError("apply missing method name on line %d" % line)
block = _ApplyBlock(suffix, line, block_body)
elif operator == "block":
if not suffix:
raise ParseError("block missing name on line %d" % line)
block = _NamedBlock(suffix, block_body, template, line)
else:
block = _ControlBlock(contents, line, block_body)
body.chunks.append(block)
continue
elif operator in ("break", "continue"):
if not in_loop:
raise ParseError("%s outside %s block" % (operator, set(["for", "while"])))
body.chunks.append(_Statement(contents, line))
continue
else:
raise ParseError("unknown operator: %r" % operator)

View file

@ -0,0 +1,4 @@
Test coverage is almost non-existent, but it's a start. Be sure to
set PYTHONPATH apprioriately (generally to the root directory of your
tornado checkout) when running tests to make sure you're getting the
version of the tornado package that you expect.

View file

@ -0,0 +1,14 @@
"""Shim to allow python -m tornado.test.
This only works in python 2.7+.
"""
from __future__ import absolute_import, division, print_function, with_statement
from tornado.test.runtests import all, main
# tornado.testing.main autodiscovery relies on 'all' being present in
# the main module, so import it here even though it is not used directly.
# The following line prevents a pyflakes warning.
all = all
main()

View file

@ -0,0 +1,451 @@
# These tests do not currently do much to verify the correct implementation
# of the openid/oauth protocols, they just exercise the major code paths
# and ensure that it doesn't blow up (e.g. with unicode/bytes issues in
# python 3)
from __future__ import absolute_import, division, print_function, with_statement
from tornado.auth import OpenIdMixin, OAuthMixin, OAuth2Mixin, TwitterMixin, GoogleMixin, AuthError
from tornado.concurrent import Future
from tornado.escape import json_decode
from tornado import gen
from tornado.log import gen_log
from tornado.testing import AsyncHTTPTestCase, ExpectLog
from tornado.util import u
from tornado.web import RequestHandler, Application, asynchronous, HTTPError
class OpenIdClientLoginHandler(RequestHandler, OpenIdMixin):
def initialize(self, test):
self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate')
@asynchronous
def get(self):
if self.get_argument('openid.mode', None):
self.get_authenticated_user(
self.on_user, http_client=self.settings['http_client'])
return
res = self.authenticate_redirect()
assert isinstance(res, Future)
assert res.done()
def on_user(self, user):
if user is None:
raise Exception("user is None")
self.finish(user)
class OpenIdServerAuthenticateHandler(RequestHandler):
def post(self):
if self.get_argument('openid.mode') != 'check_authentication':
raise Exception("incorrect openid.mode %r")
self.write('is_valid:true')
class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
def initialize(self, test, version):
self._OAUTH_VERSION = version
self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize')
self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/oauth1/server/access_token')
def _oauth_consumer_token(self):
return dict(key='asdf', secret='qwer')
@asynchronous
def get(self):
if self.get_argument('oauth_token', None):
self.get_authenticated_user(
self.on_user, http_client=self.settings['http_client'])
return
res = self.authorize_redirect(http_client=self.settings['http_client'])
assert isinstance(res, Future)
def on_user(self, user):
if user is None:
raise Exception("user is None")
self.finish(user)
def _oauth_get_user(self, access_token, callback):
if self.get_argument('fail_in_get_user', None):
raise Exception("failing in get_user")
if access_token != dict(key='uiop', secret='5678'):
raise Exception("incorrect access token %r" % access_token)
callback(dict(email='foo@example.com'))
class OAuth1ClientLoginCoroutineHandler(OAuth1ClientLoginHandler):
"""Replaces OAuth1ClientLoginCoroutineHandler's get() with a coroutine."""
@gen.coroutine
def get(self):
if self.get_argument('oauth_token', None):
# Ensure that any exceptions are set on the returned Future,
# not simply thrown into the surrounding StackContext.
try:
yield self.get_authenticated_user()
except Exception as e:
self.set_status(503)
self.write("got exception: %s" % e)
else:
yield self.authorize_redirect()
class OAuth1ClientRequestParametersHandler(RequestHandler, OAuthMixin):
def initialize(self, version):
self._OAUTH_VERSION = version
def _oauth_consumer_token(self):
return dict(key='asdf', secret='qwer')
def get(self):
params = self._oauth_request_parameters(
'http://www.example.com/api/asdf',
dict(key='uiop', secret='5678'),
parameters=dict(foo='bar'))
self.write(params)
class OAuth1ServerRequestTokenHandler(RequestHandler):
def get(self):
self.write('oauth_token=zxcv&oauth_token_secret=1234')
class OAuth1ServerAccessTokenHandler(RequestHandler):
def get(self):
self.write('oauth_token=uiop&oauth_token_secret=5678')
class OAuth2ClientLoginHandler(RequestHandler, OAuth2Mixin):
def initialize(self, test):
self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth2/server/authorize')
def get(self):
res = self.authorize_redirect()
assert isinstance(res, Future)
assert res.done()
class TwitterClientHandler(RequestHandler, TwitterMixin):
def initialize(self, test):
self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/twitter/server/access_token')
self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize')
self._TWITTER_BASE_URL = test.get_url('/twitter/api')
def get_auth_http_client(self):
return self.settings['http_client']
class TwitterClientLoginHandler(TwitterClientHandler):
@asynchronous
def get(self):
if self.get_argument("oauth_token", None):
self.get_authenticated_user(self.on_user)
return
self.authorize_redirect()
def on_user(self, user):
if user is None:
raise Exception("user is None")
self.finish(user)
class TwitterClientLoginGenEngineHandler(TwitterClientHandler):
@asynchronous
@gen.engine
def get(self):
if self.get_argument("oauth_token", None):
user = yield self.get_authenticated_user()
self.finish(user)
else:
# Old style: with @gen.engine we can ignore the Future from
# authorize_redirect.
self.authorize_redirect()
class TwitterClientLoginGenCoroutineHandler(TwitterClientHandler):
@gen.coroutine
def get(self):
if self.get_argument("oauth_token", None):
user = yield self.get_authenticated_user()
self.finish(user)
else:
# New style: with @gen.coroutine the result must be yielded
# or else the request will be auto-finished too soon.
yield self.authorize_redirect()
class TwitterClientShowUserHandler(TwitterClientHandler):
@asynchronous
@gen.engine
def get(self):
# TODO: would be nice to go through the login flow instead of
# cheating with a hard-coded access token.
response = yield gen.Task(self.twitter_request,
'/users/show/%s' % self.get_argument('name'),
access_token=dict(key='hjkl', secret='vbnm'))
if response is None:
self.set_status(500)
self.finish('error from twitter request')
else:
self.finish(response)
class TwitterClientShowUserFutureHandler(TwitterClientHandler):
@asynchronous
@gen.engine
def get(self):
try:
response = yield self.twitter_request(
'/users/show/%s' % self.get_argument('name'),
access_token=dict(key='hjkl', secret='vbnm'))
except AuthError as e:
self.set_status(500)
self.finish(str(e))
return
assert response is not None
self.finish(response)
class TwitterServerAccessTokenHandler(RequestHandler):
def get(self):
self.write('oauth_token=hjkl&oauth_token_secret=vbnm&screen_name=foo')
class TwitterServerShowUserHandler(RequestHandler):
def get(self, screen_name):
if screen_name == 'error':
raise HTTPError(500)
assert 'oauth_nonce' in self.request.arguments
assert 'oauth_timestamp' in self.request.arguments
assert 'oauth_signature' in self.request.arguments
assert self.get_argument('oauth_consumer_key') == 'test_twitter_consumer_key'
assert self.get_argument('oauth_signature_method') == 'HMAC-SHA1'
assert self.get_argument('oauth_version') == '1.0'
assert self.get_argument('oauth_token') == 'hjkl'
self.write(dict(screen_name=screen_name, name=screen_name.capitalize()))
class TwitterServerVerifyCredentialsHandler(RequestHandler):
def get(self):
assert 'oauth_nonce' in self.request.arguments
assert 'oauth_timestamp' in self.request.arguments
assert 'oauth_signature' in self.request.arguments
assert self.get_argument('oauth_consumer_key') == 'test_twitter_consumer_key'
assert self.get_argument('oauth_signature_method') == 'HMAC-SHA1'
assert self.get_argument('oauth_version') == '1.0'
assert self.get_argument('oauth_token') == 'hjkl'
self.write(dict(screen_name='foo', name='Foo'))
class GoogleOpenIdClientLoginHandler(RequestHandler, GoogleMixin):
def initialize(self, test):
self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate')
@asynchronous
def get(self):
if self.get_argument("openid.mode", None):
self.get_authenticated_user(self.on_user)
return
res = self.authenticate_redirect()
assert isinstance(res, Future)
assert res.done()
def on_user(self, user):
if user is None:
raise Exception("user is None")
self.finish(user)
def get_auth_http_client(self):
return self.settings['http_client']
class AuthTest(AsyncHTTPTestCase):
def get_app(self):
return Application(
[
# test endpoints
('/openid/client/login', OpenIdClientLoginHandler, dict(test=self)),
('/oauth10/client/login', OAuth1ClientLoginHandler,
dict(test=self, version='1.0')),
('/oauth10/client/request_params',
OAuth1ClientRequestParametersHandler,
dict(version='1.0')),
('/oauth10a/client/login', OAuth1ClientLoginHandler,
dict(test=self, version='1.0a')),
('/oauth10a/client/login_coroutine',
OAuth1ClientLoginCoroutineHandler,
dict(test=self, version='1.0a')),
('/oauth10a/client/request_params',
OAuth1ClientRequestParametersHandler,
dict(version='1.0a')),
('/oauth2/client/login', OAuth2ClientLoginHandler, dict(test=self)),
('/twitter/client/login', TwitterClientLoginHandler, dict(test=self)),
('/twitter/client/login_gen_engine', TwitterClientLoginGenEngineHandler, dict(test=self)),
('/twitter/client/login_gen_coroutine', TwitterClientLoginGenCoroutineHandler, dict(test=self)),
('/twitter/client/show_user', TwitterClientShowUserHandler, dict(test=self)),
('/twitter/client/show_user_future', TwitterClientShowUserFutureHandler, dict(test=self)),
('/google/client/openid_login', GoogleOpenIdClientLoginHandler, dict(test=self)),
# simulated servers
('/openid/server/authenticate', OpenIdServerAuthenticateHandler),
('/oauth1/server/request_token', OAuth1ServerRequestTokenHandler),
('/oauth1/server/access_token', OAuth1ServerAccessTokenHandler),
('/twitter/server/access_token', TwitterServerAccessTokenHandler),
(r'/twitter/api/users/show/(.*)\.json', TwitterServerShowUserHandler),
(r'/twitter/api/account/verify_credentials\.json', TwitterServerVerifyCredentialsHandler),
],
http_client=self.http_client,
twitter_consumer_key='test_twitter_consumer_key',
twitter_consumer_secret='test_twitter_consumer_secret')
def test_openid_redirect(self):
response = self.fetch('/openid/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(
'/openid/server/authenticate?' in response.headers['Location'])
def test_openid_get_user(self):
response = self.fetch('/openid/client/login?openid.mode=blah&openid.ns.ax=http://openid.net/srv/ax/1.0&openid.ax.type.email=http://axschema.org/contact/email&openid.ax.value.email=foo@example.com')
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["email"], "foo@example.com")
def test_oauth10_redirect(self):
response = self.fetch('/oauth10/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(response.headers['Location'].endswith(
'/oauth1/server/authorize?oauth_token=zxcv'))
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
response.headers['Set-Cookie'])
def test_oauth10_get_user(self):
response = self.fetch(
'/oauth10/client/login?oauth_token=zxcv',
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed['email'], 'foo@example.com')
self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
def test_oauth10_request_parameters(self):
response = self.fetch('/oauth10/client/request_params')
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed['oauth_consumer_key'], 'asdf')
self.assertEqual(parsed['oauth_token'], 'uiop')
self.assertTrue('oauth_nonce' in parsed)
self.assertTrue('oauth_signature' in parsed)
def test_oauth10a_redirect(self):
response = self.fetch('/oauth10a/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(response.headers['Location'].endswith(
'/oauth1/server/authorize?oauth_token=zxcv'))
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
response.headers['Set-Cookie'])
def test_oauth10a_get_user(self):
response = self.fetch(
'/oauth10a/client/login?oauth_token=zxcv',
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed['email'], 'foo@example.com')
self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
def test_oauth10a_request_parameters(self):
response = self.fetch('/oauth10a/client/request_params')
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed['oauth_consumer_key'], 'asdf')
self.assertEqual(parsed['oauth_token'], 'uiop')
self.assertTrue('oauth_nonce' in parsed)
self.assertTrue('oauth_signature' in parsed)
def test_oauth10a_get_user_coroutine_exception(self):
response = self.fetch(
'/oauth10a/client/login_coroutine?oauth_token=zxcv&fail_in_get_user=true',
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
self.assertEqual(response.code, 503)
def test_oauth2_redirect(self):
response = self.fetch('/oauth2/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue('/oauth2/server/authorize?' in response.headers['Location'])
def base_twitter_redirect(self, url):
# Same as test_oauth10a_redirect
response = self.fetch(url, follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(response.headers['Location'].endswith(
'/oauth1/server/authorize?oauth_token=zxcv'))
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
response.headers['Set-Cookie'])
def test_twitter_redirect(self):
self.base_twitter_redirect('/twitter/client/login')
def test_twitter_redirect_gen_engine(self):
self.base_twitter_redirect('/twitter/client/login_gen_engine')
def test_twitter_redirect_gen_coroutine(self):
self.base_twitter_redirect('/twitter/client/login_gen_coroutine')
def test_twitter_get_user(self):
response = self.fetch(
'/twitter/client/login?oauth_token=zxcv',
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed,
{u('access_token'): {u('key'): u('hjkl'),
u('screen_name'): u('foo'),
u('secret'): u('vbnm')},
u('name'): u('Foo'),
u('screen_name'): u('foo'),
u('username'): u('foo')})
def test_twitter_show_user(self):
response = self.fetch('/twitter/client/show_user?name=somebody')
response.rethrow()
self.assertEqual(json_decode(response.body),
{'name': 'Somebody', 'screen_name': 'somebody'})
def test_twitter_show_user_error(self):
with ExpectLog(gen_log, 'Error response HTTP 500'):
response = self.fetch('/twitter/client/show_user?name=error')
self.assertEqual(response.code, 500)
self.assertEqual(response.body, b'error from twitter request')
def test_twitter_show_user_future(self):
response = self.fetch('/twitter/client/show_user_future?name=somebody')
response.rethrow()
self.assertEqual(json_decode(response.body),
{'name': 'Somebody', 'screen_name': 'somebody'})
def test_twitter_show_user_future_error(self):
response = self.fetch('/twitter/client/show_user_future?name=error')
self.assertEqual(response.code, 500)
self.assertIn(b'Error response HTTP 500', response.body)
def test_google_redirect(self):
# same as test_openid_redirect
response = self.fetch('/google/client/openid_login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(
'/openid/server/authenticate?' in response.headers['Location'])
def test_google_get_user(self):
response = self.fetch('/google/client/openid_login?openid.mode=blah&openid.ns.ax=http://openid.net/srv/ax/1.0&openid.ax.type.email=http://axschema.org/contact/email&openid.ax.value.email=foo@example.com', follow_redirects=False)
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["email"], "foo@example.com")

View file

@ -0,0 +1,336 @@
#!/usr/bin/env python
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, with_statement
import logging
import re
import socket
import sys
import traceback
from tornado.concurrent import Future, return_future, ReturnValueIgnoredError
from tornado.escape import utf8, to_unicode
from tornado import gen
from tornado.iostream import IOStream
from tornado import stack_context
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, LogTrapTestCase, bind_unused_port, gen_test
try:
from concurrent import futures
except ImportError:
futures = None
class ReturnFutureTest(AsyncTestCase):
@return_future
def sync_future(self, callback):
callback(42)
@return_future
def async_future(self, callback):
self.io_loop.add_callback(callback, 42)
@return_future
def immediate_failure(self, callback):
1 / 0
@return_future
def delayed_failure(self, callback):
self.io_loop.add_callback(lambda: 1 / 0)
@return_future
def return_value(self, callback):
# Note that the result of both running the callback and returning
# a value (or raising an exception) is unspecified; with current
# implementations the last event prior to callback resolution wins.
return 42
@return_future
def no_result_future(self, callback):
callback()
def test_immediate_failure(self):
with self.assertRaises(ZeroDivisionError):
# The caller sees the error just like a normal function.
self.immediate_failure(callback=self.stop)
# The callback is not run because the function failed synchronously.
self.io_loop.add_timeout(self.io_loop.time() + 0.05, self.stop)
result = self.wait()
self.assertIs(result, None)
def test_return_value(self):
with self.assertRaises(ReturnValueIgnoredError):
self.return_value(callback=self.stop)
def test_callback_kw(self):
future = self.sync_future(callback=self.stop)
result = self.wait()
self.assertEqual(result, 42)
self.assertEqual(future.result(), 42)
def test_callback_positional(self):
# When the callback is passed in positionally, future_wrap shouldn't
# add another callback in the kwargs.
future = self.sync_future(self.stop)
result = self.wait()
self.assertEqual(result, 42)
self.assertEqual(future.result(), 42)
def test_no_callback(self):
future = self.sync_future()
self.assertEqual(future.result(), 42)
def test_none_callback_kw(self):
# explicitly pass None as callback
future = self.sync_future(callback=None)
self.assertEqual(future.result(), 42)
def test_none_callback_pos(self):
future = self.sync_future(None)
self.assertEqual(future.result(), 42)
def test_async_future(self):
future = self.async_future()
self.assertFalse(future.done())
self.io_loop.add_future(future, self.stop)
future2 = self.wait()
self.assertIs(future, future2)
self.assertEqual(future.result(), 42)
@gen_test
def test_async_future_gen(self):
result = yield self.async_future()
self.assertEqual(result, 42)
def test_delayed_failure(self):
future = self.delayed_failure()
self.io_loop.add_future(future, self.stop)
future2 = self.wait()
self.assertIs(future, future2)
with self.assertRaises(ZeroDivisionError):
future.result()
def test_kw_only_callback(self):
@return_future
def f(**kwargs):
kwargs['callback'](42)
future = f()
self.assertEqual(future.result(), 42)
def test_error_in_callback(self):
self.sync_future(callback=lambda future: 1 / 0)
# The exception gets caught by our StackContext and will be re-raised
# when we wait.
self.assertRaises(ZeroDivisionError, self.wait)
def test_no_result_future(self):
future = self.no_result_future(self.stop)
result = self.wait()
self.assertIs(result, None)
# result of this future is undefined, but not an error
future.result()
def test_no_result_future_callback(self):
future = self.no_result_future(callback=lambda: self.stop())
result = self.wait()
self.assertIs(result, None)
future.result()
@gen_test
def test_future_traceback(self):
@return_future
@gen.engine
def f(callback):
yield gen.Task(self.io_loop.add_callback)
try:
1 / 0
except ZeroDivisionError:
self.expected_frame = traceback.extract_tb(
sys.exc_info()[2], limit=1)[0]
raise
try:
yield f()
self.fail("didn't get expected exception")
except ZeroDivisionError:
tb = traceback.extract_tb(sys.exc_info()[2])
self.assertIn(self.expected_frame, tb)
# The following series of classes demonstrate and test various styles
# of use, with and without generators and futures.
class CapServer(TCPServer):
def handle_stream(self, stream, address):
logging.info("handle_stream")
self.stream = stream
self.stream.read_until(b"\n", self.handle_read)
def handle_read(self, data):
logging.info("handle_read")
data = to_unicode(data)
if data == data.upper():
self.stream.write(b"error\talready capitalized\n")
else:
# data already has \n
self.stream.write(utf8("ok\t%s" % data.upper()))
self.stream.close()
class CapError(Exception):
pass
class BaseCapClient(object):
def __init__(self, port, io_loop):
self.port = port
self.io_loop = io_loop
def process_response(self, data):
status, message = re.match('(.*)\t(.*)\n', to_unicode(data)).groups()
if status == 'ok':
return message
else:
raise CapError(message)
class ManualCapClient(BaseCapClient):
def capitalize(self, request_data, callback=None):
logging.info("capitalize")
self.request_data = request_data
self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
self.stream.connect(('127.0.0.1', self.port),
callback=self.handle_connect)
self.future = Future()
if callback is not None:
self.future.add_done_callback(
stack_context.wrap(lambda future: callback(future.result())))
return self.future
def handle_connect(self):
logging.info("handle_connect")
self.stream.write(utf8(self.request_data + "\n"))
self.stream.read_until(b'\n', callback=self.handle_read)
def handle_read(self, data):
logging.info("handle_read")
self.stream.close()
try:
self.future.set_result(self.process_response(data))
except CapError as e:
self.future.set_exception(e)
class DecoratorCapClient(BaseCapClient):
@return_future
def capitalize(self, request_data, callback):
logging.info("capitalize")
self.request_data = request_data
self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
self.stream.connect(('127.0.0.1', self.port),
callback=self.handle_connect)
self.callback = callback
def handle_connect(self):
logging.info("handle_connect")
self.stream.write(utf8(self.request_data + "\n"))
self.stream.read_until(b'\n', callback=self.handle_read)
def handle_read(self, data):
logging.info("handle_read")
self.stream.close()
self.callback(self.process_response(data))
class GeneratorCapClient(BaseCapClient):
@return_future
@gen.engine
def capitalize(self, request_data, callback):
logging.info('capitalize')
stream = IOStream(socket.socket(), io_loop=self.io_loop)
logging.info('connecting')
yield gen.Task(stream.connect, ('127.0.0.1', self.port))
stream.write(utf8(request_data + '\n'))
logging.info('reading')
data = yield gen.Task(stream.read_until, b'\n')
logging.info('returning')
stream.close()
callback(self.process_response(data))
class ClientTestMixin(object):
def setUp(self):
super(ClientTestMixin, self).setUp()
self.server = CapServer(io_loop=self.io_loop)
sock, port = bind_unused_port()
self.server.add_sockets([sock])
self.client = self.client_class(io_loop=self.io_loop, port=port)
def tearDown(self):
self.server.stop()
super(ClientTestMixin, self).tearDown()
def test_callback(self):
self.client.capitalize("hello", callback=self.stop)
result = self.wait()
self.assertEqual(result, "HELLO")
def test_callback_error(self):
self.client.capitalize("HELLO", callback=self.stop)
self.assertRaisesRegexp(CapError, "already capitalized", self.wait)
def test_future(self):
future = self.client.capitalize("hello")
self.io_loop.add_future(future, self.stop)
self.wait()
self.assertEqual(future.result(), "HELLO")
def test_future_error(self):
future = self.client.capitalize("HELLO")
self.io_loop.add_future(future, self.stop)
self.wait()
self.assertRaisesRegexp(CapError, "already capitalized", future.result)
def test_generator(self):
@gen.engine
def f():
result = yield self.client.capitalize("hello")
self.assertEqual(result, "HELLO")
self.stop()
f()
self.wait()
def test_generator_error(self):
@gen.engine
def f():
with self.assertRaisesRegexp(CapError, "already capitalized"):
yield self.client.capitalize("HELLO")
self.stop()
f()
self.wait()
class ManualClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
client_class = ManualCapClient
class DecoratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
client_class = DecoratorCapClient
class GeneratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
client_class = GeneratorCapClient

View file

@ -0,0 +1 @@
"school","école"
1 school école

View file

@ -0,0 +1,122 @@
from __future__ import absolute_import, division, print_function, with_statement
from hashlib import md5
from tornado.escape import utf8
from tornado.httpclient import HTTPRequest
from tornado.stack_context import ExceptionStackContext
from tornado.testing import AsyncHTTPTestCase
from tornado.test import httpclient_test
from tornado.test.util import unittest
from tornado.web import Application, RequestHandler
try:
import pycurl
except ImportError:
pycurl = None
if pycurl is not None:
from tornado.curl_httpclient import CurlAsyncHTTPClient
@unittest.skipIf(pycurl is None, "pycurl module not present")
class CurlHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
def get_http_client(self):
client = CurlAsyncHTTPClient(io_loop=self.io_loop,
defaults=dict(allow_ipv6=False))
# make sure AsyncHTTPClient magic doesn't give us the wrong class
self.assertTrue(isinstance(client, CurlAsyncHTTPClient))
return client
class DigestAuthHandler(RequestHandler):
def get(self):
realm = 'test'
opaque = 'asdf'
# Real implementations would use a random nonce.
nonce = "1234"
username = 'foo'
password = 'bar'
auth_header = self.request.headers.get('Authorization', None)
if auth_header is not None:
auth_mode, params = auth_header.split(' ', 1)
assert auth_mode == 'Digest'
param_dict = {}
for pair in params.split(','):
k, v = pair.strip().split('=', 1)
if v[0] == '"' and v[-1] == '"':
v = v[1:-1]
param_dict[k] = v
assert param_dict['realm'] == realm
assert param_dict['opaque'] == opaque
assert param_dict['nonce'] == nonce
assert param_dict['username'] == username
assert param_dict['uri'] == self.request.path
h1 = md5(utf8('%s:%s:%s' % (username, realm, password))).hexdigest()
h2 = md5(utf8('%s:%s' % (self.request.method,
self.request.path))).hexdigest()
digest = md5(utf8('%s:%s:%s' % (h1, nonce, h2))).hexdigest()
if digest == param_dict['response']:
self.write('ok')
else:
self.write('fail')
else:
self.set_status(401)
self.set_header('WWW-Authenticate',
'Digest realm="%s", nonce="%s", opaque="%s"' %
(realm, nonce, opaque))
class CustomReasonHandler(RequestHandler):
def get(self):
self.set_status(200, "Custom reason")
class CustomFailReasonHandler(RequestHandler):
def get(self):
self.set_status(400, "Custom reason")
@unittest.skipIf(pycurl is None, "pycurl module not present")
class CurlHTTPClientTestCase(AsyncHTTPTestCase):
def setUp(self):
super(CurlHTTPClientTestCase, self).setUp()
self.http_client = CurlAsyncHTTPClient(self.io_loop,
defaults=dict(allow_ipv6=False))
def get_app(self):
return Application([
('/digest', DigestAuthHandler),
('/custom_reason', CustomReasonHandler),
('/custom_fail_reason', CustomFailReasonHandler),
])
def test_prepare_curl_callback_stack_context(self):
exc_info = []
def error_handler(typ, value, tb):
exc_info.append((typ, value, tb))
self.stop()
return True
with ExceptionStackContext(error_handler):
request = HTTPRequest(self.get_url('/'),
prepare_curl_callback=lambda curl: 1 / 0)
self.http_client.fetch(request, callback=self.stop)
self.wait()
self.assertEqual(1, len(exc_info))
self.assertIs(exc_info[0][0], ZeroDivisionError)
def test_digest_auth(self):
response = self.fetch('/digest', auth_mode='digest',
auth_username='foo', auth_password='bar')
self.assertEqual(response.body, b'ok')
def test_custom_reason(self):
response = self.fetch('/custom_reason')
self.assertEqual(response.reason, "Custom reason")
def test_fail_custom_reason(self):
response = self.fetch('/custom_fail_reason')
self.assertEqual(str(response.error), "HTTP 400: Custom reason")

View file

@ -0,0 +1,217 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
import tornado.escape
from tornado.escape import utf8, xhtml_escape, xhtml_unescape, url_escape, url_unescape, to_unicode, json_decode, json_encode
from tornado.util import u, unicode_type, bytes_type
from tornado.test.util import unittest
linkify_tests = [
# (input, linkify_kwargs, expected_output)
("hello http://world.com/!", {},
u('hello <a href="http://world.com/">http://world.com/</a>!')),
("hello http://world.com/with?param=true&stuff=yes", {},
u('hello <a href="http://world.com/with?param=true&amp;stuff=yes">http://world.com/with?param=true&amp;stuff=yes</a>')),
# an opened paren followed by many chars killed Gruber's regex
("http://url.com/w(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", {},
u('<a href="http://url.com/w">http://url.com/w</a>(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa')),
# as did too many dots at the end
("http://url.com/withmany.......................................", {},
u('<a href="http://url.com/withmany">http://url.com/withmany</a>.......................................')),
("http://url.com/withmany((((((((((((((((((((((((((((((((((a)", {},
u('<a href="http://url.com/withmany">http://url.com/withmany</a>((((((((((((((((((((((((((((((((((a)')),
# some examples from http://daringfireball.net/2009/11/liberal_regex_for_matching_urls
# plus a fex extras (such as multiple parentheses).
("http://foo.com/blah_blah", {},
u('<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>')),
("http://foo.com/blah_blah/", {},
u('<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>')),
("(Something like http://foo.com/blah_blah)", {},
u('(Something like <a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>)')),
("http://foo.com/blah_blah_(wikipedia)", {},
u('<a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>')),
("http://foo.com/blah_(blah)_(wikipedia)_blah", {},
u('<a href="http://foo.com/blah_(blah)_(wikipedia)_blah">http://foo.com/blah_(blah)_(wikipedia)_blah</a>')),
("(Something like http://foo.com/blah_blah_(wikipedia))", {},
u('(Something like <a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>)')),
("http://foo.com/blah_blah.", {},
u('<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>.')),
("http://foo.com/blah_blah/.", {},
u('<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>.')),
("<http://foo.com/blah_blah>", {},
u('&lt;<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>&gt;')),
("<http://foo.com/blah_blah/>", {},
u('&lt;<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>&gt;')),
("http://foo.com/blah_blah,", {},
u('<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>,')),
("http://www.example.com/wpstyle/?p=364.", {},
u('<a href="http://www.example.com/wpstyle/?p=364">http://www.example.com/wpstyle/?p=364</a>.')),
("rdar://1234",
{"permitted_protocols": ["http", "rdar"]},
u('<a href="rdar://1234">rdar://1234</a>')),
("rdar:/1234",
{"permitted_protocols": ["rdar"]},
u('<a href="rdar:/1234">rdar:/1234</a>')),
("http://userid:password@example.com:8080", {},
u('<a href="http://userid:password@example.com:8080">http://userid:password@example.com:8080</a>')),
("http://userid@example.com", {},
u('<a href="http://userid@example.com">http://userid@example.com</a>')),
("http://userid@example.com:8080", {},
u('<a href="http://userid@example.com:8080">http://userid@example.com:8080</a>')),
("http://userid:password@example.com", {},
u('<a href="http://userid:password@example.com">http://userid:password@example.com</a>')),
("message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e",
{"permitted_protocols": ["http", "message"]},
u('<a href="message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e">message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e</a>')),
(u("http://\u27a1.ws/\u4a39"), {},
u('<a href="http://\u27a1.ws/\u4a39">http://\u27a1.ws/\u4a39</a>')),
("<tag>http://example.com</tag>", {},
u('&lt;tag&gt;<a href="http://example.com">http://example.com</a>&lt;/tag&gt;')),
("Just a www.example.com link.", {},
u('Just a <a href="http://www.example.com">www.example.com</a> link.')),
("Just a www.example.com link.",
{"require_protocol": True},
u('Just a www.example.com link.')),
("A http://reallylong.com/link/that/exceedsthelenglimit.html",
{"require_protocol": True, "shorten": True},
u('A <a href="http://reallylong.com/link/that/exceedsthelenglimit.html" title="http://reallylong.com/link/that/exceedsthelenglimit.html">http://reallylong.com/link...</a>')),
("A http://reallylongdomainnamethatwillbetoolong.com/hi!",
{"shorten": True},
u('A <a href="http://reallylongdomainnamethatwillbetoolong.com/hi" title="http://reallylongdomainnamethatwillbetoolong.com/hi">http://reallylongdomainnametha...</a>!')),
("A file:///passwords.txt and http://web.com link", {},
u('A file:///passwords.txt and <a href="http://web.com">http://web.com</a> link')),
("A file:///passwords.txt and http://web.com link",
{"permitted_protocols": ["file"]},
u('A <a href="file:///passwords.txt">file:///passwords.txt</a> and http://web.com link')),
("www.external-link.com",
{"extra_params": 'rel="nofollow" class="external"'},
u('<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>')),
("www.external-link.com and www.internal-link.com/blogs extra",
{"extra_params": lambda href: 'class="internal"' if href.startswith("http://www.internal-link.com") else 'rel="nofollow" class="external"'},
u('<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a> and <a href="http://www.internal-link.com/blogs" class="internal">www.internal-link.com/blogs</a> extra')),
("www.external-link.com",
{"extra_params": lambda href: ' rel="nofollow" class="external" '},
u('<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>')),
]
class EscapeTestCase(unittest.TestCase):
def test_linkify(self):
for text, kwargs, html in linkify_tests:
linked = tornado.escape.linkify(text, **kwargs)
self.assertEqual(linked, html)
def test_xhtml_escape(self):
tests = [
("<foo>", "&lt;foo&gt;"),
(u("<foo>"), u("&lt;foo&gt;")),
(b"<foo>", b"&lt;foo&gt;"),
("<>&\"'", "&lt;&gt;&amp;&quot;&#39;"),
("&amp;", "&amp;amp;"),
(u("<\u00e9>"), u("&lt;\u00e9&gt;")),
(b"<\xc3\xa9>", b"&lt;\xc3\xa9&gt;"),
]
for unescaped, escaped in tests:
self.assertEqual(utf8(xhtml_escape(unescaped)), utf8(escaped))
self.assertEqual(utf8(unescaped), utf8(xhtml_unescape(escaped)))
def test_url_escape_unicode(self):
tests = [
# byte strings are passed through as-is
(u('\u00e9').encode('utf8'), '%C3%A9'),
(u('\u00e9').encode('latin1'), '%E9'),
# unicode strings become utf8
(u('\u00e9'), '%C3%A9'),
]
for unescaped, escaped in tests:
self.assertEqual(url_escape(unescaped), escaped)
def test_url_unescape_unicode(self):
tests = [
('%C3%A9', u('\u00e9'), 'utf8'),
('%C3%A9', u('\u00c3\u00a9'), 'latin1'),
('%C3%A9', utf8(u('\u00e9')), None),
]
for escaped, unescaped, encoding in tests:
# input strings to url_unescape should only contain ascii
# characters, but make sure the function accepts both byte
# and unicode strings.
self.assertEqual(url_unescape(to_unicode(escaped), encoding), unescaped)
self.assertEqual(url_unescape(utf8(escaped), encoding), unescaped)
def test_url_escape_quote_plus(self):
unescaped = '+ #%'
plus_escaped = '%2B+%23%25'
escaped = '%2B%20%23%25'
self.assertEqual(url_escape(unescaped), plus_escaped)
self.assertEqual(url_escape(unescaped, plus=False), escaped)
self.assertEqual(url_unescape(plus_escaped), unescaped)
self.assertEqual(url_unescape(escaped, plus=False), unescaped)
self.assertEqual(url_unescape(plus_escaped, encoding=None),
utf8(unescaped))
self.assertEqual(url_unescape(escaped, encoding=None, plus=False),
utf8(unescaped))
def test_escape_return_types(self):
# On python2 the escape methods should generally return the same
# type as their argument
self.assertEqual(type(xhtml_escape("foo")), str)
self.assertEqual(type(xhtml_escape(u("foo"))), unicode_type)
def test_json_decode(self):
# json_decode accepts both bytes and unicode, but strings it returns
# are always unicode.
self.assertEqual(json_decode(b'"foo"'), u("foo"))
self.assertEqual(json_decode(u('"foo"')), u("foo"))
# Non-ascii bytes are interpreted as utf8
self.assertEqual(json_decode(utf8(u('"\u00e9"'))), u("\u00e9"))
def test_json_encode(self):
# json deals with strings, not bytes. On python 2 byte strings will
# convert automatically if they are utf8; on python 3 byte strings
# are not allowed.
self.assertEqual(json_decode(json_encode(u("\u00e9"))), u("\u00e9"))
if bytes_type is str:
self.assertEqual(json_decode(json_encode(utf8(u("\u00e9")))), u("\u00e9"))
self.assertRaises(UnicodeDecodeError, json_encode, b"\xe9")

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,22 @@
# SOME DESCRIPTIVE TITLE.
# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER
# This file is distributed under the same license as the PACKAGE package.
# FIRST AUTHOR <EMAIL@ADDRESS>, YEAR.
#
#, fuzzy
msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2012-06-14 01:10-0700\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n"
"Language: \n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
#: extract_me.py:1
msgid "school"
msgstr "école"

View file

@ -0,0 +1,517 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
import base64
import binascii
from contextlib import closing
import functools
import sys
import threading
from tornado.escape import utf8
from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream
from tornado.log import gen_log
from tornado import netutil
from tornado.stack_context import ExceptionStackContext, NullContext
from tornado.testing import AsyncHTTPTestCase, bind_unused_port, gen_test, ExpectLog
from tornado.test.util import unittest, skipOnTravis
from tornado.util import u, bytes_type
from tornado.web import Application, RequestHandler, url
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO
class HelloWorldHandler(RequestHandler):
def get(self):
name = self.get_argument("name", "world")
self.set_header("Content-Type", "text/plain")
self.finish("Hello %s!" % name)
class PostHandler(RequestHandler):
def post(self):
self.finish("Post arg1: %s, arg2: %s" % (
self.get_argument("arg1"), self.get_argument("arg2")))
class ChunkHandler(RequestHandler):
def get(self):
self.write("asdf")
self.flush()
self.write("qwer")
class AuthHandler(RequestHandler):
def get(self):
self.finish(self.request.headers["Authorization"])
class CountdownHandler(RequestHandler):
def get(self, count):
count = int(count)
if count > 0:
self.redirect(self.reverse_url("countdown", count - 1))
else:
self.write("Zero")
class EchoPostHandler(RequestHandler):
def post(self):
self.write(self.request.body)
class UserAgentHandler(RequestHandler):
def get(self):
self.write(self.request.headers.get('User-Agent', 'User agent not set'))
class ContentLength304Handler(RequestHandler):
def get(self):
self.set_status(304)
self.set_header('Content-Length', 42)
def _clear_headers_for_304(self):
# Tornado strips content-length from 304 responses, but here we
# want to simulate servers that include the headers anyway.
pass
class AllMethodsHandler(RequestHandler):
SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ('OTHER',)
def method(self):
self.write(self.request.method)
get = post = put = delete = options = patch = other = method
# These tests end up getting run redundantly: once here with the default
# HTTPClient implementation, and then again in each implementation's own
# test suite.
class HTTPClientCommonTestCase(AsyncHTTPTestCase):
def get_app(self):
return Application([
url("/hello", HelloWorldHandler),
url("/post", PostHandler),
url("/chunk", ChunkHandler),
url("/auth", AuthHandler),
url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
url("/echopost", EchoPostHandler),
url("/user_agent", UserAgentHandler),
url("/304_with_content_length", ContentLength304Handler),
url("/all_methods", AllMethodsHandler),
], gzip=True)
@skipOnTravis
def test_hello_world(self):
response = self.fetch("/hello")
self.assertEqual(response.code, 200)
self.assertEqual(response.headers["Content-Type"], "text/plain")
self.assertEqual(response.body, b"Hello world!")
self.assertEqual(int(response.request_time), 0)
response = self.fetch("/hello?name=Ben")
self.assertEqual(response.body, b"Hello Ben!")
def test_streaming_callback(self):
# streaming_callback is also tested in test_chunked
chunks = []
response = self.fetch("/hello",
streaming_callback=chunks.append)
# with streaming_callback, data goes to the callback and not response.body
self.assertEqual(chunks, [b"Hello world!"])
self.assertFalse(response.body)
def test_post(self):
response = self.fetch("/post", method="POST",
body="arg1=foo&arg2=bar")
self.assertEqual(response.code, 200)
self.assertEqual(response.body, b"Post arg1: foo, arg2: bar")
def test_chunked(self):
response = self.fetch("/chunk")
self.assertEqual(response.body, b"asdfqwer")
chunks = []
response = self.fetch("/chunk",
streaming_callback=chunks.append)
self.assertEqual(chunks, [b"asdf", b"qwer"])
self.assertFalse(response.body)
def test_chunked_close(self):
# test case in which chunks spread read-callback processing
# over several ioloop iterations, but the connection is already closed.
sock, port = bind_unused_port()
with closing(sock):
def write_response(stream, request_data):
stream.write(b"""\
HTTP/1.1 200 OK
Transfer-Encoding: chunked
1
1
1
2
0
""".replace(b"\n", b"\r\n"), callback=stream.close)
def accept_callback(conn, address):
# fake an HTTP server using chunked encoding where the final chunks
# and connection close all happen at once
stream = IOStream(conn, io_loop=self.io_loop)
stream.read_until(b"\r\n\r\n",
functools.partial(write_response, stream))
netutil.add_accept_handler(sock, accept_callback, self.io_loop)
self.http_client.fetch("http://127.0.0.1:%d/" % port, self.stop)
resp = self.wait()
resp.rethrow()
self.assertEqual(resp.body, b"12")
self.io_loop.remove_handler(sock.fileno())
def test_streaming_stack_context(self):
chunks = []
exc_info = []
def error_handler(typ, value, tb):
exc_info.append((typ, value, tb))
return True
def streaming_cb(chunk):
chunks.append(chunk)
if chunk == b'qwer':
1 / 0
with ExceptionStackContext(error_handler):
self.fetch('/chunk', streaming_callback=streaming_cb)
self.assertEqual(chunks, [b'asdf', b'qwer'])
self.assertEqual(1, len(exc_info))
self.assertIs(exc_info[0][0], ZeroDivisionError)
def test_basic_auth(self):
self.assertEqual(self.fetch("/auth", auth_username="Aladdin",
auth_password="open sesame").body,
b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
def test_basic_auth_explicit_mode(self):
self.assertEqual(self.fetch("/auth", auth_username="Aladdin",
auth_password="open sesame",
auth_mode="basic").body,
b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
def test_unsupported_auth_mode(self):
# curl and simple clients handle errors a bit differently; the
# important thing is that they don't fall back to basic auth
# on an unknown mode.
with ExpectLog(gen_log, "uncaught exception", required=False):
with self.assertRaises((ValueError, HTTPError)):
response = self.fetch("/auth", auth_username="Aladdin",
auth_password="open sesame",
auth_mode="asdf")
response.rethrow()
def test_follow_redirect(self):
response = self.fetch("/countdown/2", follow_redirects=False)
self.assertEqual(302, response.code)
self.assertTrue(response.headers["Location"].endswith("/countdown/1"))
response = self.fetch("/countdown/2")
self.assertEqual(200, response.code)
self.assertTrue(response.effective_url.endswith("/countdown/0"))
self.assertEqual(b"Zero", response.body)
def test_credentials_in_url(self):
url = self.get_url("/auth").replace("http://", "http://me:secret@")
self.http_client.fetch(url, self.stop)
response = self.wait()
self.assertEqual(b"Basic " + base64.b64encode(b"me:secret"),
response.body)
def test_body_encoding(self):
unicode_body = u("\xe9")
byte_body = binascii.a2b_hex(b"e9")
# unicode string in body gets converted to utf8
response = self.fetch("/echopost", method="POST", body=unicode_body,
headers={"Content-Type": "application/blah"})
self.assertEqual(response.headers["Content-Length"], "2")
self.assertEqual(response.body, utf8(unicode_body))
# byte strings pass through directly
response = self.fetch("/echopost", method="POST",
body=byte_body,
headers={"Content-Type": "application/blah"})
self.assertEqual(response.headers["Content-Length"], "1")
self.assertEqual(response.body, byte_body)
# Mixing unicode in headers and byte string bodies shouldn't
# break anything
response = self.fetch("/echopost", method="POST", body=byte_body,
headers={"Content-Type": "application/blah"},
user_agent=u("foo"))
self.assertEqual(response.headers["Content-Length"], "1")
self.assertEqual(response.body, byte_body)
def test_types(self):
response = self.fetch("/hello")
self.assertEqual(type(response.body), bytes_type)
self.assertEqual(type(response.headers["Content-Type"]), str)
self.assertEqual(type(response.code), int)
self.assertEqual(type(response.effective_url), str)
def test_header_callback(self):
first_line = []
headers = {}
chunks = []
def header_callback(header_line):
if header_line.startswith('HTTP/'):
first_line.append(header_line)
elif header_line != '\r\n':
k, v = header_line.split(':', 1)
headers[k] = v.strip()
def streaming_callback(chunk):
# All header callbacks are run before any streaming callbacks,
# so the header data is available to process the data as it
# comes in.
self.assertEqual(headers['Content-Type'], 'text/html; charset=UTF-8')
chunks.append(chunk)
self.fetch('/chunk', header_callback=header_callback,
streaming_callback=streaming_callback)
self.assertEqual(len(first_line), 1)
self.assertRegexpMatches(first_line[0], 'HTTP/1.[01] 200 OK\r\n')
self.assertEqual(chunks, [b'asdf', b'qwer'])
def test_header_callback_stack_context(self):
exc_info = []
def error_handler(typ, value, tb):
exc_info.append((typ, value, tb))
return True
def header_callback(header_line):
if header_line.startswith('Content-Type:'):
1 / 0
with ExceptionStackContext(error_handler):
self.fetch('/chunk', header_callback=header_callback)
self.assertEqual(len(exc_info), 1)
self.assertIs(exc_info[0][0], ZeroDivisionError)
def test_configure_defaults(self):
defaults = dict(user_agent='TestDefaultUserAgent', allow_ipv6=False)
# Construct a new instance of the configured client class
client = self.http_client.__class__(self.io_loop, force_instance=True,
defaults=defaults)
client.fetch(self.get_url('/user_agent'), callback=self.stop)
response = self.wait()
self.assertEqual(response.body, b'TestDefaultUserAgent')
client.close()
def test_304_with_content_length(self):
# According to the spec 304 responses SHOULD NOT include
# Content-Length or other entity headers, but some servers do it
# anyway.
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.5
response = self.fetch('/304_with_content_length')
self.assertEqual(response.code, 304)
self.assertEqual(response.headers['Content-Length'], '42')
def test_final_callback_stack_context(self):
# The final callback should be run outside of the httpclient's
# stack_context. We want to ensure that there is not stack_context
# between the user's callback and the IOLoop, so monkey-patch
# IOLoop.handle_callback_exception and disable the test harness's
# context with a NullContext.
# Note that this does not apply to secondary callbacks (header
# and streaming_callback), as errors there must be seen as errors
# by the http client so it can clean up the connection.
exc_info = []
def handle_callback_exception(callback):
exc_info.append(sys.exc_info())
self.stop()
self.io_loop.handle_callback_exception = handle_callback_exception
with NullContext():
self.http_client.fetch(self.get_url('/hello'),
lambda response: 1 / 0)
self.wait()
self.assertEqual(exc_info[0][0], ZeroDivisionError)
@gen_test
def test_future_interface(self):
response = yield self.http_client.fetch(self.get_url('/hello'))
self.assertEqual(response.body, b'Hello world!')
@gen_test
def test_future_http_error(self):
with self.assertRaises(HTTPError) as context:
yield self.http_client.fetch(self.get_url('/notfound'))
self.assertEqual(context.exception.code, 404)
self.assertEqual(context.exception.response.code, 404)
@gen_test
def test_reuse_request_from_response(self):
# The response.request attribute should be an HTTPRequest, not
# a _RequestProxy.
# This test uses self.http_client.fetch because self.fetch calls
# self.get_url on the input unconditionally.
url = self.get_url('/hello')
response = yield self.http_client.fetch(url)
self.assertEqual(response.request.url, url)
self.assertTrue(isinstance(response.request, HTTPRequest))
response2 = yield self.http_client.fetch(response.request)
self.assertEqual(response2.body, b'Hello world!')
def test_all_methods(self):
for method in ['GET', 'DELETE', 'OPTIONS']:
response = self.fetch('/all_methods', method=method)
self.assertEqual(response.body, utf8(method))
for method in ['POST', 'PUT', 'PATCH']:
response = self.fetch('/all_methods', method=method, body=b'')
self.assertEqual(response.body, utf8(method))
response = self.fetch('/all_methods', method='HEAD')
self.assertEqual(response.body, b'')
response = self.fetch('/all_methods', method='OTHER',
allow_nonstandard_methods=True)
self.assertEqual(response.body, b'OTHER')
@gen_test
def test_body(self):
hello_url = self.get_url('/hello')
with self.assertRaises(AssertionError) as context:
yield self.http_client.fetch(hello_url, body='data')
self.assertTrue('must be empty' in str(context.exception))
with self.assertRaises(AssertionError) as context:
yield self.http_client.fetch(hello_url, method='POST')
self.assertTrue('must not be empty' in str(context.exception))
class RequestProxyTest(unittest.TestCase):
def test_request_set(self):
proxy = _RequestProxy(HTTPRequest('http://example.com/',
user_agent='foo'),
dict())
self.assertEqual(proxy.user_agent, 'foo')
def test_default_set(self):
proxy = _RequestProxy(HTTPRequest('http://example.com/'),
dict(network_interface='foo'))
self.assertEqual(proxy.network_interface, 'foo')
def test_both_set(self):
proxy = _RequestProxy(HTTPRequest('http://example.com/',
proxy_host='foo'),
dict(proxy_host='bar'))
self.assertEqual(proxy.proxy_host, 'foo')
def test_neither_set(self):
proxy = _RequestProxy(HTTPRequest('http://example.com/'),
dict())
self.assertIs(proxy.auth_username, None)
def test_bad_attribute(self):
proxy = _RequestProxy(HTTPRequest('http://example.com/'),
dict())
with self.assertRaises(AttributeError):
proxy.foo
def test_defaults_none(self):
proxy = _RequestProxy(HTTPRequest('http://example.com/'), None)
self.assertIs(proxy.auth_username, None)
class HTTPResponseTestCase(unittest.TestCase):
def test_str(self):
response = HTTPResponse(HTTPRequest('http://example.com'),
200, headers={}, buffer=BytesIO())
s = str(response)
self.assertTrue(s.startswith('HTTPResponse('))
self.assertIn('code=200', s)
class SyncHTTPClientTest(unittest.TestCase):
def setUp(self):
if IOLoop.configured_class().__name__ in ('TwistedIOLoop',
'AsyncIOMainLoop'):
# TwistedIOLoop only supports the global reactor, so we can't have
# separate IOLoops for client and server threads.
# AsyncIOMainLoop doesn't work with the default policy
# (although it could with some tweaks to this test and a
# policy that created loops for non-main threads).
raise unittest.SkipTest(
'Sync HTTPClient not compatible with TwistedIOLoop or '
'AsyncIOMainLoop')
self.server_ioloop = IOLoop()
sock, self.port = bind_unused_port()
app = Application([('/', HelloWorldHandler)])
self.server = HTTPServer(app, io_loop=self.server_ioloop)
self.server.add_socket(sock)
self.server_thread = threading.Thread(target=self.server_ioloop.start)
self.server_thread.start()
self.http_client = HTTPClient()
def tearDown(self):
def stop_server():
self.server.stop()
self.server_ioloop.stop()
self.server_ioloop.add_callback(stop_server)
self.server_thread.join()
self.http_client.close()
self.server_ioloop.close(all_fds=True)
def get_url(self, path):
return 'http://localhost:%d%s' % (self.port, path)
def test_sync_client(self):
response = self.http_client.fetch(self.get_url('/'))
self.assertEqual(b'Hello world!', response.body)
def test_sync_client_error(self):
# Synchronous HTTPClient raises errors directly; no need for
# response.rethrow()
with self.assertRaises(HTTPError) as assertion:
self.http_client.fetch(self.get_url('/notfound'))
self.assertEqual(assertion.exception.code, 404)
class HTTPRequestTestCase(unittest.TestCase):
def test_headers(self):
request = HTTPRequest('http://example.com', headers={'foo': 'bar'})
self.assertEqual(request.headers, {'foo': 'bar'})
def test_headers_setter(self):
request = HTTPRequest('http://example.com')
request.headers = {'bar': 'baz'}
self.assertEqual(request.headers, {'bar': 'baz'})
def test_null_headers_setter(self):
request = HTTPRequest('http://example.com')
request.headers = None
self.assertEqual(request.headers, {})
def test_body(self):
request = HTTPRequest('http://example.com', body='foo')
self.assertEqual(request.body, utf8('foo'))
def test_body_setter(self):
request = HTTPRequest('http://example.com')
request.body = 'foo'
self.assertEqual(request.body, utf8('foo'))

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,255 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
from tornado.httputil import url_concat, parse_multipart_form_data, HTTPHeaders, format_timestamp
from tornado.escape import utf8
from tornado.log import gen_log
from tornado.testing import ExpectLog
from tornado.test.util import unittest
import datetime
import logging
import time
class TestUrlConcat(unittest.TestCase):
def test_url_concat_no_query_params(self):
url = url_concat(
"https://localhost/path",
[('y', 'y'), ('z', 'z')],
)
self.assertEqual(url, "https://localhost/path?y=y&z=z")
def test_url_concat_encode_args(self):
url = url_concat(
"https://localhost/path",
[('y', '/y'), ('z', 'z')],
)
self.assertEqual(url, "https://localhost/path?y=%2Fy&z=z")
def test_url_concat_trailing_q(self):
url = url_concat(
"https://localhost/path?",
[('y', 'y'), ('z', 'z')],
)
self.assertEqual(url, "https://localhost/path?y=y&z=z")
def test_url_concat_q_with_no_trailing_amp(self):
url = url_concat(
"https://localhost/path?x",
[('y', 'y'), ('z', 'z')],
)
self.assertEqual(url, "https://localhost/path?x&y=y&z=z")
def test_url_concat_trailing_amp(self):
url = url_concat(
"https://localhost/path?x&",
[('y', 'y'), ('z', 'z')],
)
self.assertEqual(url, "https://localhost/path?x&y=y&z=z")
def test_url_concat_mult_params(self):
url = url_concat(
"https://localhost/path?a=1&b=2",
[('y', 'y'), ('z', 'z')],
)
self.assertEqual(url, "https://localhost/path?a=1&b=2&y=y&z=z")
def test_url_concat_no_params(self):
url = url_concat(
"https://localhost/path?r=1&t=2",
[],
)
self.assertEqual(url, "https://localhost/path?r=1&t=2")
class MultipartFormDataTest(unittest.TestCase):
def test_file_upload(self):
data = b"""\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"
Foo
--1234--""".replace(b"\n", b"\r\n")
args = {}
files = {}
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], "ab.txt")
self.assertEqual(file["body"], b"Foo")
def test_unquoted_names(self):
# quotes are optional unless special characters are present
data = b"""\
--1234
Content-Disposition: form-data; name=files; filename=ab.txt
Foo
--1234--""".replace(b"\n", b"\r\n")
args = {}
files = {}
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], "ab.txt")
self.assertEqual(file["body"], b"Foo")
def test_special_filenames(self):
filenames = ['a;b.txt',
'a"b.txt',
'a";b.txt',
'a;"b.txt',
'a";";.txt',
'a\\"b.txt',
'a\\b.txt',
]
for filename in filenames:
logging.debug("trying filename %r", filename)
data = """\
--1234
Content-Disposition: form-data; name="files"; filename="%s"
Foo
--1234--""" % filename.replace('\\', '\\\\').replace('"', '\\"')
data = utf8(data.replace("\n", "\r\n"))
args = {}
files = {}
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], filename)
self.assertEqual(file["body"], b"Foo")
def test_boundary_starts_and_ends_with_quotes(self):
data = b'''\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"
Foo
--1234--'''.replace(b"\n", b"\r\n")
args = {}
files = {}
parse_multipart_form_data(b'"1234"', data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], "ab.txt")
self.assertEqual(file["body"], b"Foo")
def test_missing_headers(self):
data = b'''\
--1234
Foo
--1234--'''.replace(b"\n", b"\r\n")
args = {}
files = {}
with ExpectLog(gen_log, "multipart/form-data missing headers"):
parse_multipart_form_data(b"1234", data, args, files)
self.assertEqual(files, {})
def test_invalid_content_disposition(self):
data = b'''\
--1234
Content-Disposition: invalid; name="files"; filename="ab.txt"
Foo
--1234--'''.replace(b"\n", b"\r\n")
args = {}
files = {}
with ExpectLog(gen_log, "Invalid multipart/form-data"):
parse_multipart_form_data(b"1234", data, args, files)
self.assertEqual(files, {})
def test_line_does_not_end_with_correct_line_break(self):
data = b'''\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"
Foo--1234--'''.replace(b"\n", b"\r\n")
args = {}
files = {}
with ExpectLog(gen_log, "Invalid multipart/form-data"):
parse_multipart_form_data(b"1234", data, args, files)
self.assertEqual(files, {})
def test_content_disposition_header_without_name_parameter(self):
data = b"""\
--1234
Content-Disposition: form-data; filename="ab.txt"
Foo
--1234--""".replace(b"\n", b"\r\n")
args = {}
files = {}
with ExpectLog(gen_log, "multipart/form-data value missing name"):
parse_multipart_form_data(b"1234", data, args, files)
self.assertEqual(files, {})
def test_data_after_final_boundary(self):
# The spec requires that data after the final boundary be ignored.
# http://www.w3.org/Protocols/rfc1341/7_2_Multipart.html
# In practice, some libraries include an extra CRLF after the boundary.
data = b"""\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"
Foo
--1234--
""".replace(b"\n", b"\r\n")
args = {}
files = {}
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], "ab.txt")
self.assertEqual(file["body"], b"Foo")
class HTTPHeadersTest(unittest.TestCase):
def test_multi_line(self):
# Lines beginning with whitespace are appended to the previous line
# with any leading whitespace replaced by a single space.
# Note that while multi-line headers are a part of the HTTP spec,
# their use is strongly discouraged.
data = """\
Foo: bar
baz
Asdf: qwer
\tzxcv
Foo: even
more
lines
""".replace("\n", "\r\n")
headers = HTTPHeaders.parse(data)
self.assertEqual(headers["asdf"], "qwer zxcv")
self.assertEqual(headers.get_list("asdf"), ["qwer zxcv"])
self.assertEqual(headers["Foo"], "bar baz,even more lines")
self.assertEqual(headers.get_list("foo"), ["bar baz", "even more lines"])
self.assertEqual(sorted(list(headers.get_all())),
[("Asdf", "qwer zxcv"),
("Foo", "bar baz"),
("Foo", "even more lines")])
class FormatTimestampTest(unittest.TestCase):
# Make sure that all the input types are supported.
TIMESTAMP = 1359312200.503611
EXPECTED = 'Sun, 27 Jan 2013 18:43:20 GMT'
def check(self, value):
self.assertEqual(format_timestamp(value), self.EXPECTED)
def test_unix_time_float(self):
self.check(self.TIMESTAMP)
def test_unix_time_int(self):
self.check(int(self.TIMESTAMP))
def test_struct_time(self):
self.check(time.gmtime(self.TIMESTAMP))
def test_time_tuple(self):
tup = tuple(time.gmtime(self.TIMESTAMP))
self.assertEqual(9, len(tup))
self.check(tup)
def test_datetime(self):
self.check(datetime.datetime.utcfromtimestamp(self.TIMESTAMP))

View file

@ -0,0 +1,46 @@
from __future__ import absolute_import, division, print_function, with_statement
from tornado.test.util import unittest
class ImportTest(unittest.TestCase):
def test_import_everything(self):
# Some of our modules are not otherwise tested. Import them
# all (unless they have external dependencies) here to at
# least ensure that there are no syntax errors.
import tornado.auth
import tornado.autoreload
import tornado.concurrent
# import tornado.curl_httpclient # depends on pycurl
import tornado.escape
import tornado.gen
import tornado.http1connection
import tornado.httpclient
import tornado.httpserver
import tornado.httputil
import tornado.ioloop
import tornado.iostream
import tornado.locale
import tornado.log
import tornado.netutil
import tornado.options
import tornado.process
import tornado.simple_httpclient
import tornado.stack_context
import tornado.tcpserver
import tornado.template
import tornado.testing
import tornado.util
import tornado.web
import tornado.websocket
import tornado.wsgi
# for modules with dependencies, if those dependencies can be loaded,
# load them too.
def test_import_pycurl(self):
try:
import pycurl
except ImportError:
pass
else:
import tornado.curl_httpclient

View file

@ -0,0 +1,461 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
import contextlib
import datetime
import functools
import socket
import sys
import threading
import time
from tornado import gen
from tornado.ioloop import IOLoop, TimeoutError
from tornado.log import app_log
from tornado.stack_context import ExceptionStackContext, StackContext, wrap, NullContext
from tornado.testing import AsyncTestCase, bind_unused_port, ExpectLog
from tornado.test.util import unittest, skipIfNonUnix, skipOnTravis
try:
from concurrent import futures
except ImportError:
futures = None
class TestIOLoop(AsyncTestCase):
@skipOnTravis
def test_add_callback_wakeup(self):
# Make sure that add_callback from inside a running IOLoop
# wakes up the IOLoop immediately instead of waiting for a timeout.
def callback():
self.called = True
self.stop()
def schedule_callback():
self.called = False
self.io_loop.add_callback(callback)
# Store away the time so we can check if we woke up immediately
self.start_time = time.time()
self.io_loop.add_timeout(self.io_loop.time(), schedule_callback)
self.wait()
self.assertAlmostEqual(time.time(), self.start_time, places=2)
self.assertTrue(self.called)
@skipOnTravis
def test_add_callback_wakeup_other_thread(self):
def target():
# sleep a bit to let the ioloop go into its poll loop
time.sleep(0.01)
self.stop_time = time.time()
self.io_loop.add_callback(self.stop)
thread = threading.Thread(target=target)
self.io_loop.add_callback(thread.start)
self.wait()
delta = time.time() - self.stop_time
self.assertLess(delta, 0.1)
thread.join()
def test_add_timeout_timedelta(self):
self.io_loop.add_timeout(datetime.timedelta(microseconds=1), self.stop)
self.wait()
def test_multiple_add(self):
sock, port = bind_unused_port()
try:
self.io_loop.add_handler(sock.fileno(), lambda fd, events: None,
IOLoop.READ)
# Attempting to add the same handler twice fails
# (with a platform-dependent exception)
self.assertRaises(Exception, self.io_loop.add_handler,
sock.fileno(), lambda fd, events: None,
IOLoop.READ)
finally:
self.io_loop.remove_handler(sock.fileno())
sock.close()
def test_remove_without_add(self):
# remove_handler should not throw an exception if called on an fd
# was never added.
sock, port = bind_unused_port()
try:
self.io_loop.remove_handler(sock.fileno())
finally:
sock.close()
def test_add_callback_from_signal(self):
# cheat a little bit and just run this normally, since we can't
# easily simulate the races that happen with real signal handlers
self.io_loop.add_callback_from_signal(self.stop)
self.wait()
def test_add_callback_from_signal_other_thread(self):
# Very crude test, just to make sure that we cover this case.
# This also happens to be the first test where we run an IOLoop in
# a non-main thread.
other_ioloop = IOLoop()
thread = threading.Thread(target=other_ioloop.start)
thread.start()
other_ioloop.add_callback_from_signal(other_ioloop.stop)
thread.join()
other_ioloop.close()
def test_add_callback_while_closing(self):
# Issue #635: add_callback() should raise a clean exception
# if called while another thread is closing the IOLoop.
closing = threading.Event()
def target():
other_ioloop.add_callback(other_ioloop.stop)
other_ioloop.start()
closing.set()
other_ioloop.close(all_fds=True)
other_ioloop = IOLoop()
thread = threading.Thread(target=target)
thread.start()
closing.wait()
for i in range(1000):
try:
other_ioloop.add_callback(lambda: None)
except RuntimeError as e:
self.assertEqual("IOLoop is closing", str(e))
break
def test_handle_callback_exception(self):
# IOLoop.handle_callback_exception can be overridden to catch
# exceptions in callbacks.
def handle_callback_exception(callback):
self.assertIs(sys.exc_info()[0], ZeroDivisionError)
self.stop()
self.io_loop.handle_callback_exception = handle_callback_exception
with NullContext():
# remove the test StackContext that would see this uncaught
# exception as a test failure.
self.io_loop.add_callback(lambda: 1 / 0)
self.wait()
@skipIfNonUnix # just because socketpair is so convenient
def test_read_while_writeable(self):
# Ensure that write events don't come in while we're waiting for
# a read and haven't asked for writeability. (the reverse is
# difficult to test for)
client, server = socket.socketpair()
try:
def handler(fd, events):
self.assertEqual(events, IOLoop.READ)
self.stop()
self.io_loop.add_handler(client.fileno(), handler, IOLoop.READ)
self.io_loop.add_timeout(self.io_loop.time() + 0.01,
functools.partial(server.send, b'asdf'))
self.wait()
self.io_loop.remove_handler(client.fileno())
finally:
client.close()
server.close()
def test_remove_timeout_after_fire(self):
# It is not an error to call remove_timeout after it has run.
handle = self.io_loop.add_timeout(self.io_loop.time(), self.stop)
self.wait()
self.io_loop.remove_timeout(handle)
def test_remove_timeout_cleanup(self):
# Add and remove enough callbacks to trigger cleanup.
# Not a very thorough test, but it ensures that the cleanup code
# gets executed and doesn't blow up. This test is only really useful
# on PollIOLoop subclasses, but it should run silently on any
# implementation.
for i in range(2000):
timeout = self.io_loop.add_timeout(self.io_loop.time() + 3600,
lambda: None)
self.io_loop.remove_timeout(timeout)
# HACK: wait two IOLoop iterations for the GC to happen.
self.io_loop.add_callback(lambda: self.io_loop.add_callback(self.stop))
self.wait()
def test_timeout_with_arguments(self):
# This tests that all the timeout methods pass through *args correctly.
results = []
self.io_loop.add_timeout(self.io_loop.time(), results.append, 1)
self.io_loop.add_timeout(datetime.timedelta(seconds=0),
results.append, 2)
self.io_loop.call_at(self.io_loop.time(), results.append, 3)
self.io_loop.call_later(0, results.append, 4)
self.io_loop.call_later(0, self.stop)
self.wait()
self.assertEqual(results, [1, 2, 3, 4])
def test_close_file_object(self):
"""When a file object is used instead of a numeric file descriptor,
the object should be closed (by IOLoop.close(all_fds=True),
not just the fd.
"""
# Use a socket since they are supported by IOLoop on all platforms.
# Unfortunately, sockets don't support the .closed attribute for
# inspecting their close status, so we must use a wrapper.
class SocketWrapper(object):
def __init__(self, sockobj):
self.sockobj = sockobj
self.closed = False
def fileno(self):
return self.sockobj.fileno()
def close(self):
self.closed = True
self.sockobj.close()
sockobj, port = bind_unused_port()
socket_wrapper = SocketWrapper(sockobj)
io_loop = IOLoop()
io_loop.add_handler(socket_wrapper, lambda fd, events: None,
IOLoop.READ)
io_loop.close(all_fds=True)
self.assertTrue(socket_wrapper.closed)
def test_handler_callback_file_object(self):
"""The handler callback receives the same fd object it passed in."""
server_sock, port = bind_unused_port()
fds = []
def handle_connection(fd, events):
fds.append(fd)
conn, addr = server_sock.accept()
conn.close()
self.stop()
self.io_loop.add_handler(server_sock, handle_connection, IOLoop.READ)
with contextlib.closing(socket.socket()) as client_sock:
client_sock.connect(('127.0.0.1', port))
self.wait()
self.io_loop.remove_handler(server_sock)
self.io_loop.add_handler(server_sock.fileno(), handle_connection,
IOLoop.READ)
with contextlib.closing(socket.socket()) as client_sock:
client_sock.connect(('127.0.0.1', port))
self.wait()
self.assertIs(fds[0], server_sock)
self.assertEqual(fds[1], server_sock.fileno())
self.io_loop.remove_handler(server_sock.fileno())
server_sock.close()
def test_mixed_fd_fileobj(self):
server_sock, port = bind_unused_port()
def f(fd, events):
pass
self.io_loop.add_handler(server_sock, f, IOLoop.READ)
with self.assertRaises(Exception):
# The exact error is unspecified - some implementations use
# IOError, others use ValueError.
self.io_loop.add_handler(server_sock.fileno(), f, IOLoop.READ)
self.io_loop.remove_handler(server_sock.fileno())
server_sock.close()
def test_reentrant(self):
"""Calling start() twice should raise an error, not deadlock."""
returned_from_start = [False]
got_exception = [False]
def callback():
try:
self.io_loop.start()
returned_from_start[0] = True
except Exception:
got_exception[0] = True
self.stop()
self.io_loop.add_callback(callback)
self.wait()
self.assertTrue(got_exception[0])
self.assertFalse(returned_from_start[0])
def test_exception_logging(self):
"""Uncaught exceptions get logged by the IOLoop."""
# Use a NullContext to keep the exception from being caught by
# AsyncTestCase.
with NullContext():
self.io_loop.add_callback(lambda: 1/0)
self.io_loop.add_callback(self.stop)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
def test_exception_logging_future(self):
"""The IOLoop examines exceptions from Futures and logs them."""
with NullContext():
@gen.coroutine
def callback():
self.io_loop.add_callback(self.stop)
1/0
self.io_loop.add_callback(callback)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
def test_spawn_callback(self):
# An added callback runs in the test's stack_context, so will be
# re-arised in wait().
self.io_loop.add_callback(lambda: 1/0)
with self.assertRaises(ZeroDivisionError):
self.wait()
# A spawned callback is run directly on the IOLoop, so it will be
# logged without stopping the test.
self.io_loop.spawn_callback(lambda: 1/0)
self.io_loop.add_callback(self.stop)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
# Deliberately not a subclass of AsyncTestCase so the IOLoop isn't
# automatically set as current.
class TestIOLoopCurrent(unittest.TestCase):
def setUp(self):
self.io_loop = IOLoop()
def tearDown(self):
self.io_loop.close()
def test_current(self):
def f():
self.current_io_loop = IOLoop.current()
self.io_loop.stop()
self.io_loop.add_callback(f)
self.io_loop.start()
self.assertIs(self.current_io_loop, self.io_loop)
class TestIOLoopAddCallback(AsyncTestCase):
def setUp(self):
super(TestIOLoopAddCallback, self).setUp()
self.active_contexts = []
def add_callback(self, callback, *args, **kwargs):
self.io_loop.add_callback(callback, *args, **kwargs)
@contextlib.contextmanager
def context(self, name):
self.active_contexts.append(name)
yield
self.assertEqual(self.active_contexts.pop(), name)
def test_pre_wrap(self):
# A pre-wrapped callback is run in the context in which it was
# wrapped, not when it was added to the IOLoop.
def f1():
self.assertIn('c1', self.active_contexts)
self.assertNotIn('c2', self.active_contexts)
self.stop()
with StackContext(functools.partial(self.context, 'c1')):
wrapped = wrap(f1)
with StackContext(functools.partial(self.context, 'c2')):
self.add_callback(wrapped)
self.wait()
def test_pre_wrap_with_args(self):
# Same as test_pre_wrap, but the function takes arguments.
# Implementation note: The function must not be wrapped in a
# functools.partial until after it has been passed through
# stack_context.wrap
def f1(foo, bar):
self.assertIn('c1', self.active_contexts)
self.assertNotIn('c2', self.active_contexts)
self.stop((foo, bar))
with StackContext(functools.partial(self.context, 'c1')):
wrapped = wrap(f1)
with StackContext(functools.partial(self.context, 'c2')):
self.add_callback(wrapped, 1, bar=2)
result = self.wait()
self.assertEqual(result, (1, 2))
class TestIOLoopAddCallbackFromSignal(TestIOLoopAddCallback):
# Repeat the add_callback tests using add_callback_from_signal
def add_callback(self, callback, *args, **kwargs):
self.io_loop.add_callback_from_signal(callback, *args, **kwargs)
@unittest.skipIf(futures is None, "futures module not present")
class TestIOLoopFutures(AsyncTestCase):
def test_add_future_threads(self):
with futures.ThreadPoolExecutor(1) as pool:
self.io_loop.add_future(pool.submit(lambda: None),
lambda future: self.stop(future))
future = self.wait()
self.assertTrue(future.done())
self.assertTrue(future.result() is None)
def test_add_future_stack_context(self):
ready = threading.Event()
def task():
# we must wait for the ioloop callback to be scheduled before
# the task completes to ensure that add_future adds the callback
# asynchronously (which is the scenario in which capturing
# the stack_context matters)
ready.wait(1)
assert ready.isSet(), "timed out"
raise Exception("worker")
def callback(future):
self.future = future
raise Exception("callback")
def handle_exception(typ, value, traceback):
self.exception = value
self.stop()
return True
# stack_context propagates to the ioloop callback, but the worker
# task just has its exceptions caught and saved in the Future.
with futures.ThreadPoolExecutor(1) as pool:
with ExceptionStackContext(handle_exception):
self.io_loop.add_future(pool.submit(task), callback)
ready.set()
self.wait()
self.assertEqual(self.exception.args[0], "callback")
self.assertEqual(self.future.exception().args[0], "worker")
class TestIOLoopRunSync(unittest.TestCase):
def setUp(self):
self.io_loop = IOLoop()
def tearDown(self):
self.io_loop.close()
def test_sync_result(self):
self.assertEqual(self.io_loop.run_sync(lambda: 42), 42)
def test_sync_exception(self):
with self.assertRaises(ZeroDivisionError):
self.io_loop.run_sync(lambda: 1 / 0)
def test_async_result(self):
@gen.coroutine
def f():
yield gen.Task(self.io_loop.add_callback)
raise gen.Return(42)
self.assertEqual(self.io_loop.run_sync(f), 42)
def test_async_exception(self):
@gen.coroutine
def f():
yield gen.Task(self.io_loop.add_callback)
1 / 0
with self.assertRaises(ZeroDivisionError):
self.io_loop.run_sync(f)
def test_current(self):
def f():
self.assertIs(IOLoop.current(), self.io_loop)
self.io_loop.run_sync(f)
def test_timeout(self):
@gen.coroutine
def f():
yield gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1)
self.assertRaises(TimeoutError, self.io_loop.run_sync, f, timeout=0.01)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,907 @@
from __future__ import absolute_import, division, print_function, with_statement
from tornado.concurrent import Future
from tornado import gen
from tornado import netutil
from tornado.iostream import IOStream, SSLIOStream, PipeIOStream, StreamClosedError
from tornado.httputil import HTTPHeaders
from tornado.log import gen_log, app_log
from tornado.netutil import ssl_wrap_socket
from tornado.stack_context import NullContext
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test
from tornado.test.util import unittest, skipIfNonUnix
from tornado.web import RequestHandler, Application
import certifi
import errno
import logging
import os
import platform
import socket
import ssl
import sys
def _server_ssl_options():
return dict(
certfile=os.path.join(os.path.dirname(__file__), 'test.crt'),
keyfile=os.path.join(os.path.dirname(__file__), 'test.key'),
)
class HelloHandler(RequestHandler):
def get(self):
self.write("Hello")
class TestIOStreamWebMixin(object):
def _make_client_iostream(self):
raise NotImplementedError()
def get_app(self):
return Application([('/', HelloHandler)])
def test_connection_closed(self):
# When a server sends a response and then closes the connection,
# the client must be allowed to read the data before the IOStream
# closes itself. Epoll reports closed connections with a separate
# EPOLLRDHUP event delivered at the same time as the read event,
# while kqueue reports them as a second read/write event with an EOF
# flag.
response = self.fetch("/", headers={"Connection": "close"})
response.rethrow()
def test_read_until_close(self):
stream = self._make_client_iostream()
stream.connect(('localhost', self.get_http_port()), callback=self.stop)
self.wait()
stream.write(b"GET / HTTP/1.0\r\n\r\n")
stream.read_until_close(self.stop)
data = self.wait()
self.assertTrue(data.startswith(b"HTTP/1.0 200"))
self.assertTrue(data.endswith(b"Hello"))
def test_read_zero_bytes(self):
self.stream = self._make_client_iostream()
self.stream.connect(("localhost", self.get_http_port()),
callback=self.stop)
self.wait()
self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
# normal read
self.stream.read_bytes(9, self.stop)
data = self.wait()
self.assertEqual(data, b"HTTP/1.0 ")
# zero bytes
self.stream.read_bytes(0, self.stop)
data = self.wait()
self.assertEqual(data, b"")
# another normal read
self.stream.read_bytes(3, self.stop)
data = self.wait()
self.assertEqual(data, b"200")
self.stream.close()
def test_write_while_connecting(self):
stream = self._make_client_iostream()
connected = [False]
def connected_callback():
connected[0] = True
self.stop()
stream.connect(("localhost", self.get_http_port()),
callback=connected_callback)
# unlike the previous tests, try to write before the connection
# is complete.
written = [False]
def write_callback():
written[0] = True
self.stop()
stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n",
callback=write_callback)
self.assertTrue(not connected[0])
# by the time the write has flushed, the connection callback has
# also run
try:
self.wait(lambda: connected[0] and written[0])
finally:
logging.debug((connected, written))
stream.read_until_close(self.stop)
data = self.wait()
self.assertTrue(data.endswith(b"Hello"))
stream.close()
@gen_test
def test_future_interface(self):
"""Basic test of IOStream's ability to return Futures."""
stream = self._make_client_iostream()
connect_result = yield stream.connect(
("localhost", self.get_http_port()))
self.assertIs(connect_result, stream)
yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
first_line = yield stream.read_until(b"\r\n")
self.assertEqual(first_line, b"HTTP/1.0 200 OK\r\n")
# callback=None is equivalent to no callback.
header_data = yield stream.read_until(b"\r\n\r\n", callback=None)
headers = HTTPHeaders.parse(header_data.decode('latin1'))
content_length = int(headers['Content-Length'])
body = yield stream.read_bytes(content_length)
self.assertEqual(body, b'Hello')
stream.close()
@gen_test
def test_future_close_while_reading(self):
stream = self._make_client_iostream()
yield stream.connect(("localhost", self.get_http_port()))
yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
with self.assertRaises(StreamClosedError):
yield stream.read_bytes(1024 * 1024)
stream.close()
@gen_test
def test_future_read_until_close(self):
# Ensure that the data comes through before the StreamClosedError.
stream = self._make_client_iostream()
yield stream.connect(("localhost", self.get_http_port()))
yield stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n")
yield stream.read_until(b"\r\n\r\n")
body = yield stream.read_until_close()
self.assertEqual(body, b"Hello")
# Nothing else to read; the error comes immediately without waiting
# for yield.
with self.assertRaises(StreamClosedError):
stream.read_bytes(1)
class TestIOStreamMixin(object):
def _make_server_iostream(self, connection, **kwargs):
raise NotImplementedError()
def _make_client_iostream(self, connection, **kwargs):
raise NotImplementedError()
def make_iostream_pair(self, **kwargs):
listener, port = bind_unused_port()
streams = [None, None]
def accept_callback(connection, address):
streams[0] = self._make_server_iostream(connection, **kwargs)
self.stop()
def connect_callback():
streams[1] = client_stream
self.stop()
netutil.add_accept_handler(listener, accept_callback,
io_loop=self.io_loop)
client_stream = self._make_client_iostream(socket.socket(), **kwargs)
client_stream.connect(('127.0.0.1', port),
callback=connect_callback)
self.wait(condition=lambda: all(streams))
self.io_loop.remove_handler(listener.fileno())
listener.close()
return streams
def test_streaming_callback_with_data_in_buffer(self):
server, client = self.make_iostream_pair()
client.write(b"abcd\r\nefgh")
server.read_until(b"\r\n", self.stop)
data = self.wait()
self.assertEqual(data, b"abcd\r\n")
def closed_callback(chunk):
self.fail()
server.read_until_close(callback=closed_callback,
streaming_callback=self.stop)
# self.io_loop.add_timeout(self.io_loop.time() + 0.01, self.stop)
data = self.wait()
self.assertEqual(data, b"efgh")
server.close()
client.close()
def test_write_zero_bytes(self):
# Attempting to write zero bytes should run the callback without
# going into an infinite loop.
server, client = self.make_iostream_pair()
server.write(b'', callback=self.stop)
self.wait()
server.close()
client.close()
def test_connection_refused(self):
# When a connection is refused, the connect callback should not
# be run. (The kqueue IOLoop used to behave differently from the
# epoll IOLoop in this respect)
server_socket, port = bind_unused_port()
server_socket.close()
stream = IOStream(socket.socket(), self.io_loop)
self.connect_called = False
def connect_callback():
self.connect_called = True
stream.set_close_callback(self.stop)
# log messages vary by platform and ioloop implementation
with ExpectLog(gen_log, ".*", required=False):
stream.connect(("localhost", port), connect_callback)
self.wait()
self.assertFalse(self.connect_called)
self.assertTrue(isinstance(stream.error, socket.error), stream.error)
if sys.platform != 'cygwin':
_ERRNO_CONNREFUSED = (errno.ECONNREFUSED,)
if hasattr(errno, "WSAECONNREFUSED"):
_ERRNO_CONNREFUSED += (errno.WSAECONNREFUSED,)
# cygwin's errnos don't match those used on native windows python
self.assertTrue(stream.error.args[0] in _ERRNO_CONNREFUSED)
def test_gaierror(self):
# Test that IOStream sets its exc_info on getaddrinfo error
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
stream = IOStream(s, io_loop=self.io_loop)
stream.set_close_callback(self.stop)
# To reliably generate a gaierror we use a malformed domain name
# instead of a name that's simply unlikely to exist (since
# opendns and some ISPs return bogus addresses for nonexistent
# domains instead of the proper error codes).
with ExpectLog(gen_log, "Connect error"):
stream.connect(('an invalid domain', 54321))
self.assertTrue(isinstance(stream.error, socket.gaierror), stream.error)
def test_read_callback_error(self):
# Test that IOStream sets its exc_info when a read callback throws
server, client = self.make_iostream_pair()
try:
server.set_close_callback(self.stop)
with ExpectLog(
app_log, "(Uncaught exception|Exception in callback)"
):
# Clear ExceptionStackContext so IOStream catches error
with NullContext():
server.read_bytes(1, callback=lambda data: 1 / 0)
client.write(b"1")
self.wait()
self.assertTrue(isinstance(server.error, ZeroDivisionError))
finally:
server.close()
client.close()
def test_streaming_callback(self):
server, client = self.make_iostream_pair()
try:
chunks = []
final_called = []
def streaming_callback(data):
chunks.append(data)
self.stop()
def final_callback(data):
self.assertFalse(data)
final_called.append(True)
self.stop()
server.read_bytes(6, callback=final_callback,
streaming_callback=streaming_callback)
client.write(b"1234")
self.wait(condition=lambda: chunks)
client.write(b"5678")
self.wait(condition=lambda: final_called)
self.assertEqual(chunks, [b"1234", b"56"])
# the rest of the last chunk is still in the buffer
server.read_bytes(2, callback=self.stop)
data = self.wait()
self.assertEqual(data, b"78")
finally:
server.close()
client.close()
def test_streaming_until_close(self):
server, client = self.make_iostream_pair()
try:
chunks = []
closed = [False]
def streaming_callback(data):
chunks.append(data)
self.stop()
def close_callback(data):
assert not data, data
closed[0] = True
self.stop()
client.read_until_close(callback=close_callback,
streaming_callback=streaming_callback)
server.write(b"1234")
self.wait(condition=lambda: len(chunks) == 1)
server.write(b"5678", self.stop)
self.wait()
server.close()
self.wait(condition=lambda: closed[0])
self.assertEqual(chunks, [b"1234", b"5678"])
finally:
server.close()
client.close()
def test_delayed_close_callback(self):
# The scenario: Server closes the connection while there is a pending
# read that can be served out of buffered data. The client does not
# run the close_callback as soon as it detects the close, but rather
# defers it until after the buffered read has finished.
server, client = self.make_iostream_pair()
try:
client.set_close_callback(self.stop)
server.write(b"12")
chunks = []
def callback1(data):
chunks.append(data)
client.read_bytes(1, callback2)
server.close()
def callback2(data):
chunks.append(data)
client.read_bytes(1, callback1)
self.wait() # stopped by close_callback
self.assertEqual(chunks, [b"1", b"2"])
finally:
server.close()
client.close()
def test_future_delayed_close_callback(self):
# Same as test_delayed_close_callback, but with the future interface.
server, client = self.make_iostream_pair()
# We can't call make_iostream_pair inside a gen_test function
# because the ioloop is not reentrant.
@gen_test
def f(self):
server.write(b"12")
chunks = []
chunks.append((yield client.read_bytes(1)))
server.close()
chunks.append((yield client.read_bytes(1)))
self.assertEqual(chunks, [b"1", b"2"])
try:
f(self)
finally:
server.close()
client.close()
def test_close_buffered_data(self):
# Similar to the previous test, but with data stored in the OS's
# socket buffers instead of the IOStream's read buffer. Out-of-band
# close notifications must be delayed until all data has been
# drained into the IOStream buffer. (epoll used to use out-of-band
# close events with EPOLLRDHUP, but no longer)
#
# This depends on the read_chunk_size being smaller than the
# OS socket buffer, so make it small.
server, client = self.make_iostream_pair(read_chunk_size=256)
try:
server.write(b"A" * 512)
client.read_bytes(256, self.stop)
data = self.wait()
self.assertEqual(b"A" * 256, data)
server.close()
# Allow the close to propagate to the client side of the
# connection. Using add_callback instead of add_timeout
# doesn't seem to work, even with multiple iterations
self.io_loop.add_timeout(self.io_loop.time() + 0.01, self.stop)
self.wait()
client.read_bytes(256, self.stop)
data = self.wait()
self.assertEqual(b"A" * 256, data)
finally:
server.close()
client.close()
def test_read_until_close_after_close(self):
# Similar to test_delayed_close_callback, but read_until_close takes
# a separate code path so test it separately.
server, client = self.make_iostream_pair()
try:
server.write(b"1234")
server.close()
# Read one byte to make sure the client has received the data.
# It won't run the close callback as long as there is more buffered
# data that could satisfy a later read.
client.read_bytes(1, self.stop)
data = self.wait()
self.assertEqual(data, b"1")
client.read_until_close(self.stop)
data = self.wait()
self.assertEqual(data, b"234")
finally:
server.close()
client.close()
def test_streaming_read_until_close_after_close(self):
# Same as the preceding test but with a streaming_callback.
# All data should go through the streaming callback,
# and the final read callback just gets an empty string.
server, client = self.make_iostream_pair()
try:
server.write(b"1234")
server.close()
client.read_bytes(1, self.stop)
data = self.wait()
self.assertEqual(data, b"1")
streaming_data = []
client.read_until_close(self.stop,
streaming_callback=streaming_data.append)
data = self.wait()
self.assertEqual(b'', data)
self.assertEqual(b''.join(streaming_data), b"234")
finally:
server.close()
client.close()
def test_large_read_until(self):
# Performance test: read_until used to have a quadratic component
# so a read_until of 4MB would take 8 seconds; now it takes 0.25
# seconds.
server, client = self.make_iostream_pair()
try:
# This test fails on pypy with ssl. I think it's because
# pypy's gc defeats moves objects, breaking the
# "frozen write buffer" assumption.
if (isinstance(server, SSLIOStream) and
platform.python_implementation() == 'PyPy'):
raise unittest.SkipTest(
"pypy gc causes problems with openssl")
NUM_KB = 4096
for i in range(NUM_KB):
client.write(b"A" * 1024)
client.write(b"\r\n")
server.read_until(b"\r\n", self.stop)
data = self.wait()
self.assertEqual(len(data), NUM_KB * 1024 + 2)
finally:
server.close()
client.close()
def test_close_callback_with_pending_read(self):
# Regression test for a bug that was introduced in 2.3
# where the IOStream._close_callback would never be called
# if there were pending reads.
OK = b"OK\r\n"
server, client = self.make_iostream_pair()
client.set_close_callback(self.stop)
try:
server.write(OK)
client.read_until(b"\r\n", self.stop)
res = self.wait()
self.assertEqual(res, OK)
server.close()
client.read_until(b"\r\n", lambda x: x)
# If _close_callback (self.stop) is not called,
# an AssertionError: Async operation timed out after 5 seconds
# will be raised.
res = self.wait()
self.assertTrue(res is None)
finally:
server.close()
client.close()
@skipIfNonUnix
def test_inline_read_error(self):
# An error on an inline read is raised without logging (on the
# assumption that it will eventually be noticed or logged further
# up the stack).
#
# This test is posix-only because windows os.close() doesn't work
# on socket FDs, but we can't close the socket object normally
# because we won't get the error we want if the socket knows
# it's closed.
server, client = self.make_iostream_pair()
try:
os.close(server.socket.fileno())
with self.assertRaises(socket.error):
server.read_bytes(1, lambda data: None)
finally:
server.close()
client.close()
def test_async_read_error_logging(self):
# Socket errors on asynchronous reads should be logged (but only
# once).
server, client = self.make_iostream_pair()
server.set_close_callback(self.stop)
try:
# Start a read that will be fullfilled asynchronously.
server.read_bytes(1, lambda data: None)
client.write(b'a')
# Stub out read_from_fd to make it fail.
def fake_read_from_fd():
os.close(server.socket.fileno())
server.__class__.read_from_fd(server)
server.read_from_fd = fake_read_from_fd
# This log message is from _handle_read (not read_from_fd).
with ExpectLog(gen_log, "error on read"):
self.wait()
finally:
server.close()
client.close()
def test_future_close_callback(self):
# Regression test for interaction between the Future read interfaces
# and IOStream._maybe_add_error_listener.
server, client = self.make_iostream_pair()
closed = [False]
def close_callback():
closed[0] = True
self.stop()
server.set_close_callback(close_callback)
try:
client.write(b'a')
future = server.read_bytes(1)
self.io_loop.add_future(future, self.stop)
self.assertEqual(self.wait().result(), b'a')
self.assertFalse(closed[0])
client.close()
self.wait()
self.assertTrue(closed[0])
finally:
server.close()
client.close()
def test_read_bytes_partial(self):
server, client = self.make_iostream_pair()
try:
# Ask for more than is available with partial=True
client.read_bytes(50, self.stop, partial=True)
server.write(b"hello")
data = self.wait()
self.assertEqual(data, b"hello")
# Ask for less than what is available; num_bytes is still
# respected.
client.read_bytes(3, self.stop, partial=True)
server.write(b"world")
data = self.wait()
self.assertEqual(data, b"wor")
# Partial reads won't return an empty string, but read_bytes(0)
# will.
client.read_bytes(0, self.stop, partial=True)
data = self.wait()
self.assertEqual(data, b'')
finally:
server.close()
client.close()
def test_read_until_max_bytes(self):
server, client = self.make_iostream_pair()
client.set_close_callback(lambda: self.stop("closed"))
try:
# Extra room under the limit
client.read_until(b"def", self.stop, max_bytes=50)
server.write(b"abcdef")
data = self.wait()
self.assertEqual(data, b"abcdef")
# Just enough space
client.read_until(b"def", self.stop, max_bytes=6)
server.write(b"abcdef")
data = self.wait()
self.assertEqual(data, b"abcdef")
# Not enough space, but we don't know it until all we can do is
# log a warning and close the connection.
with ExpectLog(gen_log, "Unsatisfiable read"):
client.read_until(b"def", self.stop, max_bytes=5)
server.write(b"123456")
data = self.wait()
self.assertEqual(data, "closed")
finally:
server.close()
client.close()
def test_read_until_max_bytes_inline(self):
server, client = self.make_iostream_pair()
client.set_close_callback(lambda: self.stop("closed"))
try:
# Similar to the error case in the previous test, but the
# server writes first so client reads are satisfied
# inline. For consistency with the out-of-line case, we
# do not raise the error synchronously.
server.write(b"123456")
with ExpectLog(gen_log, "Unsatisfiable read"):
client.read_until(b"def", self.stop, max_bytes=5)
data = self.wait()
self.assertEqual(data, "closed")
finally:
server.close()
client.close()
def test_read_until_max_bytes_ignores_extra(self):
server, client = self.make_iostream_pair()
client.set_close_callback(lambda: self.stop("closed"))
try:
# Even though data that matches arrives the same packet that
# puts us over the limit, we fail the request because it was not
# found within the limit.
server.write(b"abcdef")
with ExpectLog(gen_log, "Unsatisfiable read"):
client.read_until(b"def", self.stop, max_bytes=5)
data = self.wait()
self.assertEqual(data, "closed")
finally:
server.close()
client.close()
def test_read_until_regex_max_bytes(self):
server, client = self.make_iostream_pair()
client.set_close_callback(lambda: self.stop("closed"))
try:
# Extra room under the limit
client.read_until_regex(b"def", self.stop, max_bytes=50)
server.write(b"abcdef")
data = self.wait()
self.assertEqual(data, b"abcdef")
# Just enough space
client.read_until_regex(b"def", self.stop, max_bytes=6)
server.write(b"abcdef")
data = self.wait()
self.assertEqual(data, b"abcdef")
# Not enough space, but we don't know it until all we can do is
# log a warning and close the connection.
with ExpectLog(gen_log, "Unsatisfiable read"):
client.read_until_regex(b"def", self.stop, max_bytes=5)
server.write(b"123456")
data = self.wait()
self.assertEqual(data, "closed")
finally:
server.close()
client.close()
def test_read_until_regex_max_bytes_inline(self):
server, client = self.make_iostream_pair()
client.set_close_callback(lambda: self.stop("closed"))
try:
# Similar to the error case in the previous test, but the
# server writes first so client reads are satisfied
# inline. For consistency with the out-of-line case, we
# do not raise the error synchronously.
server.write(b"123456")
with ExpectLog(gen_log, "Unsatisfiable read"):
client.read_until_regex(b"def", self.stop, max_bytes=5)
data = self.wait()
self.assertEqual(data, "closed")
finally:
server.close()
client.close()
def test_read_until_regex_max_bytes_ignores_extra(self):
server, client = self.make_iostream_pair()
client.set_close_callback(lambda: self.stop("closed"))
try:
# Even though data that matches arrives the same packet that
# puts us over the limit, we fail the request because it was not
# found within the limit.
server.write(b"abcdef")
with ExpectLog(gen_log, "Unsatisfiable read"):
client.read_until_regex(b"def", self.stop, max_bytes=5)
data = self.wait()
self.assertEqual(data, "closed")
finally:
server.close()
client.close()
def test_small_reads_from_large_buffer(self):
# 10KB buffer size, 100KB available to read.
# Read 1KB at a time and make sure that the buffer is not eagerly
# filled.
server, client = self.make_iostream_pair(max_buffer_size=10 * 1024)
try:
server.write(b"a" * 1024 * 100)
for i in range(100):
client.read_bytes(1024, self.stop)
data = self.wait()
self.assertEqual(data, b"a" * 1024)
finally:
server.close()
client.close()
def test_small_read_untils_from_large_buffer(self):
# 10KB buffer size, 100KB available to read.
# Read 1KB at a time and make sure that the buffer is not eagerly
# filled.
server, client = self.make_iostream_pair(max_buffer_size=10 * 1024)
try:
server.write((b"a" * 1023 + b"\n") * 100)
for i in range(100):
client.read_until(b"\n", self.stop, max_bytes=4096)
data = self.wait()
self.assertEqual(data, b"a" * 1023 + b"\n")
finally:
server.close()
client.close()
class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase):
def _make_client_iostream(self):
return IOStream(socket.socket(), io_loop=self.io_loop)
class TestIOStreamWebHTTPS(TestIOStreamWebMixin, AsyncHTTPSTestCase):
def _make_client_iostream(self):
return SSLIOStream(socket.socket(), io_loop=self.io_loop)
class TestIOStream(TestIOStreamMixin, AsyncTestCase):
def _make_server_iostream(self, connection, **kwargs):
return IOStream(connection, **kwargs)
def _make_client_iostream(self, connection, **kwargs):
return IOStream(connection, **kwargs)
class TestIOStreamSSL(TestIOStreamMixin, AsyncTestCase):
def _make_server_iostream(self, connection, **kwargs):
connection = ssl.wrap_socket(connection,
server_side=True,
do_handshake_on_connect=False,
**_server_ssl_options())
return SSLIOStream(connection, io_loop=self.io_loop, **kwargs)
def _make_client_iostream(self, connection, **kwargs):
return SSLIOStream(connection, io_loop=self.io_loop, **kwargs)
# This will run some tests that are basically redundant but it's the
# simplest way to make sure that it works to pass an SSLContext
# instead of an ssl_options dict to the SSLIOStream constructor.
@unittest.skipIf(not hasattr(ssl, 'SSLContext'), 'ssl.SSLContext not present')
class TestIOStreamSSLContext(TestIOStreamMixin, AsyncTestCase):
def _make_server_iostream(self, connection, **kwargs):
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
context.load_cert_chain(
os.path.join(os.path.dirname(__file__), 'test.crt'),
os.path.join(os.path.dirname(__file__), 'test.key'))
connection = ssl_wrap_socket(connection, context,
server_side=True,
do_handshake_on_connect=False)
return SSLIOStream(connection, io_loop=self.io_loop, **kwargs)
def _make_client_iostream(self, connection, **kwargs):
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
return SSLIOStream(connection, io_loop=self.io_loop,
ssl_options=context, **kwargs)
class TestIOStreamStartTLS(AsyncTestCase):
def setUp(self):
try:
super(TestIOStreamStartTLS, self).setUp()
self.listener, self.port = bind_unused_port()
self.server_stream = None
self.server_accepted = Future()
netutil.add_accept_handler(self.listener, self.accept)
self.client_stream = IOStream(socket.socket())
self.io_loop.add_future(self.client_stream.connect(
('127.0.0.1', self.port)), self.stop)
self.wait()
self.io_loop.add_future(self.server_accepted, self.stop)
self.wait()
except Exception as e:
print(e)
raise
def tearDown(self):
if self.server_stream is not None:
self.server_stream.close()
if self.client_stream is not None:
self.client_stream.close()
self.listener.close()
super(TestIOStreamStartTLS, self).tearDown()
def accept(self, connection, address):
if self.server_stream is not None:
self.fail("should only get one connection")
self.server_stream = IOStream(connection)
self.server_accepted.set_result(None)
@gen.coroutine
def client_send_line(self, line):
self.client_stream.write(line)
recv_line = yield self.server_stream.read_until(b"\r\n")
self.assertEqual(line, recv_line)
@gen.coroutine
def server_send_line(self, line):
self.server_stream.write(line)
recv_line = yield self.client_stream.read_until(b"\r\n")
self.assertEqual(line, recv_line)
def client_start_tls(self, ssl_options=None):
client_stream = self.client_stream
self.client_stream = None
return client_stream.start_tls(False, ssl_options)
def server_start_tls(self, ssl_options=None):
server_stream = self.server_stream
self.server_stream = None
return server_stream.start_tls(True, ssl_options)
@gen_test
def test_start_tls_smtp(self):
# This flow is simplified from RFC 3207 section 5.
# We don't really need all of this, but it helps to make sure
# that after realistic back-and-forth traffic the buffers end up
# in a sane state.
yield self.server_send_line(b"220 mail.example.com ready\r\n")
yield self.client_send_line(b"EHLO mail.example.com\r\n")
yield self.server_send_line(b"250-mail.example.com welcome\r\n")
yield self.server_send_line(b"250 STARTTLS\r\n")
yield self.client_send_line(b"STARTTLS\r\n")
yield self.server_send_line(b"220 Go ahead\r\n")
client_future = self.client_start_tls()
server_future = self.server_start_tls(_server_ssl_options())
self.client_stream = yield client_future
self.server_stream = yield server_future
self.assertTrue(isinstance(self.client_stream, SSLIOStream))
self.assertTrue(isinstance(self.server_stream, SSLIOStream))
yield self.client_send_line(b"EHLO mail.example.com\r\n")
yield self.server_send_line(b"250 mail.example.com welcome\r\n")
@gen_test
def test_handshake_fail(self):
self.server_start_tls(_server_ssl_options())
client_future = self.client_start_tls(
dict(cert_reqs=ssl.CERT_REQUIRED, ca_certs=certifi.where()))
with ExpectLog(gen_log, "SSL Error"):
with self.assertRaises(ssl.SSLError):
yield client_future
@skipIfNonUnix
class TestPipeIOStream(AsyncTestCase):
def test_pipe_iostream(self):
r, w = os.pipe()
rs = PipeIOStream(r, io_loop=self.io_loop)
ws = PipeIOStream(w, io_loop=self.io_loop)
ws.write(b"hel")
ws.write(b"lo world")
rs.read_until(b' ', callback=self.stop)
data = self.wait()
self.assertEqual(data, b"hello ")
rs.read_bytes(3, self.stop)
data = self.wait()
self.assertEqual(data, b"wor")
ws.close()
rs.read_until_close(self.stop)
data = self.wait()
self.assertEqual(data, b"ld")
rs.close()
def test_pipe_iostream_big_write(self):
r, w = os.pipe()
rs = PipeIOStream(r, io_loop=self.io_loop)
ws = PipeIOStream(w, io_loop=self.io_loop)
NUM_BYTES = 1048576
# Write 1MB of data, which should fill the buffer
ws.write(b"1" * NUM_BYTES)
rs.read_bytes(NUM_BYTES, self.stop)
data = self.wait()
self.assertEqual(data, b"1" * NUM_BYTES)
ws.close()
rs.close()

View file

@ -0,0 +1,59 @@
from __future__ import absolute_import, division, print_function, with_statement
import datetime
import os
import tornado.locale
from tornado.escape import utf8
from tornado.test.util import unittest
from tornado.util import u, unicode_type
class TranslationLoaderTest(unittest.TestCase):
# TODO: less hacky way to get isolated tests
SAVE_VARS = ['_translations', '_supported_locales', '_use_gettext']
def clear_locale_cache(self):
if hasattr(tornado.locale.Locale, '_cache'):
del tornado.locale.Locale._cache
def setUp(self):
self.saved = {}
for var in TranslationLoaderTest.SAVE_VARS:
self.saved[var] = getattr(tornado.locale, var)
self.clear_locale_cache()
def tearDown(self):
for k, v in self.saved.items():
setattr(tornado.locale, k, v)
self.clear_locale_cache()
def test_csv(self):
tornado.locale.load_translations(
os.path.join(os.path.dirname(__file__), 'csv_translations'))
locale = tornado.locale.get("fr_FR")
self.assertTrue(isinstance(locale, tornado.locale.CSVLocale))
self.assertEqual(locale.translate("school"), u("\u00e9cole"))
def test_gettext(self):
tornado.locale.load_gettext_translations(
os.path.join(os.path.dirname(__file__), 'gettext_translations'),
"tornado_test")
locale = tornado.locale.get("fr_FR")
self.assertTrue(isinstance(locale, tornado.locale.GettextLocale))
self.assertEqual(locale.translate("school"), u("\u00e9cole"))
class LocaleDataTest(unittest.TestCase):
def test_non_ascii_name(self):
name = tornado.locale.LOCALE_NAMES['es_LA']['name']
self.assertTrue(isinstance(name, unicode_type))
self.assertEqual(name, u('Espa\u00f1ol'))
self.assertEqual(utf8(name), b'Espa\xc3\xb1ol')
class EnglishTest(unittest.TestCase):
def test_format_date(self):
locale = tornado.locale.get('en_US')
date = datetime.datetime(2013, 4, 28, 18, 35)
self.assertEqual(locale.format_date(date, full_format=True),
'April 28, 2013 at 6:35 pm')

View file

@ -0,0 +1,207 @@
#!/usr/bin/env python
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, with_statement
import contextlib
import glob
import logging
import os
import re
import subprocess
import sys
import tempfile
import warnings
from tornado.escape import utf8
from tornado.log import LogFormatter, define_logging_options, enable_pretty_logging
from tornado.options import OptionParser
from tornado.test.util import unittest
from tornado.util import u, bytes_type, basestring_type
@contextlib.contextmanager
def ignore_bytes_warning():
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=BytesWarning)
yield
class LogFormatterTest(unittest.TestCase):
# Matches the output of a single logging call (which may be multiple lines
# if a traceback was included, so we use the DOTALL option)
LINE_RE = re.compile(b"(?s)\x01\\[E [0-9]{6} [0-9]{2}:[0-9]{2}:[0-9]{2} log_test:[0-9]+\\]\x02 (.*)")
def setUp(self):
self.formatter = LogFormatter(color=False)
# Fake color support. We can't guarantee anything about the $TERM
# variable when the tests are run, so just patch in some values
# for testing. (testing with color off fails to expose some potential
# encoding issues from the control characters)
self.formatter._colors = {
logging.ERROR: u("\u0001"),
}
self.formatter._normal = u("\u0002")
# construct a Logger directly to bypass getLogger's caching
self.logger = logging.Logger('LogFormatterTest')
self.logger.propagate = False
self.tempdir = tempfile.mkdtemp()
self.filename = os.path.join(self.tempdir, 'log.out')
self.handler = self.make_handler(self.filename)
self.handler.setFormatter(self.formatter)
self.logger.addHandler(self.handler)
def tearDown(self):
self.handler.close()
os.unlink(self.filename)
os.rmdir(self.tempdir)
def make_handler(self, filename):
# Base case: default setup without explicit encoding.
# In python 2, supports arbitrary byte strings and unicode objects
# that contain only ascii. In python 3, supports ascii-only unicode
# strings (but byte strings will be repr'd automatically).
return logging.FileHandler(filename)
def get_output(self):
with open(self.filename, "rb") as f:
line = f.read().strip()
m = LogFormatterTest.LINE_RE.match(line)
if m:
return m.group(1)
else:
raise Exception("output didn't match regex: %r" % line)
def test_basic_logging(self):
self.logger.error("foo")
self.assertEqual(self.get_output(), b"foo")
def test_bytes_logging(self):
with ignore_bytes_warning():
# This will be "\xe9" on python 2 or "b'\xe9'" on python 3
self.logger.error(b"\xe9")
self.assertEqual(self.get_output(), utf8(repr(b"\xe9")))
def test_utf8_logging(self):
self.logger.error(u("\u00e9").encode("utf8"))
if issubclass(bytes_type, basestring_type):
# on python 2, utf8 byte strings (and by extension ascii byte
# strings) are passed through as-is.
self.assertEqual(self.get_output(), utf8(u("\u00e9")))
else:
# on python 3, byte strings always get repr'd even if
# they're ascii-only, so this degenerates into another
# copy of test_bytes_logging.
self.assertEqual(self.get_output(), utf8(repr(utf8(u("\u00e9")))))
def test_bytes_exception_logging(self):
try:
raise Exception(b'\xe9')
except Exception:
self.logger.exception('caught exception')
# This will be "Exception: \xe9" on python 2 or
# "Exception: b'\xe9'" on python 3.
output = self.get_output()
self.assertRegexpMatches(output, br'Exception.*\\xe9')
# The traceback contains newlines, which should not have been escaped.
self.assertNotIn(br'\n', output)
class UnicodeLogFormatterTest(LogFormatterTest):
def make_handler(self, filename):
# Adding an explicit encoding configuration allows non-ascii unicode
# strings in both python 2 and 3, without changing the behavior
# for byte strings.
return logging.FileHandler(filename, encoding="utf8")
def test_unicode_logging(self):
self.logger.error(u("\u00e9"))
self.assertEqual(self.get_output(), utf8(u("\u00e9")))
class EnablePrettyLoggingTest(unittest.TestCase):
def setUp(self):
super(EnablePrettyLoggingTest, self).setUp()
self.options = OptionParser()
define_logging_options(self.options)
self.logger = logging.Logger('tornado.test.log_test.EnablePrettyLoggingTest')
self.logger.propagate = False
def test_log_file(self):
tmpdir = tempfile.mkdtemp()
try:
self.options.log_file_prefix = tmpdir + '/test_log'
enable_pretty_logging(options=self.options, logger=self.logger)
self.assertEqual(1, len(self.logger.handlers))
self.logger.error('hello')
self.logger.handlers[0].flush()
filenames = glob.glob(tmpdir + '/test_log*')
self.assertEqual(1, len(filenames))
with open(filenames[0]) as f:
self.assertRegexpMatches(f.read(), r'^\[E [^]]*\] hello$')
finally:
for handler in self.logger.handlers:
handler.flush()
handler.close()
for filename in glob.glob(tmpdir + '/test_log*'):
os.unlink(filename)
os.rmdir(tmpdir)
class LoggingOptionTest(unittest.TestCase):
"""Test the ability to enable and disable Tornado's logging hooks."""
def logs_present(self, statement, args=None):
# Each test may manipulate and/or parse the options and then logs
# a line at the 'info' level. This level is ignored in the
# logging module by default, but Tornado turns it on by default
# so it is the easiest way to tell whether tornado's logging hooks
# ran.
IMPORT = 'from tornado.options import options, parse_command_line'
LOG_INFO = 'import logging; logging.info("hello")'
program = ';'.join([IMPORT, statement, LOG_INFO])
proc = subprocess.Popen(
[sys.executable, '-c', program] + (args or []),
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
stdout, stderr = proc.communicate()
self.assertEqual(proc.returncode, 0, 'process failed: %r' % stdout)
return b'hello' in stdout
def test_default(self):
self.assertFalse(self.logs_present('pass'))
def test_tornado_default(self):
self.assertTrue(self.logs_present('parse_command_line()'))
def test_disable_command_line(self):
self.assertFalse(self.logs_present('parse_command_line()',
['--logging=none']))
def test_disable_command_line_case_insensitive(self):
self.assertFalse(self.logs_present('parse_command_line()',
['--logging=None']))
def test_disable_code_string(self):
self.assertFalse(self.logs_present(
'options.logging = "none"; parse_command_line()'))
def test_disable_code_none(self):
self.assertFalse(self.logs_present(
'options.logging = None; parse_command_line()'))
def test_disable_override(self):
# command line trumps code defaults
self.assertTrue(self.logs_present(
'options.logging = None; parse_command_line()',
['--logging=info']))

View file

@ -0,0 +1,168 @@
from __future__ import absolute_import, division, print_function, with_statement
import os
import signal
import socket
from subprocess import Popen
import sys
import time
from tornado.netutil import BlockingResolver, ThreadedResolver, is_valid_ip, bind_sockets
from tornado.stack_context import ExceptionStackContext
from tornado.testing import AsyncTestCase, gen_test
from tornado.test.util import unittest, skipIfNoNetwork
try:
from concurrent import futures
except ImportError:
futures = None
try:
import pycares
except ImportError:
pycares = None
else:
from tornado.platform.caresresolver import CaresResolver
try:
import twisted
import twisted.names
except ImportError:
twisted = None
else:
from tornado.platform.twisted import TwistedResolver
class _ResolverTestMixin(object):
def skipOnCares(self):
# Some DNS-hijacking ISPs (e.g. Time Warner) return non-empty results
# with an NXDOMAIN status code. Most resolvers treat this as an error;
# C-ares returns the results, making the "bad_host" tests unreliable.
# C-ares will try to resolve even malformed names, such as the
# name with spaces used in this test.
if self.resolver.__class__.__name__ == 'CaresResolver':
self.skipTest("CaresResolver doesn't recognize fake NXDOMAIN")
def test_localhost(self):
self.resolver.resolve('localhost', 80, callback=self.stop)
result = self.wait()
self.assertIn((socket.AF_INET, ('127.0.0.1', 80)), result)
@gen_test
def test_future_interface(self):
addrinfo = yield self.resolver.resolve('localhost', 80,
socket.AF_UNSPEC)
self.assertIn((socket.AF_INET, ('127.0.0.1', 80)),
addrinfo)
def test_bad_host(self):
self.skipOnCares()
def handler(exc_typ, exc_val, exc_tb):
self.stop(exc_val)
return True # Halt propagation.
with ExceptionStackContext(handler):
self.resolver.resolve('an invalid domain', 80, callback=self.stop)
result = self.wait()
self.assertIsInstance(result, Exception)
@gen_test
def test_future_interface_bad_host(self):
self.skipOnCares()
with self.assertRaises(Exception):
yield self.resolver.resolve('an invalid domain', 80,
socket.AF_UNSPEC)
@skipIfNoNetwork
class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(BlockingResolverTest, self).setUp()
self.resolver = BlockingResolver(io_loop=self.io_loop)
@skipIfNoNetwork
@unittest.skipIf(futures is None, "futures module not present")
class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(ThreadedResolverTest, self).setUp()
self.resolver = ThreadedResolver(io_loop=self.io_loop)
def tearDown(self):
self.resolver.close()
super(ThreadedResolverTest, self).tearDown()
@skipIfNoNetwork
@unittest.skipIf(futures is None, "futures module not present")
@unittest.skipIf(sys.platform == 'win32', "preexec_fn not available on win32")
class ThreadedResolverImportTest(unittest.TestCase):
def test_import(self):
TIMEOUT = 5
# Test for a deadlock when importing a module that runs the
# ThreadedResolver at import-time. See resolve_test.py for
# full explanation.
command = [
sys.executable,
'-c',
'import tornado.test.resolve_test_helper']
start = time.time()
popen = Popen(command, preexec_fn=lambda: signal.alarm(TIMEOUT))
while time.time() - start < TIMEOUT:
return_code = popen.poll()
if return_code is not None:
self.assertEqual(0, return_code)
return # Success.
time.sleep(0.05)
self.fail("import timed out")
@skipIfNoNetwork
@unittest.skipIf(pycares is None, "pycares module not present")
class CaresResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(CaresResolverTest, self).setUp()
self.resolver = CaresResolver(io_loop=self.io_loop)
@skipIfNoNetwork
@unittest.skipIf(twisted is None, "twisted module not present")
@unittest.skipIf(getattr(twisted, '__version__', '0.0') < "12.1", "old version of twisted")
class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(TwistedResolverTest, self).setUp()
self.resolver = TwistedResolver(io_loop=self.io_loop)
class IsValidIPTest(unittest.TestCase):
def test_is_valid_ip(self):
self.assertTrue(is_valid_ip('127.0.0.1'))
self.assertTrue(is_valid_ip('4.4.4.4'))
self.assertTrue(is_valid_ip('::1'))
self.assertTrue(is_valid_ip('2620:0:1cfe:face:b00c::3'))
self.assertTrue(not is_valid_ip('www.google.com'))
self.assertTrue(not is_valid_ip('localhost'))
self.assertTrue(not is_valid_ip('4.4.4.4<'))
self.assertTrue(not is_valid_ip(' 127.0.0.1'))
self.assertTrue(not is_valid_ip(''))
self.assertTrue(not is_valid_ip(' '))
self.assertTrue(not is_valid_ip('\n'))
self.assertTrue(not is_valid_ip('\x00'))
class TestPortAllocation(unittest.TestCase):
def test_same_port_allocation(self):
if 'TRAVIS' in os.environ:
self.skipTest("dual-stack servers often have port conflicts on travis")
sockets = bind_sockets(None, 'localhost')
try:
port = sockets[0].getsockname()[1]
self.assertTrue(all(s.getsockname()[1] == port
for s in sockets[1:]))
finally:
for sock in sockets:
sock.close()

View file

@ -0,0 +1,2 @@
port=443
port=443

View file

@ -0,0 +1,220 @@
from __future__ import absolute_import, division, print_function, with_statement
import datetime
import os
import sys
from tornado.options import OptionParser, Error
from tornado.util import basestring_type
from tornado.test.util import unittest
try:
from cStringIO import StringIO # python 2
except ImportError:
from io import StringIO # python 3
try:
from unittest import mock # python 3.3
except ImportError:
try:
import mock # third-party mock package
except ImportError:
mock = None
class OptionsTest(unittest.TestCase):
def test_parse_command_line(self):
options = OptionParser()
options.define("port", default=80)
options.parse_command_line(["main.py", "--port=443"])
self.assertEqual(options.port, 443)
def test_parse_config_file(self):
options = OptionParser()
options.define("port", default=80)
options.parse_config_file(os.path.join(os.path.dirname(__file__),
"options_test.cfg"))
self.assertEquals(options.port, 443)
def test_parse_callbacks(self):
options = OptionParser()
self.called = False
def callback():
self.called = True
options.add_parse_callback(callback)
# non-final parse doesn't run callbacks
options.parse_command_line(["main.py"], final=False)
self.assertFalse(self.called)
# final parse does
options.parse_command_line(["main.py"])
self.assertTrue(self.called)
# callbacks can be run more than once on the same options
# object if there are multiple final parses
self.called = False
options.parse_command_line(["main.py"])
self.assertTrue(self.called)
def test_help(self):
options = OptionParser()
try:
orig_stderr = sys.stderr
sys.stderr = StringIO()
with self.assertRaises(SystemExit):
options.parse_command_line(["main.py", "--help"])
usage = sys.stderr.getvalue()
finally:
sys.stderr = orig_stderr
self.assertIn("Usage:", usage)
def test_subcommand(self):
base_options = OptionParser()
base_options.define("verbose", default=False)
sub_options = OptionParser()
sub_options.define("foo", type=str)
rest = base_options.parse_command_line(
["main.py", "--verbose", "subcommand", "--foo=bar"])
self.assertEqual(rest, ["subcommand", "--foo=bar"])
self.assertTrue(base_options.verbose)
rest2 = sub_options.parse_command_line(rest)
self.assertEqual(rest2, [])
self.assertEqual(sub_options.foo, "bar")
# the two option sets are distinct
try:
orig_stderr = sys.stderr
sys.stderr = StringIO()
with self.assertRaises(Error):
sub_options.parse_command_line(["subcommand", "--verbose"])
finally:
sys.stderr = orig_stderr
def test_setattr(self):
options = OptionParser()
options.define('foo', default=1, type=int)
options.foo = 2
self.assertEqual(options.foo, 2)
def test_setattr_type_check(self):
# setattr requires that options be the right type and doesn't
# parse from string formats.
options = OptionParser()
options.define('foo', default=1, type=int)
with self.assertRaises(Error):
options.foo = '2'
def test_setattr_with_callback(self):
values = []
options = OptionParser()
options.define('foo', default=1, type=int, callback=values.append)
options.foo = 2
self.assertEqual(values, [2])
def _sample_options(self):
options = OptionParser()
options.define('a', default=1)
options.define('b', default=2)
return options
def test_iter(self):
options = self._sample_options()
# OptionParsers always define 'help'.
self.assertEqual(set(['a', 'b', 'help']), set(iter(options)))
def test_getitem(self):
options = self._sample_options()
self.assertEqual(1, options['a'])
def test_items(self):
options = self._sample_options()
# OptionParsers always define 'help'.
expected = [('a', 1), ('b', 2), ('help', options.help)]
actual = sorted(options.items())
self.assertEqual(expected, actual)
def test_as_dict(self):
options = self._sample_options()
expected = {'a': 1, 'b': 2, 'help': options.help}
self.assertEqual(expected, options.as_dict())
def test_group_dict(self):
options = OptionParser()
options.define('a', default=1)
options.define('b', group='b_group', default=2)
frame = sys._getframe(0)
this_file = frame.f_code.co_filename
self.assertEqual(set(['b_group', '', this_file]), options.groups())
b_group_dict = options.group_dict('b_group')
self.assertEqual({'b': 2}, b_group_dict)
self.assertEqual({}, options.group_dict('nonexistent'))
@unittest.skipIf(mock is None, 'mock package not present')
def test_mock_patch(self):
# ensure that our setattr hooks don't interfere with mock.patch
options = OptionParser()
options.define('foo', default=1)
options.parse_command_line(['main.py', '--foo=2'])
self.assertEqual(options.foo, 2)
with mock.patch.object(options.mockable(), 'foo', 3):
self.assertEqual(options.foo, 3)
self.assertEqual(options.foo, 2)
# Try nested patches mixed with explicit sets
with mock.patch.object(options.mockable(), 'foo', 4):
self.assertEqual(options.foo, 4)
options.foo = 5
self.assertEqual(options.foo, 5)
with mock.patch.object(options.mockable(), 'foo', 6):
self.assertEqual(options.foo, 6)
self.assertEqual(options.foo, 5)
self.assertEqual(options.foo, 2)
def test_types(self):
options = OptionParser()
options.define('str', type=str)
options.define('basestring', type=basestring_type)
options.define('int', type=int)
options.define('float', type=float)
options.define('datetime', type=datetime.datetime)
options.define('timedelta', type=datetime.timedelta)
options.parse_command_line(['main.py',
'--str=asdf',
'--basestring=qwer',
'--int=42',
'--float=1.5',
'--datetime=2013-04-28 05:16',
'--timedelta=45s'])
self.assertEqual(options.str, 'asdf')
self.assertEqual(options.basestring, 'qwer')
self.assertEqual(options.int, 42)
self.assertEqual(options.float, 1.5)
self.assertEqual(options.datetime,
datetime.datetime(2013, 4, 28, 5, 16))
self.assertEqual(options.timedelta, datetime.timedelta(seconds=45))
def test_multiple_string(self):
options = OptionParser()
options.define('foo', type=str, multiple=True)
options.parse_command_line(['main.py', '--foo=a,b,c'])
self.assertEqual(options.foo, ['a', 'b', 'c'])
def test_multiple_int(self):
options = OptionParser()
options.define('foo', type=int, multiple=True)
options.parse_command_line(['main.py', '--foo=1,3,5:7'])
self.assertEqual(options.foo, [1, 3, 5, 6, 7])
def test_error_redefine(self):
options = OptionParser()
options.define('foo')
with self.assertRaises(Error) as cm:
options.define('foo')
self.assertRegexpMatches(str(cm.exception),
'Option.*foo.*already defined')

View file

@ -0,0 +1,214 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
import logging
import os
import signal
import subprocess
import sys
from tornado.httpclient import HTTPClient, HTTPError
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from tornado.log import gen_log
from tornado.process import fork_processes, task_id, Subprocess
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.testing import bind_unused_port, ExpectLog, AsyncTestCase
from tornado.test.util import unittest, skipIfNonUnix
from tornado.web import RequestHandler, Application
def skip_if_twisted():
if IOLoop.configured_class().__name__.endswith(('TwistedIOLoop',
'AsyncIOMainLoop')):
raise unittest.SkipTest("Process tests not compatible with "
"TwistedIOLoop or AsyncIOMainLoop")
# Not using AsyncHTTPTestCase because we need control over the IOLoop.
@skipIfNonUnix
class ProcessTest(unittest.TestCase):
def get_app(self):
class ProcessHandler(RequestHandler):
def get(self):
if self.get_argument("exit", None):
# must use os._exit instead of sys.exit so unittest's
# exception handler doesn't catch it
os._exit(int(self.get_argument("exit")))
if self.get_argument("signal", None):
os.kill(os.getpid(),
int(self.get_argument("signal")))
self.write(str(os.getpid()))
return Application([("/", ProcessHandler)])
def tearDown(self):
if task_id() is not None:
# We're in a child process, and probably got to this point
# via an uncaught exception. If we return now, both
# processes will continue with the rest of the test suite.
# Exit now so the parent process will restart the child
# (since we don't have a clean way to signal failure to
# the parent that won't restart)
logging.error("aborting child process from tearDown")
logging.shutdown()
os._exit(1)
# In the surviving process, clear the alarm we set earlier
signal.alarm(0)
super(ProcessTest, self).tearDown()
def test_multi_process(self):
# This test can't work on twisted because we use the global reactor
# and have no way to get it back into a sane state after the fork.
skip_if_twisted()
with ExpectLog(gen_log, "(Starting .* processes|child .* exited|uncaught exception)"):
self.assertFalse(IOLoop.initialized())
sock, port = bind_unused_port()
def get_url(path):
return "http://127.0.0.1:%d%s" % (port, path)
# ensure that none of these processes live too long
signal.alarm(5) # master process
try:
id = fork_processes(3, max_restarts=3)
self.assertTrue(id is not None)
signal.alarm(5) # child processes
except SystemExit as e:
# if we exit cleanly from fork_processes, all the child processes
# finished with status 0
self.assertEqual(e.code, 0)
self.assertTrue(task_id() is None)
sock.close()
return
try:
if id in (0, 1):
self.assertEqual(id, task_id())
server = HTTPServer(self.get_app())
server.add_sockets([sock])
IOLoop.instance().start()
elif id == 2:
self.assertEqual(id, task_id())
sock.close()
# Always use SimpleAsyncHTTPClient here; the curl
# version appears to get confused sometimes if the
# connection gets closed before it's had a chance to
# switch from writing mode to reading mode.
client = HTTPClient(SimpleAsyncHTTPClient)
def fetch(url, fail_ok=False):
try:
return client.fetch(get_url(url))
except HTTPError as e:
if not (fail_ok and e.code == 599):
raise
# Make two processes exit abnormally
fetch("/?exit=2", fail_ok=True)
fetch("/?exit=3", fail_ok=True)
# They've been restarted, so a new fetch will work
int(fetch("/").body)
# Now the same with signals
# Disabled because on the mac a process dying with a signal
# can trigger an "Application exited abnormally; send error
# report to Apple?" prompt.
# fetch("/?signal=%d" % signal.SIGTERM, fail_ok=True)
# fetch("/?signal=%d" % signal.SIGABRT, fail_ok=True)
# int(fetch("/").body)
# Now kill them normally so they won't be restarted
fetch("/?exit=0", fail_ok=True)
# One process left; watch it's pid change
pid = int(fetch("/").body)
fetch("/?exit=4", fail_ok=True)
pid2 = int(fetch("/").body)
self.assertNotEqual(pid, pid2)
# Kill the last one so we shut down cleanly
fetch("/?exit=0", fail_ok=True)
os._exit(0)
except Exception:
logging.error("exception in child process %d", id, exc_info=True)
raise
@skipIfNonUnix
class SubprocessTest(AsyncTestCase):
def test_subprocess(self):
if IOLoop.configured_class().__name__.endswith('LayeredTwistedIOLoop'):
# This test fails non-deterministically with LayeredTwistedIOLoop.
# (the read_until('\n') returns '\n' instead of 'hello\n')
# This probably indicates a problem with either TornadoReactor
# or TwistedIOLoop, but I haven't been able to track it down
# and for now this is just causing spurious travis-ci failures.
raise unittest.SkipTest("Subprocess tests not compatible with "
"LayeredTwistedIOLoop")
subproc = Subprocess([sys.executable, '-u', '-i'],
stdin=Subprocess.STREAM,
stdout=Subprocess.STREAM, stderr=subprocess.STDOUT,
io_loop=self.io_loop)
self.addCleanup(lambda: os.kill(subproc.pid, signal.SIGTERM))
subproc.stdout.read_until(b'>>> ', self.stop)
self.wait()
subproc.stdin.write(b"print('hello')\n")
subproc.stdout.read_until(b'\n', self.stop)
data = self.wait()
self.assertEqual(data, b"hello\n")
subproc.stdout.read_until(b">>> ", self.stop)
self.wait()
subproc.stdin.write(b"raise SystemExit\n")
subproc.stdout.read_until_close(self.stop)
data = self.wait()
self.assertEqual(data, b"")
def test_close_stdin(self):
# Close the parent's stdin handle and see that the child recognizes it.
subproc = Subprocess([sys.executable, '-u', '-i'],
stdin=Subprocess.STREAM,
stdout=Subprocess.STREAM, stderr=subprocess.STDOUT,
io_loop=self.io_loop)
self.addCleanup(lambda: os.kill(subproc.pid, signal.SIGTERM))
subproc.stdout.read_until(b'>>> ', self.stop)
self.wait()
subproc.stdin.close()
subproc.stdout.read_until_close(self.stop)
data = self.wait()
self.assertEqual(data, b"\n")
def test_stderr(self):
subproc = Subprocess([sys.executable, '-u', '-c',
r"import sys; sys.stderr.write('hello\n')"],
stderr=Subprocess.STREAM,
io_loop=self.io_loop)
self.addCleanup(lambda: os.kill(subproc.pid, signal.SIGTERM))
subproc.stderr.read_until(b'\n', self.stop)
data = self.wait()
self.assertEqual(data, b'hello\n')
def test_sigchild(self):
# Twisted's SIGCHLD handler and Subprocess's conflict with each other.
skip_if_twisted()
Subprocess.initialize(io_loop=self.io_loop)
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, '-c', 'pass'],
io_loop=self.io_loop)
subproc.set_exit_callback(self.stop)
ret = self.wait()
self.assertEqual(ret, 0)
self.assertEqual(subproc.returncode, ret)
def test_sigchild_signal(self):
skip_if_twisted()
Subprocess.initialize(io_loop=self.io_loop)
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, '-c',
'import time; time.sleep(30)'],
io_loop=self.io_loop)
subproc.set_exit_callback(self.stop)
os.kill(subproc.pid, signal.SIGTERM)
ret = self.wait()
self.assertEqual(subproc.returncode, ret)
self.assertEqual(ret, -signal.SIGTERM)

View file

@ -0,0 +1,12 @@
from __future__ import absolute_import, division, print_function, with_statement
from tornado.ioloop import IOLoop
from tornado.netutil import ThreadedResolver
from tornado.util import u
# When this module is imported, it runs getaddrinfo on a thread. Since
# the hostname is unicode, getaddrinfo attempts to import encodings.idna
# but blocks on the import lock. Verify that ThreadedResolver avoids
# this deadlock.
resolver = ThreadedResolver()
IOLoop.current().run_sync(lambda: resolver.resolve(u('localhost'), 80))

View file

@ -0,0 +1,137 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
import gc
import locale # system locale module, not tornado.locale
import logging
import operator
import textwrap
import sys
from tornado.httpclient import AsyncHTTPClient
from tornado.ioloop import IOLoop
from tornado.netutil import Resolver
from tornado.options import define, options, add_parse_callback
from tornado.test.util import unittest
try:
reduce # py2
except NameError:
from functools import reduce # py3
TEST_MODULES = [
'tornado.httputil.doctests',
'tornado.iostream.doctests',
'tornado.util.doctests',
'tornado.test.auth_test',
'tornado.test.concurrent_test',
'tornado.test.curl_httpclient_test',
'tornado.test.escape_test',
'tornado.test.gen_test',
'tornado.test.httpclient_test',
'tornado.test.httpserver_test',
'tornado.test.httputil_test',
'tornado.test.import_test',
'tornado.test.ioloop_test',
'tornado.test.iostream_test',
'tornado.test.locale_test',
'tornado.test.netutil_test',
'tornado.test.log_test',
'tornado.test.options_test',
'tornado.test.process_test',
'tornado.test.simple_httpclient_test',
'tornado.test.stack_context_test',
'tornado.test.tcpclient_test',
'tornado.test.template_test',
'tornado.test.testing_test',
'tornado.test.twisted_test',
'tornado.test.util_test',
'tornado.test.web_test',
'tornado.test.websocket_test',
'tornado.test.wsgi_test',
]
def all():
return unittest.defaultTestLoader.loadTestsFromNames(TEST_MODULES)
class TornadoTextTestRunner(unittest.TextTestRunner):
def run(self, test):
result = super(TornadoTextTestRunner, self).run(test)
if result.skipped:
skip_reasons = set(reason for (test, reason) in result.skipped)
self.stream.write(textwrap.fill(
"Some tests were skipped because: %s" %
", ".join(sorted(skip_reasons))))
self.stream.write("\n")
return result
def main():
# The -W command-line option does not work in a virtualenv with
# python 3 (as of virtualenv 1.7), so configure warnings
# programmatically instead.
import warnings
# Be strict about most warnings. This also turns on warnings that are
# ignored by default, including DeprecationWarnings and
# python 3.2's ResourceWarnings.
warnings.filterwarnings("error")
# setuptools sometimes gives ImportWarnings about things that are on
# sys.path even if they're not being used.
warnings.filterwarnings("ignore", category=ImportWarning)
# Tornado generally shouldn't use anything deprecated, but some of
# our dependencies do (last match wins).
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("error", category=DeprecationWarning,
module=r"tornado\..*")
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
warnings.filterwarnings("error", category=PendingDeprecationWarning,
module=r"tornado\..*")
# The unittest module is aggressive about deprecating redundant methods,
# leaving some without non-deprecated spellings that work on both
# 2.7 and 3.2
warnings.filterwarnings("ignore", category=DeprecationWarning,
message="Please use assert.* instead")
logging.getLogger("tornado.access").setLevel(logging.CRITICAL)
define('httpclient', type=str, default=None,
callback=lambda s: AsyncHTTPClient.configure(
s, defaults=dict(allow_ipv6=False)))
define('ioloop', type=str, default=None)
define('ioloop_time_monotonic', default=False)
define('resolver', type=str, default=None,
callback=Resolver.configure)
define('debug_gc', type=str, multiple=True,
help="A comma-separated list of gc module debug constants, "
"e.g. DEBUG_STATS or DEBUG_COLLECTABLE,DEBUG_OBJECTS",
callback=lambda values: gc.set_debug(
reduce(operator.or_, (getattr(gc, v) for v in values))))
define('locale', type=str, default=None,
callback=lambda x: locale.setlocale(locale.LC_ALL, x))
def configure_ioloop():
kwargs = {}
if options.ioloop_time_monotonic:
from tornado.platform.auto import monotonic_time
if monotonic_time is None:
raise RuntimeError("monotonic clock not found")
kwargs['time_func'] = monotonic_time
if options.ioloop or kwargs:
IOLoop.configure(options.ioloop, **kwargs)
add_parse_callback(configure_ioloop)
import tornado.testing
kwargs = {}
if sys.version_info >= (3, 2):
# HACK: unittest.main will make its own changes to the warning
# configuration, which may conflict with the settings above
# or command-line flags like -bb. Passing warnings=False
# suppresses this behavior, although this looks like an implementation
# detail. http://bugs.python.org/issue15626
kwargs['warnings'] = False
kwargs['testRunner'] = TornadoTextTestRunner
tornado.testing.main(**kwargs)
if __name__ == '__main__':
main()

View file

@ -0,0 +1,552 @@
from __future__ import absolute_import, division, print_function, with_statement
import collections
from contextlib import closing
import errno
import gzip
import logging
import os
import re
import socket
import sys
from tornado import gen
from tornado.httpclient import AsyncHTTPClient
from tornado.httputil import HTTPHeaders
from tornado.ioloop import IOLoop
from tornado.log import gen_log, app_log
from tornado.netutil import Resolver, bind_sockets
from tornado.simple_httpclient import SimpleAsyncHTTPClient, _default_ca_certs
from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler
from tornado.test import httpclient_test
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog
from tornado.test.util import skipOnTravis, skipIfNoIPv6
from tornado.web import RequestHandler, Application, asynchronous, url, stream_request_body
class SimpleHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
def get_http_client(self):
client = SimpleAsyncHTTPClient(io_loop=self.io_loop,
force_instance=True)
self.assertTrue(isinstance(client, SimpleAsyncHTTPClient))
return client
class TriggerHandler(RequestHandler):
def initialize(self, queue, wake_callback):
self.queue = queue
self.wake_callback = wake_callback
@asynchronous
def get(self):
logging.debug("queuing trigger")
self.queue.append(self.finish)
if self.get_argument("wake", "true") == "true":
self.wake_callback()
class HangHandler(RequestHandler):
@asynchronous
def get(self):
pass
class ContentLengthHandler(RequestHandler):
def get(self):
self.set_header("Content-Length", self.get_argument("value"))
self.write("ok")
class HeadHandler(RequestHandler):
def head(self):
self.set_header("Content-Length", "7")
class OptionsHandler(RequestHandler):
def options(self):
self.set_header("Access-Control-Allow-Origin", "*")
self.write("ok")
class NoContentHandler(RequestHandler):
def get(self):
if self.get_argument("error", None):
self.set_header("Content-Length", "5")
self.write("hello")
self.set_status(204)
class SeeOtherPostHandler(RequestHandler):
def post(self):
redirect_code = int(self.request.body)
assert redirect_code in (302, 303), "unexpected body %r" % self.request.body
self.set_header("Location", "/see_other_get")
self.set_status(redirect_code)
class SeeOtherGetHandler(RequestHandler):
def get(self):
if self.request.body:
raise Exception("unexpected body %r" % self.request.body)
self.write("ok")
class HostEchoHandler(RequestHandler):
def get(self):
self.write(self.request.headers["Host"])
class NoContentLengthHandler(RequestHandler):
@gen.coroutine
def get(self):
# Emulate the old HTTP/1.0 behavior of returning a body with no
# content-length. Tornado handles content-length at the framework
# level so we have to go around it.
stream = self.request.connection.stream
yield stream.write(b"HTTP/1.0 200 OK\r\n\r\n"
b"hello")
stream.close()
class EchoPostHandler(RequestHandler):
def post(self):
self.write(self.request.body)
@stream_request_body
class RespondInPrepareHandler(RequestHandler):
def prepare(self):
self.set_status(403)
self.finish("forbidden")
class SimpleHTTPClientTestMixin(object):
def get_app(self):
# callable objects to finish pending /trigger requests
self.triggers = collections.deque()
return Application([
url("/trigger", TriggerHandler, dict(queue=self.triggers,
wake_callback=self.stop)),
url("/chunk", ChunkHandler),
url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
url("/hang", HangHandler),
url("/hello", HelloWorldHandler),
url("/content_length", ContentLengthHandler),
url("/head", HeadHandler),
url("/options", OptionsHandler),
url("/no_content", NoContentHandler),
url("/see_other_post", SeeOtherPostHandler),
url("/see_other_get", SeeOtherGetHandler),
url("/host_echo", HostEchoHandler),
url("/no_content_length", NoContentLengthHandler),
url("/echo_post", EchoPostHandler),
url("/respond_in_prepare", RespondInPrepareHandler),
], gzip=True)
def test_singleton(self):
# Class "constructor" reuses objects on the same IOLoop
self.assertTrue(SimpleAsyncHTTPClient(self.io_loop) is
SimpleAsyncHTTPClient(self.io_loop))
# unless force_instance is used
self.assertTrue(SimpleAsyncHTTPClient(self.io_loop) is not
SimpleAsyncHTTPClient(self.io_loop,
force_instance=True))
# different IOLoops use different objects
with closing(IOLoop()) as io_loop2:
self.assertTrue(SimpleAsyncHTTPClient(self.io_loop) is not
SimpleAsyncHTTPClient(io_loop2))
def test_connection_limit(self):
with closing(self.create_client(max_clients=2)) as client:
self.assertEqual(client.max_clients, 2)
seen = []
# Send 4 requests. Two can be sent immediately, while the others
# will be queued
for i in range(4):
client.fetch(self.get_url("/trigger"),
lambda response, i=i: (seen.append(i), self.stop()))
self.wait(condition=lambda: len(self.triggers) == 2)
self.assertEqual(len(client.queue), 2)
# Finish the first two requests and let the next two through
self.triggers.popleft()()
self.triggers.popleft()()
self.wait(condition=lambda: (len(self.triggers) == 2 and
len(seen) == 2))
self.assertEqual(set(seen), set([0, 1]))
self.assertEqual(len(client.queue), 0)
# Finish all the pending requests
self.triggers.popleft()()
self.triggers.popleft()()
self.wait(condition=lambda: len(seen) == 4)
self.assertEqual(set(seen), set([0, 1, 2, 3]))
self.assertEqual(len(self.triggers), 0)
def test_redirect_connection_limit(self):
# following redirects should not consume additional connections
with closing(self.create_client(max_clients=1)) as client:
client.fetch(self.get_url('/countdown/3'), self.stop,
max_redirects=3)
response = self.wait()
response.rethrow()
def test_default_certificates_exist(self):
open(_default_ca_certs()).close()
def test_gzip(self):
# All the tests in this file should be using gzip, but this test
# ensures that it is in fact getting compressed.
# Setting Accept-Encoding manually bypasses the client's
# decompression so we can see the raw data.
response = self.fetch("/chunk", use_gzip=False,
headers={"Accept-Encoding": "gzip"})
self.assertEqual(response.headers["Content-Encoding"], "gzip")
self.assertNotEqual(response.body, b"asdfqwer")
# Our test data gets bigger when gzipped. Oops. :)
self.assertEqual(len(response.body), 34)
f = gzip.GzipFile(mode="r", fileobj=response.buffer)
self.assertEqual(f.read(), b"asdfqwer")
def test_max_redirects(self):
response = self.fetch("/countdown/5", max_redirects=3)
self.assertEqual(302, response.code)
# We requested 5, followed three redirects for 4, 3, 2, then the last
# unfollowed redirect is to 1.
self.assertTrue(response.request.url.endswith("/countdown/5"))
self.assertTrue(response.effective_url.endswith("/countdown/2"))
self.assertTrue(response.headers["Location"].endswith("/countdown/1"))
def test_header_reuse(self):
# Apps may reuse a headers object if they are only passing in constant
# headers like user-agent. The header object should not be modified.
headers = HTTPHeaders({'User-Agent': 'Foo'})
self.fetch("/hello", headers=headers)
self.assertEqual(list(headers.get_all()), [('User-Agent', 'Foo')])
def test_see_other_redirect(self):
for code in (302, 303):
response = self.fetch("/see_other_post", method="POST", body="%d" % code)
self.assertEqual(200, response.code)
self.assertTrue(response.request.url.endswith("/see_other_post"))
self.assertTrue(response.effective_url.endswith("/see_other_get"))
# request is the original request, is a POST still
self.assertEqual("POST", response.request.method)
@skipOnTravis
def test_request_timeout(self):
response = self.fetch('/trigger?wake=false', request_timeout=0.1)
self.assertEqual(response.code, 599)
self.assertTrue(0.099 < response.request_time < 0.15, response.request_time)
self.assertEqual(str(response.error), "HTTP 599: Timeout")
# trigger the hanging request to let it clean up after itself
self.triggers.popleft()()
@skipIfNoIPv6
def test_ipv6(self):
try:
[sock] = bind_sockets(None, '::1', family=socket.AF_INET6)
port = sock.getsockname()[1]
self.http_server.add_socket(sock)
except socket.gaierror as e:
if e.args[0] == socket.EAI_ADDRFAMILY:
# python supports ipv6, but it's not configured on the network
# interface, so skip this test.
return
raise
url = '%s://[::1]:%d/hello' % (self.get_protocol(), port)
# ipv6 is currently enabled by default but can be disabled
self.http_client.fetch(url, self.stop, allow_ipv6=False)
response = self.wait()
self.assertEqual(response.code, 599)
self.http_client.fetch(url, self.stop)
response = self.wait()
self.assertEqual(response.body, b"Hello world!")
def xtest_multiple_content_length_accepted(self):
response = self.fetch("/content_length?value=2,2")
self.assertEqual(response.body, b"ok")
response = self.fetch("/content_length?value=2,%202,2")
self.assertEqual(response.body, b"ok")
response = self.fetch("/content_length?value=2,4")
self.assertEqual(response.code, 599)
response = self.fetch("/content_length?value=2,%202,3")
self.assertEqual(response.code, 599)
def test_head_request(self):
response = self.fetch("/head", method="HEAD")
self.assertEqual(response.code, 200)
self.assertEqual(response.headers["content-length"], "7")
self.assertFalse(response.body)
def test_options_request(self):
response = self.fetch("/options", method="OPTIONS")
self.assertEqual(response.code, 200)
self.assertEqual(response.headers["content-length"], "2")
self.assertEqual(response.headers["access-control-allow-origin"], "*")
self.assertEqual(response.body, b"ok")
def test_no_content(self):
response = self.fetch("/no_content")
self.assertEqual(response.code, 204)
# 204 status doesn't need a content-length, but tornado will
# add a zero content-length anyway.
self.assertEqual(response.headers["Content-length"], "0")
# 204 status with non-zero content length is malformed
with ExpectLog(app_log, "Uncaught exception"):
response = self.fetch("/no_content?error=1")
self.assertEqual(response.code, 599)
def test_host_header(self):
host_re = re.compile(b"^localhost:[0-9]+$")
response = self.fetch("/host_echo")
self.assertTrue(host_re.match(response.body))
url = self.get_url("/host_echo").replace("http://", "http://me:secret@")
self.http_client.fetch(url, self.stop)
response = self.wait()
self.assertTrue(host_re.match(response.body), response.body)
def test_connection_refused(self):
server_socket, port = bind_unused_port()
server_socket.close()
with ExpectLog(gen_log, ".*", required=False):
self.http_client.fetch("http://localhost:%d/" % port, self.stop)
response = self.wait()
self.assertEqual(599, response.code)
if sys.platform != 'cygwin':
# cygwin returns EPERM instead of ECONNREFUSED here
contains_errno = str(errno.ECONNREFUSED) in str(response.error)
if not contains_errno and hasattr(errno, "WSAECONNREFUSED"):
contains_errno = str(errno.WSAECONNREFUSED) in str(response.error)
self.assertTrue(contains_errno, response.error)
# This is usually "Connection refused".
# On windows, strerror is broken and returns "Unknown error".
expected_message = os.strerror(errno.ECONNREFUSED)
self.assertTrue(expected_message in str(response.error),
response.error)
def test_queue_timeout(self):
with closing(self.create_client(max_clients=1)) as client:
client.fetch(self.get_url('/trigger'), self.stop,
request_timeout=10)
# Wait for the trigger request to block, not complete.
self.wait()
client.fetch(self.get_url('/hello'), self.stop,
connect_timeout=0.1)
response = self.wait()
self.assertEqual(response.code, 599)
self.assertTrue(response.request_time < 1, response.request_time)
self.assertEqual(str(response.error), "HTTP 599: Timeout")
self.triggers.popleft()()
self.wait()
def test_no_content_length(self):
response = self.fetch("/no_content_length")
self.assertEquals(b"hello", response.body)
def sync_body_producer(self, write):
write(b'1234')
write(b'5678')
@gen.coroutine
def async_body_producer(self, write):
yield write(b'1234')
yield gen.Task(IOLoop.current().add_callback)
yield write(b'5678')
def test_sync_body_producer_chunked(self):
response = self.fetch("/echo_post", method="POST",
body_producer=self.sync_body_producer)
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_sync_body_producer_content_length(self):
response = self.fetch("/echo_post", method="POST",
body_producer=self.sync_body_producer,
headers={'Content-Length': '8'})
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_async_body_producer_chunked(self):
response = self.fetch("/echo_post", method="POST",
body_producer=self.async_body_producer)
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_async_body_producer_content_length(self):
response = self.fetch("/echo_post", method="POST",
body_producer=self.async_body_producer,
headers={'Content-Length': '8'})
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_100_continue(self):
response = self.fetch("/echo_post", method="POST",
body=b"1234",
expect_100_continue=True)
self.assertEqual(response.body, b"1234")
def test_100_continue_early_response(self):
def body_producer(write):
raise Exception("should not be called")
response = self.fetch("/respond_in_prepare", method="POST",
body_producer=body_producer,
expect_100_continue=True)
self.assertEqual(response.code, 403)
class SimpleHTTPClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPTestCase):
def setUp(self):
super(SimpleHTTPClientTestCase, self).setUp()
self.http_client = self.create_client()
def create_client(self, **kwargs):
return SimpleAsyncHTTPClient(self.io_loop, force_instance=True,
**kwargs)
class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase):
def setUp(self):
super(SimpleHTTPSClientTestCase, self).setUp()
self.http_client = self.create_client()
def create_client(self, **kwargs):
return SimpleAsyncHTTPClient(self.io_loop, force_instance=True,
defaults=dict(validate_cert=False),
**kwargs)
class CreateAsyncHTTPClientTestCase(AsyncTestCase):
def setUp(self):
super(CreateAsyncHTTPClientTestCase, self).setUp()
self.saved = AsyncHTTPClient._save_configuration()
def tearDown(self):
AsyncHTTPClient._restore_configuration(self.saved)
super(CreateAsyncHTTPClientTestCase, self).tearDown()
def test_max_clients(self):
AsyncHTTPClient.configure(SimpleAsyncHTTPClient)
with closing(AsyncHTTPClient(
self.io_loop, force_instance=True)) as client:
self.assertEqual(client.max_clients, 10)
with closing(AsyncHTTPClient(
self.io_loop, max_clients=11, force_instance=True)) as client:
self.assertEqual(client.max_clients, 11)
# Now configure max_clients statically and try overriding it
# with each way max_clients can be passed
AsyncHTTPClient.configure(SimpleAsyncHTTPClient, max_clients=12)
with closing(AsyncHTTPClient(
self.io_loop, force_instance=True)) as client:
self.assertEqual(client.max_clients, 12)
with closing(AsyncHTTPClient(
self.io_loop, max_clients=13, force_instance=True)) as client:
self.assertEqual(client.max_clients, 13)
with closing(AsyncHTTPClient(
self.io_loop, max_clients=14, force_instance=True)) as client:
self.assertEqual(client.max_clients, 14)
class HTTP100ContinueTestCase(AsyncHTTPTestCase):
def respond_100(self, request):
self.request = request
self.request.connection.stream.write(
b"HTTP/1.1 100 CONTINUE\r\n\r\n",
self.respond_200)
def respond_200(self):
self.request.connection.stream.write(
b"HTTP/1.1 200 OK\r\nContent-Length: 1\r\n\r\nA",
self.request.connection.stream.close)
def get_app(self):
# Not a full Application, but works as an HTTPServer callback
return self.respond_100
def test_100_continue(self):
res = self.fetch('/')
self.assertEqual(res.body, b'A')
class HostnameMappingTestCase(AsyncHTTPTestCase):
def setUp(self):
super(HostnameMappingTestCase, self).setUp()
self.http_client = SimpleAsyncHTTPClient(
self.io_loop,
hostname_mapping={
'www.example.com': '127.0.0.1',
('foo.example.com', 8000): ('127.0.0.1', self.get_http_port()),
})
def get_app(self):
return Application([url("/hello", HelloWorldHandler), ])
def test_hostname_mapping(self):
self.http_client.fetch(
'http://www.example.com:%d/hello' % self.get_http_port(), self.stop)
response = self.wait()
response.rethrow()
self.assertEqual(response.body, b'Hello world!')
def test_port_mapping(self):
self.http_client.fetch('http://foo.example.com:8000/hello', self.stop)
response = self.wait()
response.rethrow()
self.assertEqual(response.body, b'Hello world!')
class ResolveTimeoutTestCase(AsyncHTTPTestCase):
def setUp(self):
# Dummy Resolver subclass that never invokes its callback.
class BadResolver(Resolver):
def resolve(self, *args, **kwargs):
pass
super(ResolveTimeoutTestCase, self).setUp()
self.http_client = SimpleAsyncHTTPClient(
self.io_loop,
resolver=BadResolver())
def get_app(self):
return Application([url("/hello", HelloWorldHandler), ])
def test_resolve_timeout(self):
response = self.fetch('/hello', connect_timeout=0.1)
self.assertEqual(response.code, 599)
class MaxHeaderSizeTest(AsyncHTTPTestCase):
def get_app(self):
class SmallHeaders(RequestHandler):
def get(self):
self.set_header("X-Filler", "a" * 100)
self.write("ok")
class LargeHeaders(RequestHandler):
def get(self):
self.set_header("X-Filler", "a" * 1000)
self.write("ok")
return Application([('/small', SmallHeaders),
('/large', LargeHeaders)])
def get_http_client(self):
return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_header_size=1024)
def test_small_headers(self):
response = self.fetch('/small')
response.rethrow()
self.assertEqual(response.body, b'ok')
def test_large_headers(self):
with ExpectLog(gen_log, "Unsatisfiable read"):
response = self.fetch('/large')
self.assertEqual(response.code, 599)

View file

@ -0,0 +1,288 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
from tornado import gen
from tornado.log import app_log
from tornado.stack_context import (StackContext, wrap, NullContext, StackContextInconsistentError,
ExceptionStackContext, run_with_stack_context, _state)
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
from tornado.test.util import unittest
from tornado.web import asynchronous, Application, RequestHandler
import contextlib
import functools
import logging
class TestRequestHandler(RequestHandler):
def __init__(self, app, request, io_loop):
super(TestRequestHandler, self).__init__(app, request)
self.io_loop = io_loop
@asynchronous
def get(self):
logging.debug('in get()')
# call self.part2 without a self.async_callback wrapper. Its
# exception should still get thrown
self.io_loop.add_callback(self.part2)
def part2(self):
logging.debug('in part2()')
# Go through a third layer to make sure that contexts once restored
# are again passed on to future callbacks
self.io_loop.add_callback(self.part3)
def part3(self):
logging.debug('in part3()')
raise Exception('test exception')
def write_error(self, status_code, **kwargs):
if 'exc_info' in kwargs and str(kwargs['exc_info'][1]) == 'test exception':
self.write('got expected exception')
else:
self.write('unexpected failure')
class HTTPStackContextTest(AsyncHTTPTestCase):
def get_app(self):
return Application([('/', TestRequestHandler,
dict(io_loop=self.io_loop))])
def test_stack_context(self):
with ExpectLog(app_log, "Uncaught exception GET /"):
self.http_client.fetch(self.get_url('/'), self.handle_response)
self.wait()
self.assertEqual(self.response.code, 500)
self.assertTrue(b'got expected exception' in self.response.body)
def handle_response(self, response):
self.response = response
self.stop()
class StackContextTest(AsyncTestCase):
def setUp(self):
super(StackContextTest, self).setUp()
self.active_contexts = []
@contextlib.contextmanager
def context(self, name):
self.active_contexts.append(name)
yield
self.assertEqual(self.active_contexts.pop(), name)
# Simulates the effect of an asynchronous library that uses its own
# StackContext internally and then returns control to the application.
def test_exit_library_context(self):
def library_function(callback):
# capture the caller's context before introducing our own
callback = wrap(callback)
with StackContext(functools.partial(self.context, 'library')):
self.io_loop.add_callback(
functools.partial(library_inner_callback, callback))
def library_inner_callback(callback):
self.assertEqual(self.active_contexts[-2:],
['application', 'library'])
callback()
def final_callback():
# implementation detail: the full context stack at this point
# is ['application', 'library', 'application']. The 'library'
# context was not removed, but is no longer innermost so
# the application context takes precedence.
self.assertEqual(self.active_contexts[-1], 'application')
self.stop()
with StackContext(functools.partial(self.context, 'application')):
library_function(final_callback)
self.wait()
def test_deactivate(self):
deactivate_callbacks = []
def f1():
with StackContext(functools.partial(self.context, 'c1')) as c1:
deactivate_callbacks.append(c1)
self.io_loop.add_callback(f2)
def f2():
with StackContext(functools.partial(self.context, 'c2')) as c2:
deactivate_callbacks.append(c2)
self.io_loop.add_callback(f3)
def f3():
with StackContext(functools.partial(self.context, 'c3')) as c3:
deactivate_callbacks.append(c3)
self.io_loop.add_callback(f4)
def f4():
self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
deactivate_callbacks[1]()
# deactivating a context doesn't remove it immediately,
# but it will be missing from the next iteration
self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
self.io_loop.add_callback(f5)
def f5():
self.assertEqual(self.active_contexts, ['c1', 'c3'])
self.stop()
self.io_loop.add_callback(f1)
self.wait()
def test_deactivate_order(self):
# Stack context deactivation has separate logic for deactivation at
# the head and tail of the stack, so make sure it works in any order.
def check_contexts():
# Make sure that the full-context array and the exception-context
# linked lists are consistent with each other.
full_contexts, chain = _state.contexts
exception_contexts = []
while chain is not None:
exception_contexts.append(chain)
chain = chain.old_contexts[1]
self.assertEqual(list(reversed(full_contexts)), exception_contexts)
return list(self.active_contexts)
def make_wrapped_function():
"""Wraps a function in three stack contexts, and returns
the function along with the deactivation functions.
"""
# Remove the test's stack context to make sure we can cover
# the case where the last context is deactivated.
with NullContext():
partial = functools.partial
with StackContext(partial(self.context, 'c0')) as c0:
with StackContext(partial(self.context, 'c1')) as c1:
with StackContext(partial(self.context, 'c2')) as c2:
return (wrap(check_contexts), [c0, c1, c2])
# First make sure the test mechanism works without any deactivations
func, deactivate_callbacks = make_wrapped_function()
self.assertEqual(func(), ['c0', 'c1', 'c2'])
# Deactivate the tail
func, deactivate_callbacks = make_wrapped_function()
deactivate_callbacks[0]()
self.assertEqual(func(), ['c1', 'c2'])
# Deactivate the middle
func, deactivate_callbacks = make_wrapped_function()
deactivate_callbacks[1]()
self.assertEqual(func(), ['c0', 'c2'])
# Deactivate the head
func, deactivate_callbacks = make_wrapped_function()
deactivate_callbacks[2]()
self.assertEqual(func(), ['c0', 'c1'])
def test_isolation_nonempty(self):
# f2 and f3 are a chain of operations started in context c1.
# f2 is incidentally run under context c2, but that context should
# not be passed along to f3.
def f1():
with StackContext(functools.partial(self.context, 'c1')):
wrapped = wrap(f2)
with StackContext(functools.partial(self.context, 'c2')):
wrapped()
def f2():
self.assertIn('c1', self.active_contexts)
self.io_loop.add_callback(f3)
def f3():
self.assertIn('c1', self.active_contexts)
self.assertNotIn('c2', self.active_contexts)
self.stop()
self.io_loop.add_callback(f1)
self.wait()
def test_isolation_empty(self):
# Similar to test_isolation_nonempty, but here the f2/f3 chain
# is started without any context. Behavior should be equivalent
# to the nonempty case (although historically it was not)
def f1():
with NullContext():
wrapped = wrap(f2)
with StackContext(functools.partial(self.context, 'c2')):
wrapped()
def f2():
self.io_loop.add_callback(f3)
def f3():
self.assertNotIn('c2', self.active_contexts)
self.stop()
self.io_loop.add_callback(f1)
self.wait()
def test_yield_in_with(self):
@gen.engine
def f():
self.callback = yield gen.Callback('a')
with StackContext(functools.partial(self.context, 'c1')):
# This yield is a problem: the generator will be suspended
# and the StackContext's __exit__ is not called yet, so
# the context will be left on _state.contexts for anything
# that runs before the yield resolves.
yield gen.Wait('a')
with self.assertRaises(StackContextInconsistentError):
f()
self.wait()
# Cleanup: to avoid GC warnings (which for some reason only seem
# to show up on py33-asyncio), invoke the callback (which will do
# nothing since the gen.Runner is already finished) and delete it.
self.callback()
del self.callback
@gen_test
def test_yield_outside_with(self):
# This pattern avoids the problem in the previous test.
cb = yield gen.Callback('k1')
with StackContext(functools.partial(self.context, 'c1')):
self.io_loop.add_callback(cb)
yield gen.Wait('k1')
def test_yield_in_with_exception_stack_context(self):
# As above, but with ExceptionStackContext instead of StackContext.
@gen.engine
def f():
with ExceptionStackContext(lambda t, v, tb: False):
yield gen.Task(self.io_loop.add_callback)
with self.assertRaises(StackContextInconsistentError):
f()
self.wait()
@gen_test
def test_yield_outside_with_exception_stack_context(self):
cb = yield gen.Callback('k1')
with ExceptionStackContext(lambda t, v, tb: False):
self.io_loop.add_callback(cb)
yield gen.Wait('k1')
@gen_test
def test_run_with_stack_context(self):
@gen.coroutine
def f1():
self.assertEqual(self.active_contexts, ['c1'])
yield run_with_stack_context(
StackContext(functools.partial(self.context, 'c2')),
f2)
self.assertEqual(self.active_contexts, ['c1'])
@gen.coroutine
def f2():
self.assertEqual(self.active_contexts, ['c1', 'c2'])
yield gen.Task(self.io_loop.add_callback)
self.assertEqual(self.active_contexts, ['c1', 'c2'])
self.assertEqual(self.active_contexts, [])
yield run_with_stack_context(
StackContext(functools.partial(self.context, 'c1')),
f1)
self.assertEqual(self.active_contexts, [])
if __name__ == '__main__':
unittest.main()

View file

@ -0,0 +1 @@
this is the index

View file

@ -0,0 +1,2 @@
User-agent: *
Disallow: /

View file

@ -0,0 +1,278 @@
#!/usr/bin/env python
#
# Copyright 2014 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, with_statement
from contextlib import closing
import os
import socket
from tornado.concurrent import Future
from tornado.netutil import bind_sockets, Resolver
from tornado.tcpclient import TCPClient, _Connector
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, bind_unused_port, gen_test
from tornado.test.util import skipIfNoIPv6, unittest
# Fake address families for testing. Used in place of AF_INET
# and AF_INET6 because some installations do not have AF_INET6.
AF1, AF2 = 1, 2
class TestTCPServer(TCPServer):
def __init__(self, family):
super(TestTCPServer, self).__init__()
self.streams = []
sockets = bind_sockets(None, 'localhost', family)
self.add_sockets(sockets)
self.port = sockets[0].getsockname()[1]
def handle_stream(self, stream, address):
self.streams.append(stream)
def stop(self):
super(TestTCPServer, self).stop()
for stream in self.streams:
stream.close()
class TCPClientTest(AsyncTestCase):
def setUp(self):
super(TCPClientTest, self).setUp()
self.server = None
self.client = TCPClient()
def start_server(self, family):
if family == socket.AF_UNSPEC and 'TRAVIS' in os.environ:
self.skipTest("dual-stack servers often have port conflicts on travis")
self.server = TestTCPServer(family)
return self.server.port
def stop_server(self):
if self.server is not None:
self.server.stop()
self.server = None
def tearDown(self):
self.client.close()
self.stop_server()
super(TCPClientTest, self).tearDown()
def skipIfLocalhostV4(self):
Resolver().resolve('localhost', 0, callback=self.stop)
addrinfo = self.wait()
families = set(addr[0] for addr in addrinfo)
if socket.AF_INET6 not in families:
self.skipTest("localhost does not resolve to ipv6")
@gen_test
def do_test_connect(self, family, host):
port = self.start_server(family)
stream = yield self.client.connect(host, port)
with closing(stream):
stream.write(b"hello")
data = yield self.server.streams[0].read_bytes(5)
self.assertEqual(data, b"hello")
def test_connect_ipv4_ipv4(self):
self.do_test_connect(socket.AF_INET, '127.0.0.1')
def test_connect_ipv4_dual(self):
self.do_test_connect(socket.AF_INET, 'localhost')
@skipIfNoIPv6
def test_connect_ipv6_ipv6(self):
self.skipIfLocalhostV4()
self.do_test_connect(socket.AF_INET6, '::1')
@skipIfNoIPv6
def test_connect_ipv6_dual(self):
self.skipIfLocalhostV4()
if Resolver.configured_class().__name__.endswith('TwistedResolver'):
self.skipTest('TwistedResolver does not support multiple addresses')
self.do_test_connect(socket.AF_INET6, 'localhost')
def test_connect_unspec_ipv4(self):
self.do_test_connect(socket.AF_UNSPEC, '127.0.0.1')
@skipIfNoIPv6
def test_connect_unspec_ipv6(self):
self.skipIfLocalhostV4()
self.do_test_connect(socket.AF_UNSPEC, '::1')
def test_connect_unspec_dual(self):
self.do_test_connect(socket.AF_UNSPEC, 'localhost')
@gen_test
def test_refused_ipv4(self):
sock, port = bind_unused_port()
sock.close()
with self.assertRaises(IOError):
yield self.client.connect('127.0.0.1', port)
class TestConnectorSplit(unittest.TestCase):
def test_one_family(self):
# These addresses aren't in the right format, but split doesn't care.
primary, secondary = _Connector.split(
[(AF1, 'a'),
(AF1, 'b')])
self.assertEqual(primary, [(AF1, 'a'),
(AF1, 'b')])
self.assertEqual(secondary, [])
def test_mixed(self):
primary, secondary = _Connector.split(
[(AF1, 'a'),
(AF2, 'b'),
(AF1, 'c'),
(AF2, 'd')])
self.assertEqual(primary, [(AF1, 'a'), (AF1, 'c')])
self.assertEqual(secondary, [(AF2, 'b'), (AF2, 'd')])
class ConnectorTest(AsyncTestCase):
class FakeStream(object):
def __init__(self):
self.closed = False
def close(self):
self.closed = True
def setUp(self):
super(ConnectorTest, self).setUp()
self.connect_futures = {}
self.streams = {}
self.addrinfo = [(AF1, 'a'), (AF1, 'b'),
(AF2, 'c'), (AF2, 'd')]
def tearDown(self):
# Unless explicitly checked (and popped) in the test, we shouldn't
# be closing any streams
for stream in self.streams.values():
self.assertFalse(stream.closed)
super(ConnectorTest, self).tearDown()
def create_stream(self, af, addr):
future = Future()
self.connect_futures[(af, addr)] = future
return future
def assert_pending(self, *keys):
self.assertEqual(sorted(self.connect_futures.keys()), sorted(keys))
def resolve_connect(self, af, addr, success):
future = self.connect_futures.pop((af, addr))
if success:
self.streams[addr] = ConnectorTest.FakeStream()
future.set_result(self.streams[addr])
else:
future.set_exception(IOError())
def start_connect(self, addrinfo):
conn = _Connector(addrinfo, self.io_loop, self.create_stream)
# Give it a huge timeout; we'll trigger timeouts manually.
future = conn.start(3600)
return conn, future
def test_immediate_success(self):
conn, future = self.start_connect(self.addrinfo)
self.assertEqual(list(self.connect_futures.keys()),
[(AF1, 'a')])
self.resolve_connect(AF1, 'a', True)
self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
def test_immediate_failure(self):
# Fail with just one address.
conn, future = self.start_connect([(AF1, 'a')])
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assertRaises(IOError, future.result)
def test_one_family_second_try(self):
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'))
self.resolve_connect(AF1, 'b', True)
self.assertEqual(future.result(), (AF1, 'b', self.streams['b']))
def test_one_family_second_try_failure(self):
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'))
self.resolve_connect(AF1, 'b', False)
self.assertRaises(IOError, future.result)
def test_one_family_second_try_timeout(self):
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
self.assert_pending((AF1, 'a'))
# trigger the timeout while the first lookup is pending;
# nothing happens.
conn.on_timeout()
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'))
self.resolve_connect(AF1, 'b', True)
self.assertEqual(future.result(), (AF1, 'b', self.streams['b']))
def test_two_families_immediate_failure(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'), (AF2, 'c'))
self.resolve_connect(AF1, 'b', False)
self.resolve_connect(AF2, 'c', True)
self.assertEqual(future.result(), (AF2, 'c', self.streams['c']))
def test_two_families_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, 'a'))
conn.on_timeout()
self.assert_pending((AF1, 'a'), (AF2, 'c'))
self.resolve_connect(AF2, 'c', True)
self.assertEqual(future.result(), (AF2, 'c', self.streams['c']))
# resolving 'a' after the connection has completed doesn't start 'b'
self.resolve_connect(AF1, 'a', False)
self.assert_pending()
def test_success_after_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, 'a'))
conn.on_timeout()
self.assert_pending((AF1, 'a'), (AF2, 'c'))
self.resolve_connect(AF1, 'a', True)
self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
# resolving 'c' after completion closes the connection.
self.resolve_connect(AF2, 'c', True)
self.assertTrue(self.streams.pop('c').closed)
def test_all_fail(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, 'a'))
conn.on_timeout()
self.assert_pending((AF1, 'a'), (AF2, 'c'))
self.resolve_connect(AF2, 'c', False)
self.assert_pending((AF1, 'a'), (AF2, 'd'))
self.resolve_connect(AF2, 'd', False)
# one queue is now empty
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'))
self.assertFalse(future.done())
self.resolve_connect(AF1, 'b', False)
self.assertRaises(IOError, future.result)

View file

@ -0,0 +1,412 @@
from __future__ import absolute_import, division, print_function, with_statement
import os
import sys
import traceback
from tornado.escape import utf8, native_str, to_unicode
from tornado.template import Template, DictLoader, ParseError, Loader
from tornado.test.util import unittest
from tornado.util import u, bytes_type, ObjectDict, unicode_type
class TemplateTest(unittest.TestCase):
def test_simple(self):
template = Template("Hello {{ name }}!")
self.assertEqual(template.generate(name="Ben"),
b"Hello Ben!")
def test_bytes(self):
template = Template("Hello {{ name }}!")
self.assertEqual(template.generate(name=utf8("Ben")),
b"Hello Ben!")
def test_expressions(self):
template = Template("2 + 2 = {{ 2 + 2 }}")
self.assertEqual(template.generate(), b"2 + 2 = 4")
def test_comment(self):
template = Template("Hello{# TODO i18n #} {{ name }}!")
self.assertEqual(template.generate(name=utf8("Ben")),
b"Hello Ben!")
def test_include(self):
loader = DictLoader({
"index.html": '{% include "header.html" %}\nbody text',
"header.html": "header text",
})
self.assertEqual(loader.load("index.html").generate(),
b"header text\nbody text")
def test_extends(self):
loader = DictLoader({
"base.html": """\
<title>{% block title %}default title{% end %}</title>
<body>{% block body %}default body{% end %}</body>
""",
"page.html": """\
{% extends "base.html" %}
{% block title %}page title{% end %}
{% block body %}page body{% end %}
""",
})
self.assertEqual(loader.load("page.html").generate(),
b"<title>page title</title>\n<body>page body</body>\n")
def test_relative_load(self):
loader = DictLoader({
"a/1.html": "{% include '2.html' %}",
"a/2.html": "{% include '../b/3.html' %}",
"b/3.html": "ok",
})
self.assertEqual(loader.load("a/1.html").generate(),
b"ok")
def test_escaping(self):
self.assertRaises(ParseError, lambda: Template("{{"))
self.assertRaises(ParseError, lambda: Template("{%"))
self.assertEqual(Template("{{!").generate(), b"{{")
self.assertEqual(Template("{%!").generate(), b"{%")
self.assertEqual(Template("{{ 'expr' }} {{!jquery expr}}").generate(),
b"expr {{jquery expr}}")
def test_unicode_template(self):
template = Template(utf8(u("\u00e9")))
self.assertEqual(template.generate(), utf8(u("\u00e9")))
def test_unicode_literal_expression(self):
# Unicode literals should be usable in templates. Note that this
# test simulates unicode characters appearing directly in the
# template file (with utf8 encoding), i.e. \u escapes would not
# be used in the template file itself.
if str is unicode_type:
# python 3 needs a different version of this test since
# 2to3 doesn't run on template internals
template = Template(utf8(u('{{ "\u00e9" }}')))
else:
template = Template(utf8(u('{{ u"\u00e9" }}')))
self.assertEqual(template.generate(), utf8(u("\u00e9")))
def test_custom_namespace(self):
loader = DictLoader({"test.html": "{{ inc(5) }}"}, namespace={"inc": lambda x: x + 1})
self.assertEqual(loader.load("test.html").generate(), b"6")
def test_apply(self):
def upper(s):
return s.upper()
template = Template(utf8("{% apply upper %}foo{% end %}"))
self.assertEqual(template.generate(upper=upper), b"FOO")
def test_unicode_apply(self):
def upper(s):
return to_unicode(s).upper()
template = Template(utf8(u("{% apply upper %}foo \u00e9{% end %}")))
self.assertEqual(template.generate(upper=upper), utf8(u("FOO \u00c9")))
def test_bytes_apply(self):
def upper(s):
return utf8(to_unicode(s).upper())
template = Template(utf8(u("{% apply upper %}foo \u00e9{% end %}")))
self.assertEqual(template.generate(upper=upper), utf8(u("FOO \u00c9")))
def test_if(self):
template = Template(utf8("{% if x > 4 %}yes{% else %}no{% end %}"))
self.assertEqual(template.generate(x=5), b"yes")
self.assertEqual(template.generate(x=3), b"no")
def test_if_empty_body(self):
template = Template(utf8("{% if True %}{% else %}{% end %}"))
self.assertEqual(template.generate(), b"")
def test_try(self):
template = Template(utf8("""{% try %}
try{% set y = 1/x %}
{% except %}-except
{% else %}-else
{% finally %}-finally
{% end %}"""))
self.assertEqual(template.generate(x=1), b"\ntry\n-else\n-finally\n")
self.assertEqual(template.generate(x=0), b"\ntry-except\n-finally\n")
def test_comment_directive(self):
template = Template(utf8("{% comment blah blah %}foo"))
self.assertEqual(template.generate(), b"foo")
def test_break_continue(self):
template = Template(utf8("""\
{% for i in range(10) %}
{% if i == 2 %}
{% continue %}
{% end %}
{{ i }}
{% if i == 6 %}
{% break %}
{% end %}
{% end %}"""))
result = template.generate()
# remove extraneous whitespace
result = b''.join(result.split())
self.assertEqual(result, b"013456")
def test_break_outside_loop(self):
try:
Template(utf8("{% break %}"))
raise Exception("Did not get expected exception")
except ParseError:
pass
def test_break_in_apply(self):
# This test verifies current behavior, although of course it would
# be nice if apply didn't cause seemingly unrelated breakage
try:
Template(utf8("{% for i in [] %}{% apply foo %}{% break %}{% end %}{% end %}"))
raise Exception("Did not get expected exception")
except ParseError:
pass
@unittest.skipIf(sys.version_info >= division.getMandatoryRelease(),
'no testable future imports')
def test_no_inherit_future(self):
# This file has from __future__ import division...
self.assertEqual(1 / 2, 0.5)
# ...but the template doesn't
template = Template('{{ 1 / 2 }}')
self.assertEqual(template.generate(), '0')
class StackTraceTest(unittest.TestCase):
def test_error_line_number_expression(self):
loader = DictLoader({"test.html": """one
two{{1/0}}
three
"""})
try:
loader.load("test.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# test.html:2" in traceback.format_exc())
def test_error_line_number_directive(self):
loader = DictLoader({"test.html": """one
two{%if 1/0%}
three{%end%}
"""})
try:
loader.load("test.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# test.html:2" in traceback.format_exc())
def test_error_line_number_module(self):
loader = DictLoader({
"base.html": "{% module Template('sub.html') %}",
"sub.html": "{{1/0}}",
}, namespace={"_tt_modules": ObjectDict({"Template": lambda path, **kwargs: loader.load(path).generate(**kwargs)})})
try:
loader.load("base.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
exc_stack = traceback.format_exc()
self.assertTrue('# base.html:1' in exc_stack)
self.assertTrue('# sub.html:1' in exc_stack)
def test_error_line_number_include(self):
loader = DictLoader({
"base.html": "{% include 'sub.html' %}",
"sub.html": "{{1/0}}",
})
try:
loader.load("base.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# sub.html:1 (via base.html:1)" in
traceback.format_exc())
def test_error_line_number_extends_base_error(self):
loader = DictLoader({
"base.html": "{{1/0}}",
"sub.html": "{% extends 'base.html' %}",
})
try:
loader.load("sub.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
exc_stack = traceback.format_exc()
self.assertTrue("# base.html:1" in exc_stack)
def test_error_line_number_extends_sub_error(self):
loader = DictLoader({
"base.html": "{% block 'block' %}{% end %}",
"sub.html": """
{% extends 'base.html' %}
{% block 'block' %}
{{1/0}}
{% end %}
"""})
try:
loader.load("sub.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# sub.html:4 (via base.html:1)" in
traceback.format_exc())
def test_multi_includes(self):
loader = DictLoader({
"a.html": "{% include 'b.html' %}",
"b.html": "{% include 'c.html' %}",
"c.html": "{{1/0}}",
})
try:
loader.load("a.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# c.html:1 (via b.html:1, a.html:1)" in
traceback.format_exc())
class AutoEscapeTest(unittest.TestCase):
def setUp(self):
self.templates = {
"escaped.html": "{% autoescape xhtml_escape %}{{ name }}",
"unescaped.html": "{% autoescape None %}{{ name }}",
"default.html": "{{ name }}",
"include.html": """\
escaped: {% include 'escaped.html' %}
unescaped: {% include 'unescaped.html' %}
default: {% include 'default.html' %}
""",
"escaped_block.html": """\
{% autoescape xhtml_escape %}\
{% block name %}base: {{ name }}{% end %}""",
"unescaped_block.html": """\
{% autoescape None %}\
{% block name %}base: {{ name }}{% end %}""",
# Extend a base template with different autoescape policy,
# with and without overriding the base's blocks
"escaped_extends_unescaped.html": """\
{% autoescape xhtml_escape %}\
{% extends "unescaped_block.html" %}""",
"escaped_overrides_unescaped.html": """\
{% autoescape xhtml_escape %}\
{% extends "unescaped_block.html" %}\
{% block name %}extended: {{ name }}{% end %}""",
"unescaped_extends_escaped.html": """\
{% autoescape None %}\
{% extends "escaped_block.html" %}""",
"unescaped_overrides_escaped.html": """\
{% autoescape None %}\
{% extends "escaped_block.html" %}\
{% block name %}extended: {{ name }}{% end %}""",
"raw_expression.html": """\
{% autoescape xhtml_escape %}\
expr: {{ name }}
raw: {% raw name %}""",
}
def test_default_off(self):
loader = DictLoader(self.templates, autoescape=None)
name = "Bobby <table>s"
self.assertEqual(loader.load("escaped.html").generate(name=name),
b"Bobby &lt;table&gt;s")
self.assertEqual(loader.load("unescaped.html").generate(name=name),
b"Bobby <table>s")
self.assertEqual(loader.load("default.html").generate(name=name),
b"Bobby <table>s")
self.assertEqual(loader.load("include.html").generate(name=name),
b"escaped: Bobby &lt;table&gt;s\n"
b"unescaped: Bobby <table>s\n"
b"default: Bobby <table>s\n")
def test_default_on(self):
loader = DictLoader(self.templates, autoescape="xhtml_escape")
name = "Bobby <table>s"
self.assertEqual(loader.load("escaped.html").generate(name=name),
b"Bobby &lt;table&gt;s")
self.assertEqual(loader.load("unescaped.html").generate(name=name),
b"Bobby <table>s")
self.assertEqual(loader.load("default.html").generate(name=name),
b"Bobby &lt;table&gt;s")
self.assertEqual(loader.load("include.html").generate(name=name),
b"escaped: Bobby &lt;table&gt;s\n"
b"unescaped: Bobby <table>s\n"
b"default: Bobby &lt;table&gt;s\n")
def test_unextended_block(self):
loader = DictLoader(self.templates)
name = "<script>"
self.assertEqual(loader.load("escaped_block.html").generate(name=name),
b"base: &lt;script&gt;")
self.assertEqual(loader.load("unescaped_block.html").generate(name=name),
b"base: <script>")
def test_extended_block(self):
loader = DictLoader(self.templates)
def render(name):
return loader.load(name).generate(name="<script>")
self.assertEqual(render("escaped_extends_unescaped.html"),
b"base: <script>")
self.assertEqual(render("escaped_overrides_unescaped.html"),
b"extended: &lt;script&gt;")
self.assertEqual(render("unescaped_extends_escaped.html"),
b"base: &lt;script&gt;")
self.assertEqual(render("unescaped_overrides_escaped.html"),
b"extended: <script>")
def test_raw_expression(self):
loader = DictLoader(self.templates)
def render(name):
return loader.load(name).generate(name='<>&"')
self.assertEqual(render("raw_expression.html"),
b"expr: &lt;&gt;&amp;&quot;\n"
b"raw: <>&\"")
def test_custom_escape(self):
loader = DictLoader({"foo.py":
"{% autoescape py_escape %}s = {{ name }}\n"})
def py_escape(s):
self.assertEqual(type(s), bytes_type)
return repr(native_str(s))
def render(template, name):
return loader.load(template).generate(py_escape=py_escape,
name=name)
self.assertEqual(render("foo.py", "<html>"),
b"s = '<html>'\n")
self.assertEqual(render("foo.py", "';sys.exit()"),
b"""s = "';sys.exit()"\n""")
self.assertEqual(render("foo.py", ["not a string"]),
b"""s = "['not a string']"\n""")
def test_minimize_whitespace(self):
# Whitespace including newlines is allowed within template tags
# and directives, and this is one way to avoid long lines while
# keeping extra whitespace out of the rendered output.
loader = DictLoader({'foo.txt': """\
{% for i in items
%}{% if i > 0 %}, {% end %}{#
#}{{i
}}{% end
%}""",
})
self.assertEqual(loader.load("foo.txt").generate(items=range(5)),
b"0, 1, 2, 3, 4")
class TemplateLoaderTest(unittest.TestCase):
def setUp(self):
self.loader = Loader(os.path.join(os.path.dirname(__file__), "templates"))
def test_utf8_in_file(self):
tmpl = self.loader.load("utf8.html")
result = tmpl.generate()
self.assertEqual(to_unicode(result).strip(), u("H\u00e9llo"))

View file

@ -0,0 +1 @@
Héllo

View file

@ -0,0 +1,15 @@
-----BEGIN CERTIFICATE-----
MIICSDCCAbGgAwIBAgIJAN1oTowzMbkzMA0GCSqGSIb3DQEBBQUAMD0xCzAJBgNV
BAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRkwFwYDVQQKDBBUb3JuYWRvIFdl
YiBUZXN0MB4XDTEwMDgyNTE4MjQ0NFoXDTIwMDgyMjE4MjQ0NFowPTELMAkGA1UE
BhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExGTAXBgNVBAoMEFRvcm5hZG8gV2Vi
IFRlc3QwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBALirW3mX4jbdFse2aZwW
zszCJ1IsRDrzALpbvMYLLbIZqo+Z8v5aERKTRQpXFqGaZyY+tdwYy7X7YXcLtKqv
jnw/MSeIaqkw5pROKz5aR0nkPLvcTmhJVLVPCLc8dFnIlu8aC9TrDhr90P+PzU39
UG7zLweA9zXKBuW3Tjo5dMP3AgMBAAGjUDBOMB0GA1UdDgQWBBRhJjMBYrzddCFr
/0vvPyHMeqgo0TAfBgNVHSMEGDAWgBRhJjMBYrzddCFr/0vvPyHMeqgo0TAMBgNV
HRMEBTADAQH/MA0GCSqGSIb3DQEBBQUAA4GBAGP6GaxSfb21bikcqaK3ZKCC1sRJ
tiCuvJZbBUFUCAzl05dYUfJZim/oWK+GqyUkUB8ciYivUNnn9OtS7DnlTgT2ws2e
lNgn5cuFXoAGcHXzVlHG3yoywYBf3y0Dn20uzrlLXUWJAzoSLOt2LTaXvwlgm7hF
W1q8SQ6UBshRw2X0
-----END CERTIFICATE-----

View file

@ -0,0 +1,16 @@
-----BEGIN PRIVATE KEY-----
MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALirW3mX4jbdFse2
aZwWzszCJ1IsRDrzALpbvMYLLbIZqo+Z8v5aERKTRQpXFqGaZyY+tdwYy7X7YXcL
tKqvjnw/MSeIaqkw5pROKz5aR0nkPLvcTmhJVLVPCLc8dFnIlu8aC9TrDhr90P+P
zU39UG7zLweA9zXKBuW3Tjo5dMP3AgMBAAECgYEAiygNaWYrf95AcUQi9w00zpUr
nj9fNvCwxr2kVbRMvd2balS/CC4EmXPCXdVcZ3B7dBVjYzSIJV0Fh/iZLtnVysD9
fcNMZ+Cz71b/T0ItsNYOsJk0qUVyP52uqsqkNppIPJsD19C+ZeMLZj6iEiylZyl8
2U16c/kVIjER63mUEGkCQQDayQOTGPJrKHqPAkUqzeJkfvHH2yCf+cySU+w6ezyr
j9yxcq8aZoLusCebDVT+kz7RqnD5JePFvB38cMuepYBLAkEA2BTFdZx30f4moPNv
JlXlPNJMUTUzsXG7n4vNc+18O5ous0NGQII8jZWrIcTrP8wiP9fF3JwUsKrJhcBn
xRs3hQJBAIDUgz1YIE+HW3vgi1gkOh6RPdBAsVpiXtr/fggFz3j60qrO7FswaAMj
SX8c/6KUlBYkNjgP3qruFf4zcUNvEzcCQQCaioCPFVE9ByBpjLG6IUTKsz2R9xL5
nfYqrbpLZ1aq6iLsYvkjugHE4X57sHLwNfdo4dHJbnf9wqhO2MVe25BhAkBdKYpY
7OKc/2mmMbJDhVBgoixz/muN/5VjdfbvVY48naZkJF1p1tmogqPC5F1jPCS4rM+S
FfPJIHRNEn2oktw5
-----END PRIVATE KEY-----

View file

@ -0,0 +1,220 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
from tornado import gen, ioloop
from tornado.testing import AsyncTestCase, gen_test
from tornado.test.util import unittest
import contextlib
import os
import traceback
@contextlib.contextmanager
def set_environ(name, value):
old_value = os.environ.get(name)
os.environ[name] = value
try:
yield
finally:
if old_value is None:
del os.environ[name]
else:
os.environ[name] = old_value
class AsyncTestCaseTest(AsyncTestCase):
def test_exception_in_callback(self):
self.io_loop.add_callback(lambda: 1 / 0)
try:
self.wait()
self.fail("did not get expected exception")
except ZeroDivisionError:
pass
def test_wait_timeout(self):
time = self.io_loop.time
# Accept default 5-second timeout, no error
self.io_loop.add_timeout(time() + 0.01, self.stop)
self.wait()
# Timeout passed to wait()
self.io_loop.add_timeout(time() + 1, self.stop)
with self.assertRaises(self.failureException):
self.wait(timeout=0.01)
# Timeout set with environment variable
self.io_loop.add_timeout(time() + 1, self.stop)
with set_environ('ASYNC_TEST_TIMEOUT', '0.01'):
with self.assertRaises(self.failureException):
self.wait()
def test_subsequent_wait_calls(self):
"""
This test makes sure that a second call to wait()
clears the first timeout.
"""
self.io_loop.add_timeout(self.io_loop.time() + 0.01, self.stop)
self.wait(timeout=0.02)
self.io_loop.add_timeout(self.io_loop.time() + 0.03, self.stop)
self.wait(timeout=0.15)
class AsyncTestCaseWrapperTest(unittest.TestCase):
def test_undecorated_generator(self):
class Test(AsyncTestCase):
def test_gen(self):
yield
test = Test('test_gen')
result = unittest.TestResult()
test.run(result)
self.assertEqual(len(result.errors), 1)
self.assertIn("should be decorated", result.errors[0][1])
def test_undecorated_generator_with_skip(self):
class Test(AsyncTestCase):
@unittest.skip("don't run this")
def test_gen(self):
yield
test = Test('test_gen')
result = unittest.TestResult()
test.run(result)
self.assertEqual(len(result.errors), 0)
self.assertEqual(len(result.skipped), 1)
def test_other_return(self):
class Test(AsyncTestCase):
def test_other_return(self):
return 42
test = Test('test_other_return')
result = unittest.TestResult()
test.run(result)
self.assertEqual(len(result.errors), 1)
self.assertIn("Return value from test method ignored", result.errors[0][1])
class SetUpTearDownTest(unittest.TestCase):
def test_set_up_tear_down(self):
"""
This test makes sure that AsyncTestCase calls super methods for
setUp and tearDown.
InheritBoth is a subclass of both AsyncTestCase and
SetUpTearDown, with the ordering so that the super of
AsyncTestCase will be SetUpTearDown.
"""
events = []
result = unittest.TestResult()
class SetUpTearDown(unittest.TestCase):
def setUp(self):
events.append('setUp')
def tearDown(self):
events.append('tearDown')
class InheritBoth(AsyncTestCase, SetUpTearDown):
def test(self):
events.append('test')
InheritBoth('test').run(result)
expected = ['setUp', 'test', 'tearDown']
self.assertEqual(expected, events)
class GenTest(AsyncTestCase):
def setUp(self):
super(GenTest, self).setUp()
self.finished = False
def tearDown(self):
self.assertTrue(self.finished)
super(GenTest, self).tearDown()
@gen_test
def test_sync(self):
self.finished = True
@gen_test
def test_async(self):
yield gen.Task(self.io_loop.add_callback)
self.finished = True
def test_timeout(self):
# Set a short timeout and exceed it.
@gen_test(timeout=0.1)
def test(self):
yield gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1)
# This can't use assertRaises because we need to inspect the
# exc_info triple (and not just the exception object)
try:
test(self)
self.fail("did not get expected exception")
except ioloop.TimeoutError:
# The stack trace should blame the add_timeout line, not just
# unrelated IOLoop/testing internals.
self.assertIn(
"gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1)",
traceback.format_exc())
self.finished = True
def test_no_timeout(self):
# A test that does not exceed its timeout should succeed.
@gen_test(timeout=1)
def test(self):
time = self.io_loop.time
yield gen.Task(self.io_loop.add_timeout, time() + 0.1)
test(self)
self.finished = True
def test_timeout_environment_variable(self):
@gen_test(timeout=0.5)
def test_long_timeout(self):
time = self.io_loop.time
yield gen.Task(self.io_loop.add_timeout, time() + 0.25)
# Uses provided timeout of 0.5 seconds, doesn't time out.
with set_environ('ASYNC_TEST_TIMEOUT', '0.1'):
test_long_timeout(self)
self.finished = True
def test_no_timeout_environment_variable(self):
@gen_test(timeout=0.01)
def test_short_timeout(self):
time = self.io_loop.time
yield gen.Task(self.io_loop.add_timeout, time() + 1)
# Uses environment-variable timeout of 0.1, times out.
with set_environ('ASYNC_TEST_TIMEOUT', '0.1'):
with self.assertRaises(ioloop.TimeoutError):
test_short_timeout(self)
self.finished = True
def test_with_method_args(self):
@gen_test
def test_with_args(self, *args):
self.assertEqual(args, ('test',))
yield gen.Task(self.io_loop.add_callback)
test_with_args(self, 'test')
self.finished = True
def test_with_method_kwargs(self):
@gen_test
def test_with_kwargs(self, **kwargs):
self.assertDictEqual(kwargs, {'test': 'test'})
yield gen.Task(self.io_loop.add_callback)
test_with_kwargs(self, test='test')
self.finished = True
if __name__ == '__main__':
unittest.main()

View file

@ -0,0 +1,618 @@
# Author: Ovidiu Predescu
# Date: July 2011
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Unittest for the twisted-style reactor.
"""
from __future__ import absolute_import, division, print_function, with_statement
import os
import shutil
import signal
import tempfile
import threading
try:
import fcntl
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IReadDescriptor, IWriteDescriptor
from twisted.internet.protocol import Protocol
from twisted.python import log
from tornado.platform.twisted import TornadoReactor, TwistedIOLoop
from zope.interface import implementer
have_twisted = True
except ImportError:
have_twisted = False
# The core of Twisted 12.3.0 is available on python 3, but twisted.web is not
# so test for it separately.
try:
from twisted.web.client import Agent
from twisted.web.resource import Resource
from twisted.web.server import Site
have_twisted_web = True
except ImportError:
have_twisted_web = False
try:
import thread # py2
except ImportError:
import _thread as thread # py3
from tornado.httpclient import AsyncHTTPClient
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from tornado.platform.auto import set_close_exec
from tornado.platform.select import SelectIOLoop
from tornado.testing import bind_unused_port
from tornado.test.util import unittest
from tornado.util import import_object
from tornado.web import RequestHandler, Application
skipIfNoTwisted = unittest.skipUnless(have_twisted,
"twisted module not present")
def save_signal_handlers():
saved = {}
for sig in [signal.SIGINT, signal.SIGTERM, signal.SIGCHLD]:
saved[sig] = signal.getsignal(sig)
if "twisted" in repr(saved):
if not issubclass(IOLoop.configured_class(), TwistedIOLoop):
# when the global ioloop is twisted, we expect the signal
# handlers to be installed. Otherwise, it means we're not
# cleaning up after twisted properly.
raise Exception("twisted signal handlers already installed")
return saved
def restore_signal_handlers(saved):
for sig, handler in saved.items():
signal.signal(sig, handler)
class ReactorTestCase(unittest.TestCase):
def setUp(self):
self._saved_signals = save_signal_handlers()
self._io_loop = IOLoop()
self._reactor = TornadoReactor(self._io_loop)
def tearDown(self):
self._io_loop.close(all_fds=True)
restore_signal_handlers(self._saved_signals)
@skipIfNoTwisted
class ReactorWhenRunningTest(ReactorTestCase):
def test_whenRunning(self):
self._whenRunningCalled = False
self._anotherWhenRunningCalled = False
self._reactor.callWhenRunning(self.whenRunningCallback)
self._reactor.run()
self.assertTrue(self._whenRunningCalled)
self.assertTrue(self._anotherWhenRunningCalled)
def whenRunningCallback(self):
self._whenRunningCalled = True
self._reactor.callWhenRunning(self.anotherWhenRunningCallback)
self._reactor.stop()
def anotherWhenRunningCallback(self):
self._anotherWhenRunningCalled = True
@skipIfNoTwisted
class ReactorCallLaterTest(ReactorTestCase):
def test_callLater(self):
self._laterCalled = False
self._now = self._reactor.seconds()
self._timeout = 0.001
dc = self._reactor.callLater(self._timeout, self.callLaterCallback)
self.assertEqual(self._reactor.getDelayedCalls(), [dc])
self._reactor.run()
self.assertTrue(self._laterCalled)
self.assertTrue(self._called - self._now > self._timeout)
self.assertEqual(self._reactor.getDelayedCalls(), [])
def callLaterCallback(self):
self._laterCalled = True
self._called = self._reactor.seconds()
self._reactor.stop()
@skipIfNoTwisted
class ReactorTwoCallLaterTest(ReactorTestCase):
def test_callLater(self):
self._later1Called = False
self._later2Called = False
self._now = self._reactor.seconds()
self._timeout1 = 0.0005
dc1 = self._reactor.callLater(self._timeout1, self.callLaterCallback1)
self._timeout2 = 0.001
dc2 = self._reactor.callLater(self._timeout2, self.callLaterCallback2)
self.assertTrue(self._reactor.getDelayedCalls() == [dc1, dc2] or
self._reactor.getDelayedCalls() == [dc2, dc1])
self._reactor.run()
self.assertTrue(self._later1Called)
self.assertTrue(self._later2Called)
self.assertTrue(self._called1 - self._now > self._timeout1)
self.assertTrue(self._called2 - self._now > self._timeout2)
self.assertEqual(self._reactor.getDelayedCalls(), [])
def callLaterCallback1(self):
self._later1Called = True
self._called1 = self._reactor.seconds()
def callLaterCallback2(self):
self._later2Called = True
self._called2 = self._reactor.seconds()
self._reactor.stop()
@skipIfNoTwisted
class ReactorCallFromThreadTest(ReactorTestCase):
def setUp(self):
super(ReactorCallFromThreadTest, self).setUp()
self._mainThread = thread.get_ident()
def tearDown(self):
self._thread.join()
super(ReactorCallFromThreadTest, self).tearDown()
def _newThreadRun(self):
self.assertNotEqual(self._mainThread, thread.get_ident())
if hasattr(self._thread, 'ident'): # new in python 2.6
self.assertEqual(self._thread.ident, thread.get_ident())
self._reactor.callFromThread(self._fnCalledFromThread)
def _fnCalledFromThread(self):
self.assertEqual(self._mainThread, thread.get_ident())
self._reactor.stop()
def _whenRunningCallback(self):
self._thread = threading.Thread(target=self._newThreadRun)
self._thread.start()
def testCallFromThread(self):
self._reactor.callWhenRunning(self._whenRunningCallback)
self._reactor.run()
@skipIfNoTwisted
class ReactorCallInThread(ReactorTestCase):
def setUp(self):
super(ReactorCallInThread, self).setUp()
self._mainThread = thread.get_ident()
def _fnCalledInThread(self, *args, **kwargs):
self.assertNotEqual(thread.get_ident(), self._mainThread)
self._reactor.callFromThread(lambda: self._reactor.stop())
def _whenRunningCallback(self):
self._reactor.callInThread(self._fnCalledInThread)
def testCallInThread(self):
self._reactor.callWhenRunning(self._whenRunningCallback)
self._reactor.run()
class Reader(object):
def __init__(self, fd, callback):
self._fd = fd
self._callback = callback
def logPrefix(self):
return "Reader"
def close(self):
self._fd.close()
def fileno(self):
return self._fd.fileno()
def readConnectionLost(self, reason):
self.close()
def connectionLost(self, reason):
self.close()
def doRead(self):
self._callback(self._fd)
if have_twisted:
Reader = implementer(IReadDescriptor)(Reader)
class Writer(object):
def __init__(self, fd, callback):
self._fd = fd
self._callback = callback
def logPrefix(self):
return "Writer"
def close(self):
self._fd.close()
def fileno(self):
return self._fd.fileno()
def connectionLost(self, reason):
self.close()
def doWrite(self):
self._callback(self._fd)
if have_twisted:
Writer = implementer(IWriteDescriptor)(Writer)
@skipIfNoTwisted
class ReactorReaderWriterTest(ReactorTestCase):
def _set_nonblocking(self, fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
def setUp(self):
super(ReactorReaderWriterTest, self).setUp()
r, w = os.pipe()
self._set_nonblocking(r)
self._set_nonblocking(w)
set_close_exec(r)
set_close_exec(w)
self._p1 = os.fdopen(r, "rb", 0)
self._p2 = os.fdopen(w, "wb", 0)
def tearDown(self):
super(ReactorReaderWriterTest, self).tearDown()
self._p1.close()
self._p2.close()
def _testReadWrite(self):
"""
In this test the writer writes an 'x' to its fd. The reader
reads it, check the value and ends the test.
"""
self.shouldWrite = True
def checkReadInput(fd):
self.assertEquals(fd.read(1), b'x')
self._reactor.stop()
def writeOnce(fd):
if self.shouldWrite:
self.shouldWrite = False
fd.write(b'x')
self._reader = Reader(self._p1, checkReadInput)
self._writer = Writer(self._p2, writeOnce)
self._reactor.addWriter(self._writer)
# Test that adding the reader twice adds it only once to
# IOLoop.
self._reactor.addReader(self._reader)
self._reactor.addReader(self._reader)
def testReadWrite(self):
self._reactor.callWhenRunning(self._testReadWrite)
self._reactor.run()
def _testNoWriter(self):
"""
In this test we have no writer. Make sure the reader doesn't
read anything.
"""
def checkReadInput(fd):
self.fail("Must not be called.")
def stopTest():
# Close the writer here since the IOLoop doesn't know
# about it.
self._writer.close()
self._reactor.stop()
self._reader = Reader(self._p1, checkReadInput)
# We create a writer, but it should never be invoked.
self._writer = Writer(self._p2, lambda fd: fd.write('x'))
# Test that adding and removing the writer leaves us with no writer.
self._reactor.addWriter(self._writer)
self._reactor.removeWriter(self._writer)
# Test that adding and removing the reader doesn't cause
# unintended effects.
self._reactor.addReader(self._reader)
# Wake up after a moment and stop the test
self._reactor.callLater(0.001, stopTest)
def testNoWriter(self):
self._reactor.callWhenRunning(self._testNoWriter)
self._reactor.run()
# Test various combinations of twisted and tornado http servers,
# http clients, and event loop interfaces.
@skipIfNoTwisted
@unittest.skipIf(not have_twisted_web, 'twisted web not present')
class CompatibilityTests(unittest.TestCase):
def setUp(self):
self.saved_signals = save_signal_handlers()
self.io_loop = IOLoop()
self.io_loop.make_current()
self.reactor = TornadoReactor(self.io_loop)
def tearDown(self):
self.reactor.disconnectAll()
self.io_loop.clear_current()
self.io_loop.close(all_fds=True)
restore_signal_handlers(self.saved_signals)
def start_twisted_server(self):
class HelloResource(Resource):
isLeaf = True
def render_GET(self, request):
return "Hello from twisted!"
site = Site(HelloResource())
port = self.reactor.listenTCP(0, site, interface='127.0.0.1')
self.twisted_port = port.getHost().port
def start_tornado_server(self):
class HelloHandler(RequestHandler):
def get(self):
self.write("Hello from tornado!")
app = Application([('/', HelloHandler)],
log_function=lambda x: None)
server = HTTPServer(app, io_loop=self.io_loop)
sock, self.tornado_port = bind_unused_port()
server.add_sockets([sock])
def run_ioloop(self):
self.stop_loop = self.io_loop.stop
self.io_loop.start()
self.reactor.fireSystemEvent('shutdown')
def run_reactor(self):
self.stop_loop = self.reactor.stop
self.stop = self.reactor.stop
self.reactor.run()
def tornado_fetch(self, url, runner):
responses = []
client = AsyncHTTPClient(self.io_loop)
def callback(response):
responses.append(response)
self.stop_loop()
client.fetch(url, callback=callback)
runner()
self.assertEqual(len(responses), 1)
responses[0].rethrow()
return responses[0]
def twisted_fetch(self, url, runner):
# http://twistedmatrix.com/documents/current/web/howto/client.html
chunks = []
client = Agent(self.reactor)
d = client.request('GET', url)
class Accumulator(Protocol):
def __init__(self, finished):
self.finished = finished
def dataReceived(self, data):
chunks.append(data)
def connectionLost(self, reason):
self.finished.callback(None)
def callback(response):
finished = Deferred()
response.deliverBody(Accumulator(finished))
return finished
d.addCallback(callback)
def shutdown(ignored):
self.stop_loop()
d.addBoth(shutdown)
runner()
self.assertTrue(chunks)
return ''.join(chunks)
def testTwistedServerTornadoClientIOLoop(self):
self.start_twisted_server()
response = self.tornado_fetch(
'http://localhost:%d' % self.twisted_port, self.run_ioloop)
self.assertEqual(response.body, 'Hello from twisted!')
def testTwistedServerTornadoClientReactor(self):
self.start_twisted_server()
response = self.tornado_fetch(
'http://localhost:%d' % self.twisted_port, self.run_reactor)
self.assertEqual(response.body, 'Hello from twisted!')
def testTornadoServerTwistedClientIOLoop(self):
self.start_tornado_server()
response = self.twisted_fetch(
'http://localhost:%d' % self.tornado_port, self.run_ioloop)
self.assertEqual(response, 'Hello from tornado!')
def testTornadoServerTwistedClientReactor(self):
self.start_tornado_server()
response = self.twisted_fetch(
'http://localhost:%d' % self.tornado_port, self.run_reactor)
self.assertEqual(response, 'Hello from tornado!')
if have_twisted:
# Import and run as much of twisted's test suite as possible.
# This is unfortunately rather dependent on implementation details,
# but there doesn't appear to be a clean all-in-one conformance test
# suite for reactors.
#
# This is a list of all test suites using the ReactorBuilder
# available in Twisted 11.0.0 and 11.1.0 (and a blacklist of
# specific test methods to be disabled).
twisted_tests = {
'twisted.internet.test.test_core.ObjectModelIntegrationTest': [],
'twisted.internet.test.test_core.SystemEventTestsBuilder': [
'test_iterate', # deliberately not supported
# Fails on TwistedIOLoop and AsyncIOLoop.
'test_runAfterCrash',
],
'twisted.internet.test.test_fdset.ReactorFDSetTestsBuilder': [
"test_lostFileDescriptor", # incompatible with epoll and kqueue
],
'twisted.internet.test.test_process.ProcessTestsBuilder': [
# Only work as root. Twisted's "skip" functionality works
# with py27+, but not unittest2 on py26.
'test_changeGID',
'test_changeUID',
],
# Process tests appear to work on OSX 10.7, but not 10.6
#'twisted.internet.test.test_process.PTYProcessTestsBuilder': [
# 'test_systemCallUninterruptedByChildExit',
# ],
'twisted.internet.test.test_tcp.TCPClientTestsBuilder': [
'test_badContext', # ssl-related; see also SSLClientTestsMixin
],
'twisted.internet.test.test_tcp.TCPPortTestsBuilder': [
# These use link-local addresses and cause firewall prompts on mac
'test_buildProtocolIPv6AddressScopeID',
'test_portGetHostOnIPv6ScopeID',
'test_serverGetHostOnIPv6ScopeID',
'test_serverGetPeerOnIPv6ScopeID',
],
'twisted.internet.test.test_tcp.TCPConnectionTestsBuilder': [],
'twisted.internet.test.test_tcp.WriteSequenceTests': [],
'twisted.internet.test.test_tcp.AbortConnectionTestCase': [],
'twisted.internet.test.test_threads.ThreadTestsBuilder': [],
'twisted.internet.test.test_time.TimeTestsBuilder': [],
# Extra third-party dependencies (pyOpenSSL)
#'twisted.internet.test.test_tls.SSLClientTestsMixin': [],
'twisted.internet.test.test_udp.UDPServerTestsBuilder': [],
'twisted.internet.test.test_unix.UNIXTestsBuilder': [
# Platform-specific. These tests would be skipped automatically
# if we were running twisted's own test runner.
'test_connectToLinuxAbstractNamespace',
'test_listenOnLinuxAbstractNamespace',
# These tests use twisted's sendmsg.c extension and sometimes
# fail with what looks like uninitialized memory errors
# (more common on pypy than cpython, but I've seen it on both)
'test_sendFileDescriptor',
'test_sendFileDescriptorTriggersPauseProducing',
'test_descriptorDeliveredBeforeBytes',
'test_avoidLeakingFileDescriptors',
],
'twisted.internet.test.test_unix.UNIXDatagramTestsBuilder': [
'test_listenOnLinuxAbstractNamespace',
],
'twisted.internet.test.test_unix.UNIXPortTestsBuilder': [],
}
for test_name, blacklist in twisted_tests.items():
try:
test_class = import_object(test_name)
except (ImportError, AttributeError):
continue
for test_func in blacklist:
if hasattr(test_class, test_func):
# The test_func may be defined in a mixin, so clobber
# it instead of delattr()
setattr(test_class, test_func, lambda self: None)
def make_test_subclass(test_class):
class TornadoTest(test_class):
_reactors = ["tornado.platform.twisted._TestReactor"]
def setUp(self):
# Twisted's tests expect to be run from a temporary
# directory; they create files in their working directory
# and don't always clean up after themselves.
self.__curdir = os.getcwd()
self.__tempdir = tempfile.mkdtemp()
os.chdir(self.__tempdir)
super(TornadoTest, self).setUp()
def tearDown(self):
super(TornadoTest, self).tearDown()
os.chdir(self.__curdir)
shutil.rmtree(self.__tempdir)
def buildReactor(self):
self.__saved_signals = save_signal_handlers()
return test_class.buildReactor(self)
def unbuildReactor(self, reactor):
test_class.unbuildReactor(self, reactor)
# Clean up file descriptors (especially epoll/kqueue
# objects) eagerly instead of leaving them for the
# GC. Unfortunately we can't do this in reactor.stop
# since twisted expects to be able to unregister
# connections in a post-shutdown hook.
reactor._io_loop.close(all_fds=True)
restore_signal_handlers(self.__saved_signals)
TornadoTest.__name__ = test_class.__name__
return TornadoTest
test_subclass = make_test_subclass(test_class)
globals().update(test_subclass.makeTestCaseClasses())
# Since we're not using twisted's test runner, it's tricky to get
# logging set up well. Most of the time it's easiest to just
# leave it turned off, but while working on these tests you may want
# to uncomment one of the other lines instead.
log.defaultObserver.stop()
# import sys; log.startLogging(sys.stderr, setStdout=0)
# log.startLoggingWithObserver(log.PythonLoggingObserver().emit, setStdout=0)
# import logging; logging.getLogger('twisted').setLevel(logging.WARNING)
if have_twisted:
class LayeredTwistedIOLoop(TwistedIOLoop):
"""Layers a TwistedIOLoop on top of a TornadoReactor on a SelectIOLoop.
This is of course silly, but is useful for testing purposes to make
sure we're implementing both sides of the various interfaces
correctly. In some tests another TornadoReactor is layered on top
of the whole stack.
"""
def initialize(self):
# When configured to use LayeredTwistedIOLoop we can't easily
# get the next-best IOLoop implementation, so use the lowest common
# denominator.
self.real_io_loop = SelectIOLoop()
reactor = TornadoReactor(io_loop=self.real_io_loop)
super(LayeredTwistedIOLoop, self).initialize(reactor=reactor)
self.add_callback(self.make_current)
def close(self, all_fds=False):
super(LayeredTwistedIOLoop, self).close(all_fds=all_fds)
# HACK: This is the same thing that test_class.unbuildReactor does.
for reader in self.reactor._internalReaders:
self.reactor.removeReader(reader)
reader.connectionLost(None)
self.real_io_loop.close(all_fds=all_fds)
def stop(self):
# One of twisted's tests fails if I don't delay crash()
# until the reactor has started, but if I move this to
# TwistedIOLoop then the tests fail when I'm *not* running
# tornado-on-twisted-on-tornado. I'm clearly missing something
# about the startup/crash semantics, but since stop and crash
# are really only used in tests it doesn't really matter.
self.reactor.callWhenRunning(self.reactor.crash)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,30 @@
from __future__ import absolute_import, division, print_function, with_statement
import os
import socket
import sys
# Encapsulate the choice of unittest or unittest2 here.
# To be used as 'from tornado.test.util import unittest'.
if sys.version_info < (2, 7):
# In py26, we must always use unittest2.
import unittest2 as unittest
else:
# Otherwise, use whichever version of unittest was imported in
# tornado.testing.
from tornado.testing import unittest
skipIfNonUnix = unittest.skipIf(os.name != 'posix' or sys.platform == 'cygwin',
"non-unix platform")
# travis-ci.org runs our tests in an overworked virtual machine, which makes
# timing-related tests unreliable.
skipOnTravis = unittest.skipIf('TRAVIS' in os.environ,
'timing tests unreliable on travis')
# Set the environment variable NO_NETWORK=1 to disable any tests that
# depend on an external network.
skipIfNoNetwork = unittest.skipIf('NO_NETWORK' in os.environ,
'network access disabled')
skipIfNoIPv6 = unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present')

View file

@ -0,0 +1,172 @@
# coding: utf-8
from __future__ import absolute_import, division, print_function, with_statement
import sys
from tornado.escape import utf8
from tornado.util import raise_exc_info, Configurable, u, exec_in, ArgReplacer
from tornado.test.util import unittest
try:
from cStringIO import StringIO # py2
except ImportError:
from io import StringIO # py3
class RaiseExcInfoTest(unittest.TestCase):
def test_two_arg_exception(self):
# This test would fail on python 3 if raise_exc_info were simply
# a three-argument raise statement, because TwoArgException
# doesn't have a "copy constructor"
class TwoArgException(Exception):
def __init__(self, a, b):
super(TwoArgException, self).__init__()
self.a, self.b = a, b
try:
raise TwoArgException(1, 2)
except TwoArgException:
exc_info = sys.exc_info()
try:
raise_exc_info(exc_info)
self.fail("didn't get expected exception")
except TwoArgException as e:
self.assertIs(e, exc_info[1])
class TestConfigurable(Configurable):
@classmethod
def configurable_base(cls):
return TestConfigurable
@classmethod
def configurable_default(cls):
return TestConfig1
class TestConfig1(TestConfigurable):
def initialize(self, a=None):
self.a = a
class TestConfig2(TestConfigurable):
def initialize(self, b=None):
self.b = b
class ConfigurableTest(unittest.TestCase):
def setUp(self):
self.saved = TestConfigurable._save_configuration()
def tearDown(self):
TestConfigurable._restore_configuration(self.saved)
def checkSubclasses(self):
# no matter how the class is configured, it should always be
# possible to instantiate the subclasses directly
self.assertIsInstance(TestConfig1(), TestConfig1)
self.assertIsInstance(TestConfig2(), TestConfig2)
obj = TestConfig1(a=1)
self.assertEqual(obj.a, 1)
obj = TestConfig2(b=2)
self.assertEqual(obj.b, 2)
def test_default(self):
obj = TestConfigurable()
self.assertIsInstance(obj, TestConfig1)
self.assertIs(obj.a, None)
obj = TestConfigurable(a=1)
self.assertIsInstance(obj, TestConfig1)
self.assertEqual(obj.a, 1)
self.checkSubclasses()
def test_config_class(self):
TestConfigurable.configure(TestConfig2)
obj = TestConfigurable()
self.assertIsInstance(obj, TestConfig2)
self.assertIs(obj.b, None)
obj = TestConfigurable(b=2)
self.assertIsInstance(obj, TestConfig2)
self.assertEqual(obj.b, 2)
self.checkSubclasses()
def test_config_args(self):
TestConfigurable.configure(None, a=3)
obj = TestConfigurable()
self.assertIsInstance(obj, TestConfig1)
self.assertEqual(obj.a, 3)
obj = TestConfigurable(a=4)
self.assertIsInstance(obj, TestConfig1)
self.assertEqual(obj.a, 4)
self.checkSubclasses()
# args bound in configure don't apply when using the subclass directly
obj = TestConfig1()
self.assertIs(obj.a, None)
def test_config_class_args(self):
TestConfigurable.configure(TestConfig2, b=5)
obj = TestConfigurable()
self.assertIsInstance(obj, TestConfig2)
self.assertEqual(obj.b, 5)
obj = TestConfigurable(b=6)
self.assertIsInstance(obj, TestConfig2)
self.assertEqual(obj.b, 6)
self.checkSubclasses()
# args bound in configure don't apply when using the subclass directly
obj = TestConfig2()
self.assertIs(obj.b, None)
class UnicodeLiteralTest(unittest.TestCase):
def test_unicode_escapes(self):
self.assertEqual(utf8(u('\u00e9')), b'\xc3\xa9')
class ExecInTest(unittest.TestCase):
# This test is python 2 only because there are no new future imports
# defined in python 3 yet.
@unittest.skipIf(sys.version_info >= print_function.getMandatoryRelease(),
'no testable future imports')
def test_no_inherit_future(self):
# This file has from __future__ import print_function...
f = StringIO()
print('hello', file=f)
# ...but the template doesn't
exec_in('print >> f, "world"', dict(f=f))
self.assertEqual(f.getvalue(), 'hello\nworld\n')
class ArgReplacerTest(unittest.TestCase):
def setUp(self):
def function(x, y, callback=None, z=None):
pass
self.replacer = ArgReplacer(function, 'callback')
def test_omitted(self):
args = (1, 2)
kwargs = dict()
self.assertIs(self.replacer.get_old_value(args, kwargs), None)
self.assertEqual(self.replacer.replace('new', args, kwargs),
(None, (1, 2), dict(callback='new')))
def test_position(self):
args = (1, 2, 'old', 3)
kwargs = dict()
self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old')
self.assertEqual(self.replacer.replace('new', args, kwargs),
('old', [1, 2, 'new', 3], dict()))
def test_keyword(self):
args = (1,)
kwargs = dict(y=2, callback='old', z=3)
self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old')
self.assertEqual(self.replacer.replace('new', args, kwargs),
('old', (1,), dict(y=2, callback='new', z=3)))

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,313 @@
from __future__ import absolute_import, division, print_function, with_statement
import traceback
from tornado.concurrent import Future
from tornado.httpclient import HTTPError, HTTPRequest
from tornado.log import gen_log, app_log
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
from tornado.test.util import unittest
from tornado.web import Application, RequestHandler
from tornado.util import u
try:
import tornado.websocket
from tornado.util import _websocket_mask_python
except ImportError:
# The unittest module presents misleading errors on ImportError
# (it acts as if websocket_test could not be found, hiding the underlying
# error). If we get an ImportError here (which could happen due to
# TORNADO_EXTENSION=1), print some extra information before failing.
traceback.print_exc()
raise
from tornado.websocket import WebSocketHandler, websocket_connect, WebSocketError
try:
from tornado import speedups
except ImportError:
speedups = None
class TestWebSocketHandler(WebSocketHandler):
"""Base class for testing handlers that exposes the on_close event.
This allows for deterministic cleanup of the associated socket.
"""
def initialize(self, close_future):
self.close_future = close_future
def on_close(self):
self.close_future.set_result((self.close_code, self.close_reason))
class EchoHandler(TestWebSocketHandler):
def on_message(self, message):
self.write_message(message, isinstance(message, bytes))
class ErrorInOnMessageHandler(TestWebSocketHandler):
def on_message(self, message):
1/0
class HeaderHandler(TestWebSocketHandler):
def open(self):
try:
# In a websocket context, many RequestHandler methods
# raise RuntimeErrors.
self.set_status(503)
raise Exception("did not get expected exception")
except RuntimeError:
pass
self.write_message(self.request.headers.get('X-Test', ''))
class NonWebSocketHandler(RequestHandler):
def get(self):
self.write('ok')
class CloseReasonHandler(TestWebSocketHandler):
def open(self):
self.close(1001, "goodbye")
class WebSocketTest(AsyncHTTPTestCase):
def get_app(self):
self.close_future = Future()
return Application([
('/echo', EchoHandler, dict(close_future=self.close_future)),
('/non_ws', NonWebSocketHandler),
('/header', HeaderHandler, dict(close_future=self.close_future)),
('/close_reason', CloseReasonHandler,
dict(close_future=self.close_future)),
('/error_in_on_message', ErrorInOnMessageHandler,
dict(close_future=self.close_future)),
])
def test_http_request(self):
# WS server, HTTP client.
response = self.fetch('/echo')
self.assertEqual(response.code, 400)
@gen_test
def test_websocket_gen(self):
ws = yield websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port(),
io_loop=self.io_loop)
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
ws.close()
yield self.close_future
def test_websocket_callbacks(self):
websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port(),
io_loop=self.io_loop, callback=self.stop)
ws = self.wait().result()
ws.write_message('hello')
ws.read_message(self.stop)
response = self.wait().result()
self.assertEqual(response, 'hello')
self.close_future.add_done_callback(lambda f: self.stop())
ws.close()
self.wait()
@gen_test
def test_binary_message(self):
ws = yield websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port())
ws.write_message(b'hello \xe9', binary=True)
response = yield ws.read_message()
self.assertEqual(response, b'hello \xe9')
ws.close()
yield self.close_future
@gen_test
def test_unicode_message(self):
ws = yield websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port())
ws.write_message(u('hello \u00e9'))
response = yield ws.read_message()
self.assertEqual(response, u('hello \u00e9'))
ws.close()
yield self.close_future
@gen_test
def test_error_in_on_message(self):
ws = yield websocket_connect(
'ws://localhost:%d/error_in_on_message' % self.get_http_port())
ws.write_message('hello')
with ExpectLog(app_log, "Uncaught exception"):
response = yield ws.read_message()
self.assertIs(response, None)
ws.close()
yield self.close_future
@gen_test
def test_websocket_http_fail(self):
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(
'ws://localhost:%d/notfound' % self.get_http_port(),
io_loop=self.io_loop)
self.assertEqual(cm.exception.code, 404)
@gen_test
def test_websocket_http_success(self):
with self.assertRaises(WebSocketError):
yield websocket_connect(
'ws://localhost:%d/non_ws' % self.get_http_port(),
io_loop=self.io_loop)
@gen_test
def test_websocket_network_fail(self):
sock, port = bind_unused_port()
sock.close()
with self.assertRaises(IOError):
with ExpectLog(gen_log, ".*"):
yield websocket_connect(
'ws://localhost:%d/' % port,
io_loop=self.io_loop,
connect_timeout=3600)
@gen_test
def test_websocket_close_buffered_data(self):
ws = yield websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port())
ws.write_message('hello')
ws.write_message('world')
ws.stream.close()
yield self.close_future
@gen_test
def test_websocket_headers(self):
# Ensure that arbitrary headers can be passed through websocket_connect.
ws = yield websocket_connect(
HTTPRequest('ws://localhost:%d/header' % self.get_http_port(),
headers={'X-Test': 'hello'}))
response = yield ws.read_message()
self.assertEqual(response, 'hello')
ws.close()
yield self.close_future
@gen_test
def test_server_close_reason(self):
ws = yield websocket_connect(
'ws://localhost:%d/close_reason' % self.get_http_port())
msg = yield ws.read_message()
# A message of None means the other side closed the connection.
self.assertIs(msg, None)
self.assertEqual(ws.close_code, 1001)
self.assertEqual(ws.close_reason, "goodbye")
@gen_test
def test_client_close_reason(self):
ws = yield websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port())
ws.close(1001, 'goodbye')
code, reason = yield self.close_future
self.assertEqual(code, 1001)
self.assertEqual(reason, 'goodbye')
@gen_test
def test_check_origin_valid_no_path(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
headers = {'Origin': 'http://localhost:%d' % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
ws.close()
yield self.close_future
@gen_test
def test_check_origin_valid_with_path(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
headers = {'Origin': 'http://localhost:%d/something' % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
ws.close()
yield self.close_future
@gen_test
def test_check_origin_invalid_partial_url(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
headers = {'Origin': 'localhost:%d' % port}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
self.assertEqual(cm.exception.code, 403)
@gen_test
def test_check_origin_invalid(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
# Host is localhost, which should not be accessible from some other
# domain
headers = {'Origin': 'http://somewhereelse.com'}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
self.assertEqual(cm.exception.code, 403)
@gen_test
def test_check_origin_invalid_subdomains(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
# Subdomains should be disallowed by default. If we could pass a
# resolver to websocket_connect we could test sibling domains as well.
headers = {'Origin': 'http://subtenant.localhost'}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
self.assertEqual(cm.exception.code, 403)
class MaskFunctionMixin(object):
# Subclasses should define self.mask(mask, data)
def test_mask(self):
self.assertEqual(self.mask(b'abcd', b''), b'')
self.assertEqual(self.mask(b'abcd', b'b'), b'\x03')
self.assertEqual(self.mask(b'abcd', b'54321'), b'TVPVP')
self.assertEqual(self.mask(b'ZXCV', b'98765432'), b'c`t`olpd')
# Include test cases with \x00 bytes (to ensure that the C
# extension isn't depending on null-terminated strings) and
# bytes with the high bit set (to smoke out signedness issues).
self.assertEqual(self.mask(b'\x00\x01\x02\x03',
b'\xff\xfb\xfd\xfc\xfe\xfa'),
b'\xff\xfa\xff\xff\xfe\xfb')
self.assertEqual(self.mask(b'\xff\xfb\xfd\xfc',
b'\x00\x01\x02\x03\x04\x05'),
b'\xff\xfa\xff\xff\xfb\xfe')
class PythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
def mask(self, mask, data):
return _websocket_mask_python(mask, data)
@unittest.skipIf(speedups is None, "tornado.speedups module not present")
class CythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
def mask(self, mask, data):
return speedups.websocket_mask(mask, data)

View file

@ -0,0 +1,100 @@
from __future__ import absolute_import, division, print_function, with_statement
from wsgiref.validate import validator
from tornado.escape import json_decode
from tornado.test.httpserver_test import TypeCheckHandler
from tornado.testing import AsyncHTTPTestCase
from tornado.util import u
from tornado.web import RequestHandler, Application
from tornado.wsgi import WSGIApplication, WSGIContainer, WSGIAdapter
class WSGIContainerTest(AsyncHTTPTestCase):
def wsgi_app(self, environ, start_response):
status = "200 OK"
response_headers = [("Content-Type", "text/plain")]
start_response(status, response_headers)
return [b"Hello world!"]
def get_app(self):
return WSGIContainer(validator(self.wsgi_app))
def test_simple(self):
response = self.fetch("/")
self.assertEqual(response.body, b"Hello world!")
class WSGIApplicationTest(AsyncHTTPTestCase):
def get_app(self):
class HelloHandler(RequestHandler):
def get(self):
self.write("Hello world!")
class PathQuotingHandler(RequestHandler):
def get(self, path):
self.write(path)
# It would be better to run the wsgiref server implementation in
# another thread instead of using our own WSGIContainer, but this
# fits better in our async testing framework and the wsgiref
# validator should keep us honest
return WSGIContainer(validator(WSGIApplication([
("/", HelloHandler),
("/path/(.*)", PathQuotingHandler),
("/typecheck", TypeCheckHandler),
])))
def test_simple(self):
response = self.fetch("/")
self.assertEqual(response.body, b"Hello world!")
def test_path_quoting(self):
response = self.fetch("/path/foo%20bar%C3%A9")
self.assertEqual(response.body, u("foo bar\u00e9").encode("utf-8"))
def test_types(self):
headers = {"Cookie": "foo=bar"}
response = self.fetch("/typecheck?foo=bar", headers=headers)
data = json_decode(response.body)
self.assertEqual(data, {})
response = self.fetch("/typecheck", method="POST", body="foo=bar", headers=headers)
data = json_decode(response.body)
self.assertEqual(data, {})
# This is kind of hacky, but run some of the HTTPServer tests through
# WSGIContainer and WSGIApplication to make sure everything survives
# repeated disassembly and reassembly.
from tornado.test import httpserver_test
from tornado.test import web_test
class WSGIConnectionTest(httpserver_test.HTTPConnectionTest):
def get_app(self):
return WSGIContainer(validator(WSGIApplication(self.get_handlers())))
def wrap_web_tests_application():
result = {}
for cls in web_test.wsgi_safe_tests:
class WSGIApplicationWrappedTest(cls):
def get_app(self):
self.app = WSGIApplication(self.get_handlers(),
**self.get_app_kwargs())
return WSGIContainer(validator(self.app))
result["WSGIApplication_" + cls.__name__] = WSGIApplicationWrappedTest
return result
globals().update(wrap_web_tests_application())
def wrap_web_tests_adapter():
result = {}
for cls in web_test.wsgi_safe_tests:
class WSGIAdapterWrappedTest(cls):
def get_app(self):
self.app = Application(self.get_handlers(),
**self.get_app_kwargs())
return WSGIContainer(validator(WSGIAdapter(self.app)))
result["WSGIAdapter_" + cls.__name__] = WSGIAdapterWrappedTest
return result
globals().update(wrap_web_tests_adapter())

View file

@ -0,0 +1,691 @@
#!/usr/bin/env python
"""Support classes for automated testing.
* `AsyncTestCase` and `AsyncHTTPTestCase`: Subclasses of unittest.TestCase
with additional support for testing asynchronous (`.IOLoop` based) code.
* `ExpectLog` and `LogTrapTestCase`: Make test logs less spammy.
* `main()`: A simple test runner (wrapper around unittest.main()) with support
for the tornado.autoreload module to rerun the tests when code changes.
"""
from __future__ import absolute_import, division, print_function, with_statement
try:
from tornado import gen
from tornado.httpclient import AsyncHTTPClient
from tornado.httpserver import HTTPServer
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.ioloop import IOLoop, TimeoutError
from tornado import netutil
except ImportError:
# These modules are not importable on app engine. Parts of this module
# won't work, but e.g. LogTrapTestCase and main() will.
AsyncHTTPClient = None
gen = None
HTTPServer = None
IOLoop = None
netutil = None
SimpleAsyncHTTPClient = None
from tornado.log import gen_log
from tornado.stack_context import ExceptionStackContext
from tornado.util import raise_exc_info, basestring_type
import functools
import logging
import os
import re
import signal
import socket
import sys
import types
try:
from cStringIO import StringIO # py2
except ImportError:
from io import StringIO # py3
# Tornado's own test suite requires the updated unittest module
# (either py27+ or unittest2) so tornado.test.util enforces
# this requirement, but for other users of tornado.testing we want
# to allow the older version if unitest2 is not available.
if sys.version_info >= (3,):
# On python 3, mixing unittest2 and unittest (including doctest)
# doesn't seem to work, so always use unittest.
import unittest
else:
# On python 2, prefer unittest2 when available.
try:
import unittest2 as unittest
except ImportError:
import unittest
_next_port = 10000
def get_unused_port():
"""Returns a (hopefully) unused port number.
This function does not guarantee that the port it returns is available,
only that a series of get_unused_port calls in a single process return
distinct ports.
.. deprecated::
Use bind_unused_port instead, which is guaranteed to find an unused port.
"""
global _next_port
port = _next_port
_next_port = _next_port + 1
return port
def bind_unused_port():
"""Binds a server socket to an available port on localhost.
Returns a tuple (socket, port).
"""
[sock] = netutil.bind_sockets(None, 'localhost', family=socket.AF_INET)
port = sock.getsockname()[1]
return sock, port
def get_async_test_timeout():
"""Get the global timeout setting for async tests.
Returns a float, the timeout in seconds.
.. versionadded:: 3.1
"""
try:
return float(os.environ.get('ASYNC_TEST_TIMEOUT'))
except (ValueError, TypeError):
return 5
class _TestMethodWrapper(object):
"""Wraps a test method to raise an error if it returns a value.
This is mainly used to detect undecorated generators (if a test
method yields it must use a decorator to consume the generator),
but will also detect other kinds of return values (these are not
necessarily errors, but we alert anyway since there is no good
reason to return a value from a test.
"""
def __init__(self, orig_method):
self.orig_method = orig_method
def __call__(self):
result = self.orig_method()
if isinstance(result, types.GeneratorType):
raise TypeError("Generator test methods should be decorated with "
"tornado.testing.gen_test")
elif result is not None:
raise ValueError("Return value from test method ignored: %r" %
result)
def __getattr__(self, name):
"""Proxy all unknown attributes to the original method.
This is important for some of the decorators in the `unittest`
module, such as `unittest.skipIf`.
"""
return getattr(self.orig_method, name)
class AsyncTestCase(unittest.TestCase):
"""`~unittest.TestCase` subclass for testing `.IOLoop`-based
asynchronous code.
The unittest framework is synchronous, so the test must be
complete by the time the test method returns. This means that
asynchronous code cannot be used in quite the same way as usual.
To write test functions that use the same ``yield``-based patterns
used with the `tornado.gen` module, decorate your test methods
with `tornado.testing.gen_test` instead of
`tornado.gen.coroutine`. This class also provides the `stop()`
and `wait()` methods for a more manual style of testing. The test
method itself must call ``self.wait()``, and asynchronous
callbacks should call ``self.stop()`` to signal completion.
By default, a new `.IOLoop` is constructed for each test and is available
as ``self.io_loop``. This `.IOLoop` should be used in the construction of
HTTP clients/servers, etc. If the code being tested requires a
global `.IOLoop`, subclasses should override `get_new_ioloop` to return it.
The `.IOLoop`'s ``start`` and ``stop`` methods should not be
called directly. Instead, use `self.stop <stop>` and `self.wait
<wait>`. Arguments passed to ``self.stop`` are returned from
``self.wait``. It is possible to have multiple ``wait``/``stop``
cycles in the same test.
Example::
# This test uses coroutine style.
class MyTestCase(AsyncTestCase):
@tornado.testing.gen_test
def test_http_fetch(self):
client = AsyncHTTPClient(self.io_loop)
response = yield client.fetch("http://www.tornadoweb.org")
# Test contents of response
self.assertIn("FriendFeed", response.body)
# This test uses argument passing between self.stop and self.wait.
class MyTestCase2(AsyncTestCase):
def test_http_fetch(self):
client = AsyncHTTPClient(self.io_loop)
client.fetch("http://www.tornadoweb.org/", self.stop)
response = self.wait()
# Test contents of response
self.assertIn("FriendFeed", response.body)
# This test uses an explicit callback-based style.
class MyTestCase3(AsyncTestCase):
def test_http_fetch(self):
client = AsyncHTTPClient(self.io_loop)
client.fetch("http://www.tornadoweb.org/", self.handle_fetch)
self.wait()
def handle_fetch(self, response):
# Test contents of response (failures and exceptions here
# will cause self.wait() to throw an exception and end the
# test).
# Exceptions thrown here are magically propagated to
# self.wait() in test_http_fetch() via stack_context.
self.assertIn("FriendFeed", response.body)
self.stop()
"""
def __init__(self, methodName='runTest', **kwargs):
super(AsyncTestCase, self).__init__(methodName, **kwargs)
self.__stopped = False
self.__running = False
self.__failure = None
self.__stop_args = None
self.__timeout = None
# It's easy to forget the @gen_test decorator, but if you do
# the test will silently be ignored because nothing will consume
# the generator. Replace the test method with a wrapper that will
# make sure it's not an undecorated generator.
setattr(self, methodName, _TestMethodWrapper(getattr(self, methodName)))
def setUp(self):
super(AsyncTestCase, self).setUp()
self.io_loop = self.get_new_ioloop()
self.io_loop.make_current()
def tearDown(self):
self.io_loop.clear_current()
if (not IOLoop.initialized() or
self.io_loop is not IOLoop.instance()):
# Try to clean up any file descriptors left open in the ioloop.
# This avoids leaks, especially when tests are run repeatedly
# in the same process with autoreload (because curl does not
# set FD_CLOEXEC on its file descriptors)
self.io_loop.close(all_fds=True)
super(AsyncTestCase, self).tearDown()
# In case an exception escaped or the StackContext caught an exception
# when there wasn't a wait() to re-raise it, do so here.
# This is our last chance to raise an exception in a way that the
# unittest machinery understands.
self.__rethrow()
def get_new_ioloop(self):
"""Creates a new `.IOLoop` for this test. May be overridden in
subclasses for tests that require a specific `.IOLoop` (usually
the singleton `.IOLoop.instance()`).
"""
return IOLoop()
def _handle_exception(self, typ, value, tb):
self.__failure = (typ, value, tb)
self.stop()
return True
def __rethrow(self):
if self.__failure is not None:
failure = self.__failure
self.__failure = None
raise_exc_info(failure)
def run(self, result=None):
with ExceptionStackContext(self._handle_exception):
super(AsyncTestCase, self).run(result)
# As a last resort, if an exception escaped super.run() and wasn't
# re-raised in tearDown, raise it here. This will cause the
# unittest run to fail messily, but that's better than silently
# ignoring an error.
self.__rethrow()
def stop(self, _arg=None, **kwargs):
"""Stops the `.IOLoop`, causing one pending (or future) call to `wait()`
to return.
Keyword arguments or a single positional argument passed to `stop()` are
saved and will be returned by `wait()`.
"""
assert _arg is None or not kwargs
self.__stop_args = kwargs or _arg
if self.__running:
self.io_loop.stop()
self.__running = False
self.__stopped = True
def wait(self, condition=None, timeout=None):
"""Runs the `.IOLoop` until stop is called or timeout has passed.
In the event of a timeout, an exception will be thrown. The
default timeout is 5 seconds; it may be overridden with a
``timeout`` keyword argument or globally with the
``ASYNC_TEST_TIMEOUT`` environment variable.
If ``condition`` is not None, the `.IOLoop` will be restarted
after `stop()` until ``condition()`` returns true.
.. versionchanged:: 3.1
Added the ``ASYNC_TEST_TIMEOUT`` environment variable.
"""
if timeout is None:
timeout = get_async_test_timeout()
if not self.__stopped:
if timeout:
def timeout_func():
try:
raise self.failureException(
'Async operation timed out after %s seconds' %
timeout)
except Exception:
self.__failure = sys.exc_info()
self.stop()
self.__timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout, timeout_func)
while True:
self.__running = True
self.io_loop.start()
if (self.__failure is not None or
condition is None or condition()):
break
if self.__timeout is not None:
self.io_loop.remove_timeout(self.__timeout)
self.__timeout = None
assert self.__stopped
self.__stopped = False
self.__rethrow()
result = self.__stop_args
self.__stop_args = None
return result
class AsyncHTTPTestCase(AsyncTestCase):
"""A test case that starts up an HTTP server.
Subclasses must override `get_app()`, which returns the
`tornado.web.Application` (or other `.HTTPServer` callback) to be tested.
Tests will typically use the provided ``self.http_client`` to fetch
URLs from this server.
Example::
class MyHTTPTest(AsyncHTTPTestCase):
def get_app(self):
return Application([('/', MyHandler)...])
def test_homepage(self):
# The following two lines are equivalent to
# response = self.fetch('/')
# but are shown in full here to demonstrate explicit use
# of self.stop and self.wait.
self.http_client.fetch(self.get_url('/'), self.stop)
response = self.wait()
# test contents of response
"""
def setUp(self):
super(AsyncHTTPTestCase, self).setUp()
sock, port = bind_unused_port()
self.__port = port
self.http_client = self.get_http_client()
self._app = self.get_app()
self.http_server = self.get_http_server()
self.http_server.add_sockets([sock])
def get_http_client(self):
return AsyncHTTPClient(io_loop=self.io_loop)
def get_http_server(self):
return HTTPServer(self._app, io_loop=self.io_loop,
**self.get_httpserver_options())
def get_app(self):
"""Should be overridden by subclasses to return a
`tornado.web.Application` or other `.HTTPServer` callback.
"""
raise NotImplementedError()
def fetch(self, path, **kwargs):
"""Convenience method to synchronously fetch a url.
The given path will be appended to the local server's host and
port. Any additional kwargs will be passed directly to
`.AsyncHTTPClient.fetch` (and so could be used to pass
``method="POST"``, ``body="..."``, etc).
"""
self.http_client.fetch(self.get_url(path), self.stop, **kwargs)
return self.wait()
def get_httpserver_options(self):
"""May be overridden by subclasses to return additional
keyword arguments for the server.
"""
return {}
def get_http_port(self):
"""Returns the port used by the server.
A new port is chosen for each test.
"""
return self.__port
def get_protocol(self):
return 'http'
def get_url(self, path):
"""Returns an absolute url for the given path on the test server."""
return '%s://localhost:%s%s' % (self.get_protocol(),
self.get_http_port(), path)
def tearDown(self):
self.http_server.stop()
self.io_loop.run_sync(self.http_server.close_all_connections)
if (not IOLoop.initialized() or
self.http_client.io_loop is not IOLoop.instance()):
self.http_client.close()
super(AsyncHTTPTestCase, self).tearDown()
class AsyncHTTPSTestCase(AsyncHTTPTestCase):
"""A test case that starts an HTTPS server.
Interface is generally the same as `AsyncHTTPTestCase`.
"""
def get_http_client(self):
# Some versions of libcurl have deadlock bugs with ssl,
# so always run these tests with SimpleAsyncHTTPClient.
return SimpleAsyncHTTPClient(io_loop=self.io_loop, force_instance=True,
defaults=dict(validate_cert=False))
def get_httpserver_options(self):
return dict(ssl_options=self.get_ssl_options())
def get_ssl_options(self):
"""May be overridden by subclasses to select SSL options.
By default includes a self-signed testing certificate.
"""
# Testing keys were generated with:
# openssl req -new -keyout tornado/test/test.key -out tornado/test/test.crt -nodes -days 3650 -x509
module_dir = os.path.dirname(__file__)
return dict(
certfile=os.path.join(module_dir, 'test', 'test.crt'),
keyfile=os.path.join(module_dir, 'test', 'test.key'))
def get_protocol(self):
return 'https'
def gen_test(func=None, timeout=None):
"""Testing equivalent of ``@gen.coroutine``, to be applied to test methods.
``@gen.coroutine`` cannot be used on tests because the `.IOLoop` is not
already running. ``@gen_test`` should be applied to test methods
on subclasses of `AsyncTestCase`.
Example::
class MyTest(AsyncHTTPTestCase):
@gen_test
def test_something(self):
response = yield gen.Task(self.fetch('/'))
By default, ``@gen_test`` times out after 5 seconds. The timeout may be
overridden globally with the ``ASYNC_TEST_TIMEOUT`` environment variable,
or for each test with the ``timeout`` keyword argument::
class MyTest(AsyncHTTPTestCase):
@gen_test(timeout=10)
def test_something_slow(self):
response = yield gen.Task(self.fetch('/'))
.. versionadded:: 3.1
The ``timeout`` argument and ``ASYNC_TEST_TIMEOUT`` environment
variable.
.. versionchanged:: 4.0
The wrapper now passes along ``*args, **kwargs`` so it can be used
on functions with arguments.
"""
if timeout is None:
timeout = get_async_test_timeout()
def wrap(f):
# Stack up several decorators to allow us to access the generator
# object itself. In the innermost wrapper, we capture the generator
# and save it in an attribute of self. Next, we run the wrapped
# function through @gen.coroutine. Finally, the coroutine is
# wrapped again to make it synchronous with run_sync.
#
# This is a good case study arguing for either some sort of
# extensibility in the gen decorators or cancellation support.
@functools.wraps(f)
def pre_coroutine(self, *args, **kwargs):
result = f(self, *args, **kwargs)
if isinstance(result, types.GeneratorType):
self._test_generator = result
else:
self._test_generator = None
return result
coro = gen.coroutine(pre_coroutine)
@functools.wraps(coro)
def post_coroutine(self, *args, **kwargs):
try:
return self.io_loop.run_sync(
functools.partial(coro, self, *args, **kwargs),
timeout=timeout)
except TimeoutError as e:
# run_sync raises an error with an unhelpful traceback.
# If we throw it back into the generator the stack trace
# will be replaced by the point where the test is stopped.
self._test_generator.throw(e)
# In case the test contains an overly broad except clause,
# we may get back here. In this case re-raise the original
# exception, which is better than nothing.
raise
return post_coroutine
if func is not None:
# Used like:
# @gen_test
# def f(self):
# pass
return wrap(func)
else:
# Used like @gen_test(timeout=10)
return wrap
# Without this attribute, nosetests will try to run gen_test as a test
# anywhere it is imported.
gen_test.__test__ = False
class LogTrapTestCase(unittest.TestCase):
"""A test case that captures and discards all logging output
if the test passes.
Some libraries can produce a lot of logging output even when
the test succeeds, so this class can be useful to minimize the noise.
Simply use it as a base class for your test case. It is safe to combine
with AsyncTestCase via multiple inheritance
(``class MyTestCase(AsyncHTTPTestCase, LogTrapTestCase):``)
This class assumes that only one log handler is configured and
that it is a `~logging.StreamHandler`. This is true for both
`logging.basicConfig` and the "pretty logging" configured by
`tornado.options`. It is not compatible with other log buffering
mechanisms, such as those provided by some test runners.
"""
def run(self, result=None):
logger = logging.getLogger()
if not logger.handlers:
logging.basicConfig()
handler = logger.handlers[0]
if (len(logger.handlers) > 1 or
not isinstance(handler, logging.StreamHandler)):
# Logging has been configured in a way we don't recognize,
# so just leave it alone.
super(LogTrapTestCase, self).run(result)
return
old_stream = handler.stream
try:
handler.stream = StringIO()
gen_log.info("RUNNING TEST: " + str(self))
old_error_count = len(result.failures) + len(result.errors)
super(LogTrapTestCase, self).run(result)
new_error_count = len(result.failures) + len(result.errors)
if new_error_count != old_error_count:
old_stream.write(handler.stream.getvalue())
finally:
handler.stream = old_stream
class ExpectLog(logging.Filter):
"""Context manager to capture and suppress expected log output.
Useful to make tests of error conditions less noisy, while still
leaving unexpected log entries visible. *Not thread safe.*
Usage::
with ExpectLog('tornado.application', "Uncaught exception"):
error_response = self.fetch("/some_page")
"""
def __init__(self, logger, regex, required=True):
"""Constructs an ExpectLog context manager.
:param logger: Logger object (or name of logger) to watch. Pass
an empty string to watch the root logger.
:param regex: Regular expression to match. Any log entries on
the specified logger that match this regex will be suppressed.
:param required: If true, an exeption will be raised if the end of
the ``with`` statement is reached without matching any log entries.
"""
if isinstance(logger, basestring_type):
logger = logging.getLogger(logger)
self.logger = logger
self.regex = re.compile(regex)
self.required = required
self.matched = False
def filter(self, record):
message = record.getMessage()
if self.regex.match(message):
self.matched = True
return False
return True
def __enter__(self):
self.logger.addFilter(self)
def __exit__(self, typ, value, tb):
self.logger.removeFilter(self)
if not typ and self.required and not self.matched:
raise Exception("did not get expected log message")
def main(**kwargs):
"""A simple test runner.
This test runner is essentially equivalent to `unittest.main` from
the standard library, but adds support for tornado-style option
parsing and log formatting.
The easiest way to run a test is via the command line::
python -m tornado.testing tornado.test.stack_context_test
See the standard library unittest module for ways in which tests can
be specified.
Projects with many tests may wish to define a test script like
``tornado/test/runtests.py``. This script should define a method
``all()`` which returns a test suite and then call
`tornado.testing.main()`. Note that even when a test script is
used, the ``all()`` test suite may be overridden by naming a
single test on the command line::
# Runs all tests
python -m tornado.test.runtests
# Runs one test
python -m tornado.test.runtests tornado.test.stack_context_test
Additional keyword arguments passed through to ``unittest.main()``.
For example, use ``tornado.testing.main(verbosity=2)``
to show many test details as they are run.
See http://docs.python.org/library/unittest.html#unittest.main
for full argument list.
"""
from tornado.options import define, options, parse_command_line
define('exception_on_interrupt', type=bool, default=True,
help=("If true (default), ctrl-c raises a KeyboardInterrupt "
"exception. This prints a stack trace but cannot interrupt "
"certain operations. If false, the process is more reliably "
"killed, but does not print a stack trace."))
# support the same options as unittest's command-line interface
define('verbose', type=bool)
define('quiet', type=bool)
define('failfast', type=bool)
define('catch', type=bool)
define('buffer', type=bool)
argv = [sys.argv[0]] + parse_command_line(sys.argv)
if not options.exception_on_interrupt:
signal.signal(signal.SIGINT, signal.SIG_DFL)
if options.verbose is not None:
kwargs['verbosity'] = 2
if options.quiet is not None:
kwargs['verbosity'] = 0
if options.failfast is not None:
kwargs['failfast'] = True
if options.catch is not None:
kwargs['catchbreak'] = True
if options.buffer is not None:
kwargs['buffer'] = True
if __name__ == '__main__' and len(argv) == 1:
print("No tests specified", file=sys.stderr)
sys.exit(1)
try:
# In order to be able to run tests by their fully-qualified name
# on the command line without importing all tests here,
# module must be set to None. Python 3.2's unittest.main ignores
# defaultTest if no module is given (it tries to do its own
# test discovery, which is incompatible with auto2to3), so don't
# set module if we're not asking for a specific test.
if len(argv) > 1:
unittest.main(module=None, argv=argv, **kwargs)
else:
unittest.main(defaultTest="all", argv=argv, **kwargs)
except SystemExit as e:
if e.code == 0:
gen_log.info('PASS')
else:
gen_log.error('FAIL')
raise
if __name__ == '__main__':
main()

View file

@ -0,0 +1,356 @@
"""Miscellaneous utility functions and classes.
This module is used internally by Tornado. It is not necessarily expected
that the functions and classes defined here will be useful to other
applications, but they are documented here in case they are.
The one public-facing part of this module is the `Configurable` class
and its `~Configurable.configure` method, which becomes a part of the
interface of its subclasses, including `.AsyncHTTPClient`, `.IOLoop`,
and `.Resolver`.
"""
from __future__ import absolute_import, division, print_function, with_statement
import array
import inspect
import os
import sys
import zlib
try:
xrange # py2
except NameError:
xrange = range # py3
class ObjectDict(dict):
"""Makes a dictionary behave like an object, with attribute-style access.
"""
def __getattr__(self, name):
try:
return self[name]
except KeyError:
raise AttributeError(name)
def __setattr__(self, name, value):
self[name] = value
class GzipDecompressor(object):
"""Streaming gzip decompressor.
The interface is like that of `zlib.decompressobj` (without some of the
optional arguments, but it understands gzip headers and checksums.
"""
def __init__(self):
# Magic parameter makes zlib module understand gzip header
# http://stackoverflow.com/questions/1838699/how-can-i-decompress-a-gzip-stream-with-zlib
# This works on cpython and pypy, but not jython.
self.decompressobj = zlib.decompressobj(16 + zlib.MAX_WBITS)
def decompress(self, value, max_length=None):
"""Decompress a chunk, returning newly-available data.
Some data may be buffered for later processing; `flush` must
be called when there is no more input data to ensure that
all data was processed.
If ``max_length`` is given, some input data may be left over
in ``unconsumed_tail``; you must retrieve this value and pass
it back to a future call to `decompress` if it is not empty.
"""
return self.decompressobj.decompress(value, max_length)
@property
def unconsumed_tail(self):
"""Returns the unconsumed portion left over
"""
return self.decompressobj.unconsumed_tail
def flush(self):
"""Return any remaining buffered data not yet returned by decompress.
Also checks for errors such as truncated input.
No other methods may be called on this object after `flush`.
"""
return self.decompressobj.flush()
def import_object(name):
"""Imports an object by name.
import_object('x') is equivalent to 'import x'.
import_object('x.y.z') is equivalent to 'from x.y import z'.
>>> import tornado.escape
>>> import_object('tornado.escape') is tornado.escape
True
>>> import_object('tornado.escape.utf8') is tornado.escape.utf8
True
>>> import_object('tornado') is tornado
True
>>> import_object('tornado.missing_module')
Traceback (most recent call last):
...
ImportError: No module named missing_module
"""
if name.count('.') == 0:
return __import__(name, None, None)
parts = name.split('.')
obj = __import__('.'.join(parts[:-1]), None, None, [parts[-1]], 0)
try:
return getattr(obj, parts[-1])
except AttributeError:
raise ImportError("No module named %s" % parts[-1])
# Fake unicode literal support: Python 3.2 doesn't have the u'' marker for
# literal strings, and alternative solutions like "from __future__ import
# unicode_literals" have other problems (see PEP 414). u() can be applied
# to ascii strings that include \u escapes (but they must not contain
# literal non-ascii characters).
if type('') is not type(b''):
def u(s):
return s
bytes_type = bytes
unicode_type = str
basestring_type = str
else:
def u(s):
return s.decode('unicode_escape')
bytes_type = str
unicode_type = unicode
basestring_type = basestring
if sys.version_info > (3,):
exec("""
def raise_exc_info(exc_info):
raise exc_info[1].with_traceback(exc_info[2])
def exec_in(code, glob, loc=None):
if isinstance(code, str):
code = compile(code, '<string>', 'exec', dont_inherit=True)
exec(code, glob, loc)
""")
else:
exec("""
def raise_exc_info(exc_info):
raise exc_info[0], exc_info[1], exc_info[2]
def exec_in(code, glob, loc=None):
if isinstance(code, basestring):
# exec(string) inherits the caller's future imports; compile
# the string first to prevent that.
code = compile(code, '<string>', 'exec', dont_inherit=True)
exec code in glob, loc
""")
def errno_from_exception(e):
"""Provides the errno from an Exception object.
There are cases that the errno attribute was not set so we pull
the errno out of the args but if someone instatiates an Exception
without any args you will get a tuple error. So this function
abstracts all that behavior to give you a safe way to get the
errno.
"""
if hasattr(e, 'errno'):
return e.errno
elif e.args:
return e.args[0]
else:
return None
class Configurable(object):
"""Base class for configurable interfaces.
A configurable interface is an (abstract) class whose constructor
acts as a factory function for one of its implementation subclasses.
The implementation subclass as well as optional keyword arguments to
its initializer can be set globally at runtime with `configure`.
By using the constructor as the factory method, the interface
looks like a normal class, `isinstance` works as usual, etc. This
pattern is most useful when the choice of implementation is likely
to be a global decision (e.g. when `~select.epoll` is available,
always use it instead of `~select.select`), or when a
previously-monolithic class has been split into specialized
subclasses.
Configurable subclasses must define the class methods
`configurable_base` and `configurable_default`, and use the instance
method `initialize` instead of ``__init__``.
"""
__impl_class = None
__impl_kwargs = None
def __new__(cls, **kwargs):
base = cls.configurable_base()
args = {}
if cls is base:
impl = cls.configured_class()
if base.__impl_kwargs:
args.update(base.__impl_kwargs)
else:
impl = cls
args.update(kwargs)
instance = super(Configurable, cls).__new__(impl)
# initialize vs __init__ chosen for compatiblity with AsyncHTTPClient
# singleton magic. If we get rid of that we can switch to __init__
# here too.
instance.initialize(**args)
return instance
@classmethod
def configurable_base(cls):
"""Returns the base class of a configurable hierarchy.
This will normally return the class in which it is defined.
(which is *not* necessarily the same as the cls classmethod parameter).
"""
raise NotImplementedError()
@classmethod
def configurable_default(cls):
"""Returns the implementation class to be used if none is configured."""
raise NotImplementedError()
def initialize(self):
"""Initialize a `Configurable` subclass instance.
Configurable classes should use `initialize` instead of ``__init__``.
"""
@classmethod
def configure(cls, impl, **kwargs):
"""Sets the class to use when the base class is instantiated.
Keyword arguments will be saved and added to the arguments passed
to the constructor. This can be used to set global defaults for
some parameters.
"""
base = cls.configurable_base()
if isinstance(impl, (unicode_type, bytes_type)):
impl = import_object(impl)
if impl is not None and not issubclass(impl, cls):
raise ValueError("Invalid subclass of %s" % cls)
base.__impl_class = impl
base.__impl_kwargs = kwargs
@classmethod
def configured_class(cls):
"""Returns the currently configured class."""
base = cls.configurable_base()
if cls.__impl_class is None:
base.__impl_class = cls.configurable_default()
return base.__impl_class
@classmethod
def _save_configuration(cls):
base = cls.configurable_base()
return (base.__impl_class, base.__impl_kwargs)
@classmethod
def _restore_configuration(cls, saved):
base = cls.configurable_base()
base.__impl_class = saved[0]
base.__impl_kwargs = saved[1]
class ArgReplacer(object):
"""Replaces one value in an ``args, kwargs`` pair.
Inspects the function signature to find an argument by name
whether it is passed by position or keyword. For use in decorators
and similar wrappers.
"""
def __init__(self, func, name):
self.name = name
try:
self.arg_pos = inspect.getargspec(func).args.index(self.name)
except ValueError:
# Not a positional parameter
self.arg_pos = None
def get_old_value(self, args, kwargs, default=None):
"""Returns the old value of the named argument without replacing it.
Returns ``default`` if the argument is not present.
"""
if self.arg_pos is not None and len(args) > self.arg_pos:
return args[self.arg_pos]
else:
return kwargs.get(self.name, default)
def replace(self, new_value, args, kwargs):
"""Replace the named argument in ``args, kwargs`` with ``new_value``.
Returns ``(old_value, args, kwargs)``. The returned ``args`` and
``kwargs`` objects may not be the same as the input objects, or
the input objects may be mutated.
If the named argument was not found, ``new_value`` will be added
to ``kwargs`` and None will be returned as ``old_value``.
"""
if self.arg_pos is not None and len(args) > self.arg_pos:
# The arg to replace is passed positionally
old_value = args[self.arg_pos]
args = list(args) # *args is normally a tuple
args[self.arg_pos] = new_value
else:
# The arg to replace is either omitted or passed by keyword.
old_value = kwargs.get(self.name)
kwargs[self.name] = new_value
return old_value, args, kwargs
def timedelta_to_seconds(td):
"""Equivalent to td.total_seconds() (introduced in python 2.7)."""
return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10 ** 6) / float(10 ** 6)
def _websocket_mask_python(mask, data):
"""Websocket masking function.
`mask` is a `bytes` object of length 4; `data` is a `bytes` object of any length.
Returns a `bytes` object of the same length as `data` with the mask applied
as specified in section 5.3 of RFC 6455.
This pure-python implementation may be replaced by an optimized version when available.
"""
mask = array.array("B", mask)
unmasked = array.array("B", data)
for i in xrange(len(data)):
unmasked[i] = unmasked[i] ^ mask[i % 4]
if hasattr(unmasked, 'tobytes'):
# tostring was deprecated in py32. It hasn't been removed,
# but since we turn on deprecation warnings in our tests
# we need to use the right one.
return unmasked.tobytes()
else:
return unmasked.tostring()
if (os.environ.get('TORNADO_NO_EXTENSION') or
os.environ.get('TORNADO_EXTENSION') == '0'):
# These environment variables exist to make it easier to do performance
# comparisons; they are not guaranteed to remain supported in the future.
_websocket_mask = _websocket_mask_python
else:
try:
from tornado.speedups import websocket_mask as _websocket_mask
except ImportError:
if os.environ.get('TORNADO_EXTENSION') == '1':
raise
_websocket_mask = _websocket_mask_python
def doctests():
import doctest
return doctest.DocTestSuite()

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,769 @@
"""Implementation of the WebSocket protocol.
`WebSockets <http://dev.w3.org/html5/websockets/>`_ allow for bidirectional
communication between the browser and server.
WebSockets are supported in the current versions of all major browsers,
although older versions that do not support WebSockets are still in use
(refer to http://caniuse.com/websockets for details).
This module implements the final version of the WebSocket protocol as
defined in `RFC 6455 <http://tools.ietf.org/html/rfc6455>`_. Certain
browser versions (notably Safari 5.x) implemented an earlier draft of
the protocol (known as "draft 76") and are not compatible with this module.
.. versionchanged:: 4.0
Removed support for the draft 76 protocol version.
"""
from __future__ import absolute_import, division, print_function, with_statement
# Author: Jacob Kristhammar, 2010
import base64
import collections
import hashlib
import os
import struct
import tornado.escape
import tornado.web
from tornado.concurrent import TracebackFuture
from tornado.escape import utf8, native_str, to_unicode
from tornado import httpclient, httputil
from tornado.ioloop import IOLoop
from tornado.iostream import StreamClosedError
from tornado.log import gen_log, app_log
from tornado import simple_httpclient
from tornado.tcpclient import TCPClient
from tornado.util import bytes_type, _websocket_mask
try:
from urllib.parse import urlparse # py2
except ImportError:
from urlparse import urlparse # py3
try:
xrange # py2
except NameError:
xrange = range # py3
class WebSocketError(Exception):
pass
class WebSocketClosedError(WebSocketError):
"""Raised by operations on a closed connection.
.. versionadded:: 3.2
"""
pass
class WebSocketHandler(tornado.web.RequestHandler):
"""Subclass this class to create a basic WebSocket handler.
Override `on_message` to handle incoming messages, and use
`write_message` to send messages to the client. You can also
override `open` and `on_close` to handle opened and closed
connections.
See http://dev.w3.org/html5/websockets/ for details on the
JavaScript interface. The protocol is specified at
http://tools.ietf.org/html/rfc6455.
Here is an example WebSocket handler that echos back all received messages
back to the client::
class EchoWebSocket(websocket.WebSocketHandler):
def open(self):
print "WebSocket opened"
def on_message(self, message):
self.write_message(u"You said: " + message)
def on_close(self):
print "WebSocket closed"
WebSockets are not standard HTTP connections. The "handshake" is
HTTP, but after the handshake, the protocol is
message-based. Consequently, most of the Tornado HTTP facilities
are not available in handlers of this type. The only communication
methods available to you are `write_message()`, `ping()`, and
`close()`. Likewise, your request handler class should implement
`open()` method rather than ``get()`` or ``post()``.
If you map the handler above to ``/websocket`` in your application, you can
invoke it in JavaScript with::
var ws = new WebSocket("ws://localhost:8888/websocket");
ws.onopen = function() {
ws.send("Hello, world");
};
ws.onmessage = function (evt) {
alert(evt.data);
};
This script pops up an alert box that says "You said: Hello, world".
"""
def __init__(self, application, request, **kwargs):
tornado.web.RequestHandler.__init__(self, application, request,
**kwargs)
self.ws_connection = None
self.close_code = None
self.close_reason = None
self.stream = None
@tornado.web.asynchronous
def get(self, *args, **kwargs):
self.open_args = args
self.open_kwargs = kwargs
# Upgrade header should be present and should be equal to WebSocket
if self.request.headers.get("Upgrade", "").lower() != 'websocket':
self.set_status(400)
self.finish("Can \"Upgrade\" only to \"WebSocket\".")
return
# Connection header should be upgrade. Some proxy servers/load balancers
# might mess with it.
headers = self.request.headers
connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(","))
if 'upgrade' not in connection:
self.set_status(400)
self.finish("\"Connection\" must be \"Upgrade\".")
return
# Handle WebSocket Origin naming convention differences
# The difference between version 8 and 13 is that in 8 the
# client sends a "Sec-Websocket-Origin" header and in 13 it's
# simply "Origin".
if "Origin" in self.request.headers:
origin = self.request.headers.get("Origin")
else:
origin = self.request.headers.get("Sec-Websocket-Origin", None)
# If there was an origin header, check to make sure it matches
# according to check_origin. When the origin is None, we assume it
# did not come from a browser and that it can be passed on.
if origin is not None and not self.check_origin(origin):
self.set_status(403)
self.finish("Cross origin websockets not allowed")
return
self.stream = self.request.connection.detach()
self.stream.set_close_callback(self.on_connection_close)
if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
self.ws_connection = WebSocketProtocol13(self)
self.ws_connection.accept_connection()
else:
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 426 Upgrade Required\r\n"
"Sec-WebSocket-Version: 8\r\n\r\n"))
self.stream.close()
def write_message(self, message, binary=False):
"""Sends the given message to the client of this Web Socket.
The message may be either a string or a dict (which will be
encoded as json). If the ``binary`` argument is false, the
message will be sent as utf8; in binary mode any byte string
is allowed.
If the connection is already closed, raises `WebSocketClosedError`.
.. versionchanged:: 3.2
`WebSocketClosedError` was added (previously a closed connection
would raise an `AttributeError`)
"""
if self.ws_connection is None:
raise WebSocketClosedError()
if isinstance(message, dict):
message = tornado.escape.json_encode(message)
self.ws_connection.write_message(message, binary=binary)
def select_subprotocol(self, subprotocols):
"""Invoked when a new WebSocket requests specific subprotocols.
``subprotocols`` is a list of strings identifying the
subprotocols proposed by the client. This method may be
overridden to return one of those strings to select it, or
``None`` to not select a subprotocol. Failure to select a
subprotocol does not automatically abort the connection,
although clients may close the connection if none of their
proposed subprotocols was selected.
"""
return None
def open(self):
"""Invoked when a new WebSocket is opened.
The arguments to `open` are extracted from the `tornado.web.URLSpec`
regular expression, just like the arguments to
`tornado.web.RequestHandler.get`.
"""
pass
def on_message(self, message):
"""Handle incoming messages on the WebSocket
This method must be overridden.
"""
raise NotImplementedError
def ping(self, data):
"""Send ping frame to the remote end."""
if self.ws_connection is None:
raise WebSocketClosedError()
self.ws_connection.write_ping(data)
def on_pong(self, data):
"""Invoked when the response to a ping frame is received."""
pass
def on_close(self):
"""Invoked when the WebSocket is closed.
If the connection was closed cleanly and a status code or reason
phrase was supplied, these values will be available as the attributes
``self.close_code`` and ``self.close_reason``.
.. versionchanged:: 4.0
Added ``close_code`` and ``close_reason`` attributes.
"""
pass
def close(self, code=None, reason=None):
"""Closes this Web Socket.
Once the close handshake is successful the socket will be closed.
``code`` may be a numeric status code, taken from the values
defined in `RFC 6455 section 7.4.1
<https://tools.ietf.org/html/rfc6455#section-7.4.1>`_.
``reason`` may be a textual message about why the connection is
closing. These values are made available to the client, but are
not otherwise interpreted by the websocket protocol.
.. versionchanged:: 4.0
Added the ``code`` and ``reason`` arguments.
"""
if self.ws_connection:
self.ws_connection.close(code, reason)
self.ws_connection = None
def check_origin(self, origin):
"""Override to enable support for allowing alternate origins.
The ``origin`` argument is the value of the ``Origin`` HTTP
header, the url responsible for initiating this request. This
method is not called for clients that do not send this header;
such requests are always allowed (because all browsers that
implement WebSockets support this header, and non-browser
clients do not have the same cross-site security concerns).
Should return True to accept the request or False to reject it.
By default, rejects all requests with an origin on a host other
than this one.
This is a security protection against cross site scripting attacks on
browsers, since WebSockets are allowed to bypass the usual same-origin
policies and don't use CORS headers.
.. versionadded:: 4.0
"""
parsed_origin = urlparse(origin)
origin = parsed_origin.netloc
origin = origin.lower()
host = self.request.headers.get("Host")
# Check to see that origin matches host directly, including ports
return origin == host
def set_nodelay(self, value):
"""Set the no-delay flag for this stream.
By default, small messages may be delayed and/or combined to minimize
the number of packets sent. This can sometimes cause 200-500ms delays
due to the interaction between Nagle's algorithm and TCP delayed
ACKs. To reduce this delay (at the expense of possibly increasing
bandwidth usage), call ``self.set_nodelay(True)`` once the websocket
connection is established.
See `.BaseIOStream.set_nodelay` for additional details.
.. versionadded:: 3.1
"""
self.stream.set_nodelay(value)
def on_connection_close(self):
if self.ws_connection:
self.ws_connection.on_connection_close()
self.ws_connection = None
self.on_close()
def _wrap_method(method):
def _disallow_for_websocket(self, *args, **kwargs):
if self.stream is None:
method(self, *args, **kwargs)
else:
raise RuntimeError("Method not supported for Web Sockets")
return _disallow_for_websocket
for method in ["write", "redirect", "set_header", "send_error", "set_cookie",
"set_status", "flush", "finish"]:
setattr(WebSocketHandler, method,
_wrap_method(getattr(WebSocketHandler, method)))
class WebSocketProtocol(object):
"""Base class for WebSocket protocol versions.
"""
def __init__(self, handler):
self.handler = handler
self.request = handler.request
self.stream = handler.stream
self.client_terminated = False
self.server_terminated = False
def _run_callback(self, callback, *args, **kwargs):
"""Runs the given callback with exception handling.
On error, aborts the websocket connection and returns False.
"""
try:
callback(*args, **kwargs)
except Exception:
app_log.error("Uncaught exception in %s",
self.request.path, exc_info=True)
self._abort()
def on_connection_close(self):
self._abort()
def _abort(self):
"""Instantly aborts the WebSocket connection by closing the socket"""
self.client_terminated = True
self.server_terminated = True
self.stream.close() # forcibly tear down the connection
self.close() # let the subclass cleanup
class WebSocketProtocol13(WebSocketProtocol):
"""Implementation of the WebSocket protocol from RFC 6455.
This class supports versions 7 and 8 of the protocol in addition to the
final version 13.
"""
def __init__(self, handler, mask_outgoing=False):
WebSocketProtocol.__init__(self, handler)
self.mask_outgoing = mask_outgoing
self._final_frame = False
self._frame_opcode = None
self._masked_frame = None
self._frame_mask = None
self._frame_length = None
self._fragmented_message_buffer = None
self._fragmented_message_opcode = None
self._waiting = None
def accept_connection(self):
try:
self._handle_websocket_headers()
self._accept_connection()
except ValueError:
gen_log.debug("Malformed WebSocket request received", exc_info=True)
self._abort()
return
def _handle_websocket_headers(self):
"""Verifies all invariant- and required headers
If a header is missing or have an incorrect value ValueError will be
raised
"""
fields = ("Host", "Sec-Websocket-Key", "Sec-Websocket-Version")
if not all(map(lambda f: self.request.headers.get(f), fields)):
raise ValueError("Missing/Invalid WebSocket headers")
@staticmethod
def compute_accept_value(key):
"""Computes the value for the Sec-WebSocket-Accept header,
given the value for Sec-WebSocket-Key.
"""
sha1 = hashlib.sha1()
sha1.update(utf8(key))
sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11") # Magic value
return native_str(base64.b64encode(sha1.digest()))
def _challenge_response(self):
return WebSocketProtocol13.compute_accept_value(
self.request.headers.get("Sec-Websocket-Key"))
def _accept_connection(self):
subprotocol_header = ''
subprotocols = self.request.headers.get("Sec-WebSocket-Protocol", '')
subprotocols = [s.strip() for s in subprotocols.split(',')]
if subprotocols:
selected = self.handler.select_subprotocol(subprotocols)
if selected:
assert selected in subprotocols
subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Accept: %s\r\n"
"%s"
"\r\n" % (self._challenge_response(), subprotocol_header)))
self._run_callback(self.handler.open, *self.handler.open_args,
**self.handler.open_kwargs)
self._receive_frame()
def _write_frame(self, fin, opcode, data):
if fin:
finbit = 0x80
else:
finbit = 0
frame = struct.pack("B", finbit | opcode)
l = len(data)
if self.mask_outgoing:
mask_bit = 0x80
else:
mask_bit = 0
if l < 126:
frame += struct.pack("B", l | mask_bit)
elif l <= 0xFFFF:
frame += struct.pack("!BH", 126 | mask_bit, l)
else:
frame += struct.pack("!BQ", 127 | mask_bit, l)
if self.mask_outgoing:
mask = os.urandom(4)
data = mask + _websocket_mask(mask, data)
frame += data
self.stream.write(frame)
def write_message(self, message, binary=False):
"""Sends the given message to the client of this Web Socket."""
if binary:
opcode = 0x2
else:
opcode = 0x1
message = tornado.escape.utf8(message)
assert isinstance(message, bytes_type)
try:
self._write_frame(True, opcode, message)
except StreamClosedError:
self._abort()
def write_ping(self, data):
"""Send ping frame."""
assert isinstance(data, bytes_type)
self._write_frame(True, 0x9, data)
def _receive_frame(self):
try:
self.stream.read_bytes(2, self._on_frame_start)
except StreamClosedError:
self._abort()
def _on_frame_start(self, data):
header, payloadlen = struct.unpack("BB", data)
self._final_frame = header & 0x80
reserved_bits = header & 0x70
self._frame_opcode = header & 0xf
self._frame_opcode_is_control = self._frame_opcode & 0x8
if reserved_bits:
# client is using as-yet-undefined extensions; abort
self._abort()
return
self._masked_frame = bool(payloadlen & 0x80)
payloadlen = payloadlen & 0x7f
if self._frame_opcode_is_control and payloadlen >= 126:
# control frames must have payload < 126
self._abort()
return
try:
if payloadlen < 126:
self._frame_length = payloadlen
if self._masked_frame:
self.stream.read_bytes(4, self._on_masking_key)
else:
self.stream.read_bytes(self._frame_length, self._on_frame_data)
elif payloadlen == 126:
self.stream.read_bytes(2, self._on_frame_length_16)
elif payloadlen == 127:
self.stream.read_bytes(8, self._on_frame_length_64)
except StreamClosedError:
self._abort()
def _on_frame_length_16(self, data):
self._frame_length = struct.unpack("!H", data)[0]
try:
if self._masked_frame:
self.stream.read_bytes(4, self._on_masking_key)
else:
self.stream.read_bytes(self._frame_length, self._on_frame_data)
except StreamClosedError:
self._abort()
def _on_frame_length_64(self, data):
self._frame_length = struct.unpack("!Q", data)[0]
try:
if self._masked_frame:
self.stream.read_bytes(4, self._on_masking_key)
else:
self.stream.read_bytes(self._frame_length, self._on_frame_data)
except StreamClosedError:
self._abort()
def _on_masking_key(self, data):
self._frame_mask = data
try:
self.stream.read_bytes(self._frame_length, self._on_masked_frame_data)
except StreamClosedError:
self._abort()
def _on_masked_frame_data(self, data):
self._on_frame_data(_websocket_mask(self._frame_mask, data))
def _on_frame_data(self, data):
if self._frame_opcode_is_control:
# control frames may be interleaved with a series of fragmented
# data frames, so control frames must not interact with
# self._fragmented_*
if not self._final_frame:
# control frames must not be fragmented
self._abort()
return
opcode = self._frame_opcode
elif self._frame_opcode == 0: # continuation frame
if self._fragmented_message_buffer is None:
# nothing to continue
self._abort()
return
self._fragmented_message_buffer += data
if self._final_frame:
opcode = self._fragmented_message_opcode
data = self._fragmented_message_buffer
self._fragmented_message_buffer = None
else: # start of new data message
if self._fragmented_message_buffer is not None:
# can't start new message until the old one is finished
self._abort()
return
if self._final_frame:
opcode = self._frame_opcode
else:
self._fragmented_message_opcode = self._frame_opcode
self._fragmented_message_buffer = data
if self._final_frame:
self._handle_message(opcode, data)
if not self.client_terminated:
self._receive_frame()
def _handle_message(self, opcode, data):
if self.client_terminated:
return
if opcode == 0x1:
# UTF-8 data
try:
decoded = data.decode("utf-8")
except UnicodeDecodeError:
self._abort()
return
self._run_callback(self.handler.on_message, decoded)
elif opcode == 0x2:
# Binary data
self._run_callback(self.handler.on_message, data)
elif opcode == 0x8:
# Close
self.client_terminated = True
if len(data) >= 2:
self.handler.close_code = struct.unpack('>H', data[:2])[0]
if len(data) > 2:
self.handler.close_reason = to_unicode(data[2:])
self.close()
elif opcode == 0x9:
# Ping
self._write_frame(True, 0xA, data)
elif opcode == 0xA:
# Pong
self._run_callback(self.handler.on_pong, data)
else:
self._abort()
def close(self, code=None, reason=None):
"""Closes the WebSocket connection."""
if not self.server_terminated:
if not self.stream.closed():
if code is None and reason is not None:
code = 1000 # "normal closure" status code
if code is None:
close_data = b''
else:
close_data = struct.pack('>H', code)
if reason is not None:
close_data += utf8(reason)
self._write_frame(True, 0x8, close_data)
self.server_terminated = True
if self.client_terminated:
if self._waiting is not None:
self.stream.io_loop.remove_timeout(self._waiting)
self._waiting = None
self.stream.close()
elif self._waiting is None:
# Give the client a few seconds to complete a clean shutdown,
# otherwise just close the connection.
self._waiting = self.stream.io_loop.add_timeout(
self.stream.io_loop.time() + 5, self._abort)
class WebSocketClientConnection(simple_httpclient._HTTPConnection):
"""WebSocket client connection.
This class should not be instantiated directly; use the
`websocket_connect` function instead.
"""
def __init__(self, io_loop, request):
self.connect_future = TracebackFuture()
self.read_future = None
self.read_queue = collections.deque()
self.key = base64.b64encode(os.urandom(16))
scheme, sep, rest = request.url.partition(':')
scheme = {'ws': 'http', 'wss': 'https'}[scheme]
request.url = scheme + sep + rest
request.headers.update({
'Upgrade': 'websocket',
'Connection': 'Upgrade',
'Sec-WebSocket-Key': self.key,
'Sec-WebSocket-Version': '13',
})
self.tcp_client = TCPClient(io_loop=io_loop)
super(WebSocketClientConnection, self).__init__(
io_loop, None, request, lambda: None, self._on_http_response,
104857600, self.tcp_client, 65536)
def close(self, code=None, reason=None):
"""Closes the websocket connection.
``code`` and ``reason`` are documented under
`WebSocketHandler.close`.
.. versionadded:: 3.2
.. versionchanged:: 4.0
Added the ``code`` and ``reason`` arguments.
"""
if self.protocol is not None:
self.protocol.close(code, reason)
self.protocol = None
def _on_close(self):
self.on_message(None)
self.resolver.close()
super(WebSocketClientConnection, self)._on_close()
def _on_http_response(self, response):
if not self.connect_future.done():
if response.error:
self.connect_future.set_exception(response.error)
else:
self.connect_future.set_exception(WebSocketError(
"Non-websocket response"))
def headers_received(self, start_line, headers):
if start_line.code != 101:
return super(WebSocketClientConnection, self).headers_received(
start_line, headers)
self.headers = headers
assert self.headers['Upgrade'].lower() == 'websocket'
assert self.headers['Connection'].lower() == 'upgrade'
accept = WebSocketProtocol13.compute_accept_value(self.key)
assert self.headers['Sec-Websocket-Accept'] == accept
self.protocol = WebSocketProtocol13(self, mask_outgoing=True)
self.protocol._receive_frame()
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
self._timeout = None
self.stream = self.connection.detach()
self.stream.set_close_callback(self._on_close)
self.connect_future.set_result(self)
def write_message(self, message, binary=False):
"""Sends a message to the WebSocket server."""
self.protocol.write_message(message, binary)
def read_message(self, callback=None):
"""Reads a message from the WebSocket server.
Returns a future whose result is the message, or None
if the connection is closed. If a callback argument
is given it will be called with the future when it is
ready.
"""
assert self.read_future is None
future = TracebackFuture()
if self.read_queue:
future.set_result(self.read_queue.popleft())
else:
self.read_future = future
if callback is not None:
self.io_loop.add_future(future, callback)
return future
def on_message(self, message):
if self.read_future is not None:
self.read_future.set_result(message)
self.read_future = None
else:
self.read_queue.append(message)
def on_pong(self, data):
pass
def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None):
"""Client-side websocket support.
Takes a url and returns a Future whose result is a
`WebSocketClientConnection`.
.. versionchanged:: 3.2
Also accepts ``HTTPRequest`` objects in place of urls.
"""
if io_loop is None:
io_loop = IOLoop.current()
if isinstance(url, httpclient.HTTPRequest):
assert connect_timeout is None
request = url
# Copy and convert the headers dict/object (see comments in
# AsyncHTTPClient.fetch)
request.headers = httputil.HTTPHeaders(request.headers)
else:
request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
request = httpclient._RequestProxy(
request, httpclient.HTTPRequest._DEFAULTS)
conn = WebSocketClientConnection(io_loop, request)
if callback is not None:
io_loop.add_future(conn.connect_future, callback)
return conn.connect_future

View file

@ -0,0 +1,361 @@
#!/usr/bin/env python
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""WSGI support for the Tornado web framework.
WSGI is the Python standard for web servers, and allows for interoperability
between Tornado and other Python web frameworks and servers. This module
provides WSGI support in two ways:
* `WSGIAdapter` converts a `tornado.web.Application` to the WSGI application
interface. This is useful for running a Tornado app on another
HTTP server, such as Google App Engine. See the `WSGIAdapter` class
documentation for limitations that apply.
* `WSGIContainer` lets you run other WSGI applications and frameworks on the
Tornado HTTP server. For example, with this class you can mix Django
and Tornado handlers in a single server.
"""
from __future__ import absolute_import, division, print_function, with_statement
import sys
import tornado
from tornado.concurrent import Future
from tornado import escape
from tornado import httputil
from tornado.log import access_log
from tornado import web
from tornado.escape import native_str
from tornado.util import bytes_type, unicode_type
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO # python 2
try:
import urllib.parse as urllib_parse # py3
except ImportError:
import urllib as urllib_parse
# PEP 3333 specifies that WSGI on python 3 generally deals with byte strings
# that are smuggled inside objects of type unicode (via the latin1 encoding).
# These functions are like those in the tornado.escape module, but defined
# here to minimize the temptation to use them in non-wsgi contexts.
if str is unicode_type:
def to_wsgi_str(s):
assert isinstance(s, bytes_type)
return s.decode('latin1')
def from_wsgi_str(s):
assert isinstance(s, str)
return s.encode('latin1')
else:
def to_wsgi_str(s):
assert isinstance(s, bytes_type)
return s
def from_wsgi_str(s):
assert isinstance(s, str)
return s
class WSGIApplication(web.Application):
"""A WSGI equivalent of `tornado.web.Application`.
.. deprecated:: 4.0
Use a regular `.Application` and wrap it in `WSGIAdapter` instead.
"""
def __call__(self, environ, start_response):
return WSGIAdapter(self)(environ, start_response)
# WSGI has no facilities for flow control, so just return an already-done
# Future when the interface requires it.
_dummy_future = Future()
_dummy_future.set_result(None)
class _WSGIConnection(httputil.HTTPConnection):
def __init__(self, method, start_response, context):
self.method = method
self.start_response = start_response
self.context = context
self._write_buffer = []
self._finished = False
self._expected_content_remaining = None
self._error = None
def set_close_callback(self, callback):
# WSGI has no facility for detecting a closed connection mid-request,
# so we can simply ignore the callback.
pass
def write_headers(self, start_line, headers, chunk=None, callback=None):
if self.method == 'HEAD':
self._expected_content_remaining = 0
elif 'Content-Length' in headers:
self._expected_content_remaining = int(headers['Content-Length'])
else:
self._expected_content_remaining = None
self.start_response(
'%s %s' % (start_line.code, start_line.reason),
[(native_str(k), native_str(v)) for (k, v) in headers.get_all()])
if chunk is not None:
self.write(chunk, callback)
elif callback is not None:
callback()
return _dummy_future
def write(self, chunk, callback=None):
if self._expected_content_remaining is not None:
self._expected_content_remaining -= len(chunk)
if self._expected_content_remaining < 0:
self._error = httputil.HTTPOutputError(
"Tried to write more data than Content-Length")
raise self._error
self._write_buffer.append(chunk)
if callback is not None:
callback()
return _dummy_future
def finish(self):
if (self._expected_content_remaining is not None and
self._expected_content_remaining != 0):
self._error = httputil.HTTPOutputError(
"Tried to write %d bytes less than Content-Length" %
self._expected_content_remaining)
raise self._error
self._finished = True
class _WSGIRequestContext(object):
def __init__(self, remote_ip, protocol):
self.remote_ip = remote_ip
self.protocol = protocol
def __str__(self):
return self.remote_ip
class WSGIAdapter(object):
"""Converts a `tornado.web.Application` instance into a WSGI application.
Example usage::
import tornado.web
import tornado.wsgi
import wsgiref.simple_server
class MainHandler(tornado.web.RequestHandler):
def get(self):
self.write("Hello, world")
if __name__ == "__main__":
application = tornado.web.Application([
(r"/", MainHandler),
])
wsgi_app = tornado.wsgi.WSGIAdapter(application)
server = wsgiref.simple_server.make_server('', 8888, wsgi_app)
server.serve_forever()
See the `appengine demo
<https://github.com/tornadoweb/tornado/tree/stable/demos/appengine>`_
for an example of using this module to run a Tornado app on Google
App Engine.
In WSGI mode asynchronous methods are not supported. This means
that it is not possible to use `.AsyncHTTPClient`, or the
`tornado.auth` or `tornado.websocket` modules.
.. versionadded:: 4.0
"""
def __init__(self, application):
if isinstance(application, WSGIApplication):
self.application = lambda request: web.Application.__call__(
application, request)
else:
self.application = application
def __call__(self, environ, start_response):
method = environ["REQUEST_METHOD"]
uri = urllib_parse.quote(from_wsgi_str(environ.get("SCRIPT_NAME", "")))
uri += urllib_parse.quote(from_wsgi_str(environ.get("PATH_INFO", "")))
if environ.get("QUERY_STRING"):
uri += "?" + environ["QUERY_STRING"]
headers = httputil.HTTPHeaders()
if environ.get("CONTENT_TYPE"):
headers["Content-Type"] = environ["CONTENT_TYPE"]
if environ.get("CONTENT_LENGTH"):
headers["Content-Length"] = environ["CONTENT_LENGTH"]
for key in environ:
if key.startswith("HTTP_"):
headers[key[5:].replace("_", "-")] = environ[key]
if headers.get("Content-Length"):
body = environ["wsgi.input"].read(
int(headers["Content-Length"]))
else:
body = ""
protocol = environ["wsgi.url_scheme"]
remote_ip = environ.get("REMOTE_ADDR", "")
if environ.get("HTTP_HOST"):
host = environ["HTTP_HOST"]
else:
host = environ["SERVER_NAME"]
connection = _WSGIConnection(method, start_response,
_WSGIRequestContext(remote_ip, protocol))
request = httputil.HTTPServerRequest(
method, uri, "HTTP/1.1", headers=headers, body=body,
host=host, connection=connection)
request._parse_body()
self.application(request)
if connection._error:
raise connection._error
if not connection._finished:
raise Exception("request did not finish synchronously")
return connection._write_buffer
class WSGIContainer(object):
r"""Makes a WSGI-compatible function runnable on Tornado's HTTP server.
.. warning::
WSGI is a *synchronous* interface, while Tornado's concurrency model
is based on single-threaded asynchronous execution. This means that
running a WSGI app with Tornado's `WSGIContainer` is *less scalable*
than running the same app in a multi-threaded WSGI server like
``gunicorn`` or ``uwsgi``. Use `WSGIContainer` only when there are
benefits to combining Tornado and WSGI in the same process that
outweigh the reduced scalability.
Wrap a WSGI function in a `WSGIContainer` and pass it to `.HTTPServer` to
run it. For example::
def simple_app(environ, start_response):
status = "200 OK"
response_headers = [("Content-type", "text/plain")]
start_response(status, response_headers)
return ["Hello world!\n"]
container = tornado.wsgi.WSGIContainer(simple_app)
http_server = tornado.httpserver.HTTPServer(container)
http_server.listen(8888)
tornado.ioloop.IOLoop.instance().start()
This class is intended to let other frameworks (Django, web.py, etc)
run on the Tornado HTTP server and I/O loop.
The `tornado.web.FallbackHandler` class is often useful for mixing
Tornado and WSGI apps in the same server. See
https://github.com/bdarnell/django-tornado-demo for a complete example.
"""
def __init__(self, wsgi_application):
self.wsgi_application = wsgi_application
def __call__(self, request):
data = {}
response = []
def start_response(status, response_headers, exc_info=None):
data["status"] = status
data["headers"] = response_headers
return response.append
app_response = self.wsgi_application(
WSGIContainer.environ(request), start_response)
try:
response.extend(app_response)
body = b"".join(response)
finally:
if hasattr(app_response, "close"):
app_response.close()
if not data:
raise Exception("WSGI app did not call start_response")
status_code = int(data["status"].split()[0])
headers = data["headers"]
header_set = set(k.lower() for (k, v) in headers)
body = escape.utf8(body)
if status_code != 304:
if "content-length" not in header_set:
headers.append(("Content-Length", str(len(body))))
if "content-type" not in header_set:
headers.append(("Content-Type", "text/html; charset=UTF-8"))
if "server" not in header_set:
headers.append(("Server", "TornadoServer/%s" % tornado.version))
parts = [escape.utf8("HTTP/1.1 " + data["status"] + "\r\n")]
for key, value in headers:
parts.append(escape.utf8(key) + b": " + escape.utf8(value) + b"\r\n")
parts.append(b"\r\n")
parts.append(body)
request.write(b"".join(parts))
request.finish()
self._log(status_code, request)
@staticmethod
def environ(request):
"""Converts a `tornado.httputil.HTTPServerRequest` to a WSGI environment.
"""
hostport = request.host.split(":")
if len(hostport) == 2:
host = hostport[0]
port = int(hostport[1])
else:
host = request.host
port = 443 if request.protocol == "https" else 80
environ = {
"REQUEST_METHOD": request.method,
"SCRIPT_NAME": "",
"PATH_INFO": to_wsgi_str(escape.url_unescape(
request.path, encoding=None, plus=False)),
"QUERY_STRING": request.query,
"REMOTE_ADDR": request.remote_ip,
"SERVER_NAME": host,
"SERVER_PORT": str(port),
"SERVER_PROTOCOL": request.version,
"wsgi.version": (1, 0),
"wsgi.url_scheme": request.protocol,
"wsgi.input": BytesIO(escape.utf8(request.body)),
"wsgi.errors": sys.stderr,
"wsgi.multithread": False,
"wsgi.multiprocess": True,
"wsgi.run_once": False,
}
if "Content-Type" in request.headers:
environ["CONTENT_TYPE"] = request.headers.pop("Content-Type")
if "Content-Length" in request.headers:
environ["CONTENT_LENGTH"] = request.headers.pop("Content-Length")
for key, value in request.headers.items():
environ["HTTP_" + key.replace("-", "_").upper()] = value
return environ
def _log(self, status_code, request):
if status_code < 400:
log_method = access_log.info
elif status_code < 500:
log_method = access_log.warning
else:
log_method = access_log.error
request_time = 1000.0 * request.request_time()
summary = request.method + " " + request.uri + " (" + \
request.remote_ip + ")"
log_method("%d %s %.2fms", status_code, summary, request_time)
HTTPRequest = httputil.HTTPServerRequest