Open Media Library Platform

This commit is contained in:
j 2013-10-11 19:28:32 +02:00
commit 411ad5b16f
5849 changed files with 1778641 additions and 0 deletions

View file

@ -0,0 +1,29 @@
#!/usr/bin/env python
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""The Tornado web server and tools."""
from __future__ import absolute_import, division, print_function, with_statement
# version is a human-readable version number.
# version_info is a four-tuple for programmatic comparison. The first
# three numbers are the components of the version number. The fourth
# is zero for an official release, positive for a development branch,
# or negative for a release candidate or beta (after the base version
# number has been incremented)
version = "3.1.1"
version_info = (3, 1, 1, 0)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,316 @@
#!/usr/bin/env python
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""xAutomatically restart the server when a source file is modified.
Most applications should not access this module directly. Instead, pass the
keyword argument ``debug=True`` to the `tornado.web.Application` constructor.
This will enable autoreload mode as well as checking for changes to templates
and static resources. Note that restarting is a destructive operation
and any requests in progress will be aborted when the process restarts.
This module can also be used as a command-line wrapper around scripts
such as unit test runners. See the `main` method for details.
The command-line wrapper and Application debug modes can be used together.
This combination is encouraged as the wrapper catches syntax errors and
other import-time failures, while debug mode catches changes once
the server has started.
This module depends on `.IOLoop`, so it will not work in WSGI applications
and Google App Engine. It also will not work correctly when `.HTTPServer`'s
multi-process mode is used.
Reloading loses any Python interpreter command-line arguments (e.g. ``-u``)
because it re-executes Python using ``sys.executable`` and ``sys.argv``.
Additionally, modifying these variables will cause reloading to behave
incorrectly.
"""
from __future__ import absolute_import, division, print_function, with_statement
import os
import sys
# sys.path handling
# -----------------
#
# If a module is run with "python -m", the current directory (i.e. "")
# is automatically prepended to sys.path, but not if it is run as
# "path/to/file.py". The processing for "-m" rewrites the former to
# the latter, so subsequent executions won't have the same path as the
# original.
#
# Conversely, when run as path/to/file.py, the directory containing
# file.py gets added to the path, which can cause confusion as imports
# may become relative in spite of the future import.
#
# We address the former problem by setting the $PYTHONPATH environment
# variable before re-execution so the new process will see the correct
# path. We attempt to address the latter problem when tornado.autoreload
# is run as __main__, although we can't fix the general case because
# we cannot reliably reconstruct the original command line
# (http://bugs.python.org/issue14208).
if __name__ == "__main__":
# This sys.path manipulation must come before our imports (as much
# as possible - if we introduced a tornado.sys or tornado.os
# module we'd be in trouble), or else our imports would become
# relative again despite the future import.
#
# There is a separate __main__ block at the end of the file to call main().
if sys.path[0] == os.path.dirname(__file__):
del sys.path[0]
import functools
import logging
import os
import pkgutil
import sys
import traceback
import types
import subprocess
import weakref
from tornado import ioloop
from tornado.log import gen_log
from tornado import process
from tornado.util import exec_in
try:
import signal
except ImportError:
signal = None
_watched_files = set()
_reload_hooks = []
_reload_attempted = False
_io_loops = weakref.WeakKeyDictionary()
def start(io_loop=None, check_time=500):
"""Begins watching source files for changes using the given `.IOLoop`. """
io_loop = io_loop or ioloop.IOLoop.current()
if io_loop in _io_loops:
return
_io_loops[io_loop] = True
if len(_io_loops) > 1:
gen_log.warning("tornado.autoreload started more than once in the same process")
add_reload_hook(functools.partial(io_loop.close, all_fds=True))
modify_times = {}
callback = functools.partial(_reload_on_update, modify_times)
scheduler = ioloop.PeriodicCallback(callback, check_time, io_loop=io_loop)
scheduler.start()
def wait():
"""Wait for a watched file to change, then restart the process.
Intended to be used at the end of scripts like unit test runners,
to run the tests again after any source file changes (but see also
the command-line interface in `main`)
"""
io_loop = ioloop.IOLoop()
start(io_loop)
io_loop.start()
def watch(filename):
"""Add a file to the watch list.
All imported modules are watched by default.
"""
_watched_files.add(filename)
def add_reload_hook(fn):
"""Add a function to be called before reloading the process.
Note that for open file and socket handles it is generally
preferable to set the ``FD_CLOEXEC`` flag (using `fcntl` or
``tornado.platform.auto.set_close_exec``) instead
of using a reload hook to close them.
"""
_reload_hooks.append(fn)
def _reload_on_update(modify_times):
if _reload_attempted:
# We already tried to reload and it didn't work, so don't try again.
return
if process.task_id() is not None:
# We're in a child process created by fork_processes. If child
# processes restarted themselves, they'd all restart and then
# all call fork_processes again.
return
for module in sys.modules.values():
# Some modules play games with sys.modules (e.g. email/__init__.py
# in the standard library), and occasionally this can cause strange
# failures in getattr. Just ignore anything that's not an ordinary
# module.
if not isinstance(module, types.ModuleType):
continue
path = getattr(module, "__file__", None)
if not path:
continue
if path.endswith(".pyc") or path.endswith(".pyo"):
path = path[:-1]
_check_file(modify_times, path)
for path in _watched_files:
_check_file(modify_times, path)
def _check_file(modify_times, path):
try:
modified = os.stat(path).st_mtime
except Exception:
return
if path not in modify_times:
modify_times[path] = modified
return
if modify_times[path] != modified:
gen_log.info("%s modified; restarting server", path)
_reload()
def _reload():
global _reload_attempted
_reload_attempted = True
for fn in _reload_hooks:
fn()
if hasattr(signal, "setitimer"):
# Clear the alarm signal set by
# ioloop.set_blocking_log_threshold so it doesn't fire
# after the exec.
signal.setitimer(signal.ITIMER_REAL, 0, 0)
# sys.path fixes: see comments at top of file. If sys.path[0] is an empty
# string, we were (probably) invoked with -m and the effective path
# is about to change on re-exec. Add the current directory to $PYTHONPATH
# to ensure that the new process sees the same path we did.
path_prefix = '.' + os.pathsep
if (sys.path[0] == '' and
not os.environ.get("PYTHONPATH", "").startswith(path_prefix)):
os.environ["PYTHONPATH"] = (path_prefix +
os.environ.get("PYTHONPATH", ""))
if sys.platform == 'win32':
# os.execv is broken on Windows and can't properly parse command line
# arguments and executable name if they contain whitespaces. subprocess
# fixes that behavior.
subprocess.Popen([sys.executable] + sys.argv)
sys.exit(0)
else:
try:
os.execv(sys.executable, [sys.executable] + sys.argv)
except OSError:
# Mac OS X versions prior to 10.6 do not support execv in
# a process that contains multiple threads. Instead of
# re-executing in the current process, start a new one
# and cause the current process to exit. This isn't
# ideal since the new process is detached from the parent
# terminal and thus cannot easily be killed with ctrl-C,
# but it's better than not being able to autoreload at
# all.
# Unfortunately the errno returned in this case does not
# appear to be consistent, so we can't easily check for
# this error specifically.
os.spawnv(os.P_NOWAIT, sys.executable,
[sys.executable] + sys.argv)
sys.exit(0)
_USAGE = """\
Usage:
python -m tornado.autoreload -m module.to.run [args...]
python -m tornado.autoreload path/to/script.py [args...]
"""
def main():
"""Command-line wrapper to re-run a script whenever its source changes.
Scripts may be specified by filename or module name::
python -m tornado.autoreload -m tornado.test.runtests
python -m tornado.autoreload tornado/test/runtests.py
Running a script with this wrapper is similar to calling
`tornado.autoreload.wait` at the end of the script, but this wrapper
can catch import-time problems like syntax errors that would otherwise
prevent the script from reaching its call to `wait`.
"""
original_argv = sys.argv
sys.argv = sys.argv[:]
if len(sys.argv) >= 3 and sys.argv[1] == "-m":
mode = "module"
module = sys.argv[2]
del sys.argv[1:3]
elif len(sys.argv) >= 2:
mode = "script"
script = sys.argv[1]
sys.argv = sys.argv[1:]
else:
print(_USAGE, file=sys.stderr)
sys.exit(1)
try:
if mode == "module":
import runpy
runpy.run_module(module, run_name="__main__", alter_sys=True)
elif mode == "script":
with open(script) as f:
global __file__
__file__ = script
# Use globals as our "locals" dictionary so that
# something that tries to import __main__ (e.g. the unittest
# module) will see the right things.
exec_in(f.read(), globals(), globals())
except SystemExit as e:
logging.basicConfig()
gen_log.info("Script exited with status %s", e.code)
except Exception as e:
logging.basicConfig()
gen_log.warning("Script exited with uncaught exception", exc_info=True)
# If an exception occurred at import time, the file with the error
# never made it into sys.modules and so we won't know to watch it.
# Just to make sure we've covered everything, walk the stack trace
# from the exception and watch every file.
for (filename, lineno, name, line) in traceback.extract_tb(sys.exc_info()[2]):
watch(filename)
if isinstance(e, SyntaxError):
# SyntaxErrors are special: their innermost stack frame is fake
# so extract_tb won't see it and we have to get the filename
# from the exception object.
watch(e.filename)
else:
logging.basicConfig()
gen_log.info("Script exited normally")
# restore sys.argv so subsequent executions will include autoreload
sys.argv = original_argv
if mode == 'module':
# runpy did a fake import of the module as __main__, but now it's
# no longer in sys.modules. Figure out where it is and watch it.
loader = pkgutil.get_loader(module)
if loader is not None:
watch(loader.get_filename())
wait()
if __name__ == "__main__":
# See also the other __main__ block at the top of the file, which modifies
# sys.path before our imports
main()

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,265 @@
#!/usr/bin/env python
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Utilities for working with threads and ``Futures``.
``Futures`` are a pattern for concurrent programming introduced in
Python 3.2 in the `concurrent.futures` package (this package has also
been backported to older versions of Python and can be installed with
``pip install futures``). Tornado will use `concurrent.futures.Future` if
it is available; otherwise it will use a compatible class defined in this
module.
"""
from __future__ import absolute_import, division, print_function, with_statement
import functools
import sys
from tornado.stack_context import ExceptionStackContext, wrap
from tornado.util import raise_exc_info, ArgReplacer
try:
from concurrent import futures
except ImportError:
futures = None
class ReturnValueIgnoredError(Exception):
pass
class _DummyFuture(object):
def __init__(self):
self._done = False
self._result = None
self._exception = None
self._callbacks = []
def cancel(self):
return False
def cancelled(self):
return False
def running(self):
return not self._done
def done(self):
return self._done
def result(self, timeout=None):
self._check_done()
if self._exception:
raise self._exception
return self._result
def exception(self, timeout=None):
self._check_done()
if self._exception:
return self._exception
else:
return None
def add_done_callback(self, fn):
if self._done:
fn(self)
else:
self._callbacks.append(fn)
def set_result(self, result):
self._result = result
self._set_done()
def set_exception(self, exception):
self._exception = exception
self._set_done()
def _check_done(self):
if not self._done:
raise Exception("DummyFuture does not support blocking for results")
def _set_done(self):
self._done = True
for cb in self._callbacks:
# TODO: error handling
cb(self)
self._callbacks = None
if futures is None:
Future = _DummyFuture
else:
Future = futures.Future
class TracebackFuture(Future):
"""Subclass of `Future` which can store a traceback with
exceptions.
The traceback is automatically available in Python 3, but in the
Python 2 futures backport this information is discarded.
"""
def __init__(self):
super(TracebackFuture, self).__init__()
self.__exc_info = None
def exc_info(self):
return self.__exc_info
def set_exc_info(self, exc_info):
"""Traceback-aware replacement for
`~concurrent.futures.Future.set_exception`.
"""
self.__exc_info = exc_info
self.set_exception(exc_info[1])
def result(self):
if self.__exc_info is not None:
raise_exc_info(self.__exc_info)
else:
return super(TracebackFuture, self).result()
class DummyExecutor(object):
def submit(self, fn, *args, **kwargs):
future = TracebackFuture()
try:
future.set_result(fn(*args, **kwargs))
except Exception:
future.set_exc_info(sys.exc_info())
return future
def shutdown(self, wait=True):
pass
dummy_executor = DummyExecutor()
def run_on_executor(fn):
"""Decorator to run a synchronous method asynchronously on an executor.
The decorated method may be called with a ``callback`` keyword
argument and returns a future.
"""
@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
callback = kwargs.pop("callback", None)
future = self.executor.submit(fn, self, *args, **kwargs)
if callback:
self.io_loop.add_future(future,
lambda future: callback(future.result()))
return future
return wrapper
_NO_RESULT = object()
def return_future(f):
"""Decorator to make a function that returns via callback return a
`Future`.
The wrapped function should take a ``callback`` keyword argument
and invoke it with one argument when it has finished. To signal failure,
the function can simply raise an exception (which will be
captured by the `.StackContext` and passed along to the ``Future``).
From the caller's perspective, the callback argument is optional.
If one is given, it will be invoked when the function is complete
with `Future.result()` as an argument. If the function fails, the
callback will not be run and an exception will be raised into the
surrounding `.StackContext`.
If no callback is given, the caller should use the ``Future`` to
wait for the function to complete (perhaps by yielding it in a
`.gen.engine` function, or passing it to `.IOLoop.add_future`).
Usage::
@return_future
def future_func(arg1, arg2, callback):
# Do stuff (possibly asynchronous)
callback(result)
@gen.engine
def caller(callback):
yield future_func(arg1, arg2)
callback()
Note that ``@return_future`` and ``@gen.engine`` can be applied to the
same function, provided ``@return_future`` appears first. However,
consider using ``@gen.coroutine`` instead of this combination.
"""
replacer = ArgReplacer(f, 'callback')
@functools.wraps(f)
def wrapper(*args, **kwargs):
future = TracebackFuture()
callback, args, kwargs = replacer.replace(
lambda value=_NO_RESULT: future.set_result(value),
args, kwargs)
def handle_error(typ, value, tb):
future.set_exc_info((typ, value, tb))
return True
exc_info = None
with ExceptionStackContext(handle_error):
try:
result = f(*args, **kwargs)
if result is not None:
raise ReturnValueIgnoredError(
"@return_future should not be used with functions "
"that return values")
except:
exc_info = sys.exc_info()
raise
if exc_info is not None:
# If the initial synchronous part of f() raised an exception,
# go ahead and raise it to the caller directly without waiting
# for them to inspect the Future.
raise_exc_info(exc_info)
# If the caller passed in a callback, schedule it to be called
# when the future resolves. It is important that this happens
# just before we return the future, or else we risk confusing
# stack contexts with multiple exceptions (one here with the
# immediate exception, and again when the future resolves and
# the callback triggers its exception by calling future.result()).
if callback is not None:
def run_callback(future):
result = future.result()
if result is _NO_RESULT:
callback()
else:
callback(future.result())
future.add_done_callback(wrap(run_callback))
return future
return wrapper
def chain_future(a, b):
"""Chain two futures together so that when one completes, so does the other.
The result (success or failure) of ``a`` will be copied to ``b``.
"""
def copy(future):
assert future is a
if (isinstance(a, TracebackFuture) and isinstance(b, TracebackFuture)
and a.exc_info() is not None):
b.set_exc_info(a.exc_info())
elif a.exception() is not None:
b.set_exception(a.exception())
else:
b.set_result(a.result())
a.add_done_callback(copy)

View file

@ -0,0 +1,478 @@
#!/usr/bin/env python
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Non-blocking HTTP client implementation using pycurl."""
from __future__ import absolute_import, division, print_function, with_statement
import collections
import logging
import pycurl
import threading
import time
from tornado import httputil
from tornado import ioloop
from tornado.log import gen_log
from tornado import stack_context
from tornado.escape import utf8, native_str
from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main
from tornado.util import bytes_type
try:
from io import BytesIO # py3
except ImportError:
from cStringIO import StringIO as BytesIO # py2
class CurlAsyncHTTPClient(AsyncHTTPClient):
def initialize(self, io_loop, max_clients=10, defaults=None):
super(CurlAsyncHTTPClient, self).initialize(io_loop, defaults=defaults)
self._multi = pycurl.CurlMulti()
self._multi.setopt(pycurl.M_TIMERFUNCTION, self._set_timeout)
self._multi.setopt(pycurl.M_SOCKETFUNCTION, self._handle_socket)
self._curls = [_curl_create() for i in range(max_clients)]
self._free_list = self._curls[:]
self._requests = collections.deque()
self._fds = {}
self._timeout = None
try:
self._socket_action = self._multi.socket_action
except AttributeError:
# socket_action is found in pycurl since 7.18.2 (it's been
# in libcurl longer than that but wasn't accessible to
# python).
gen_log.warning("socket_action method missing from pycurl; "
"falling back to socket_all. Upgrading "
"libcurl and pycurl will improve performance")
self._socket_action = \
lambda fd, action: self._multi.socket_all()
# libcurl has bugs that sometimes cause it to not report all
# relevant file descriptors and timeouts to TIMERFUNCTION/
# SOCKETFUNCTION. Mitigate the effects of such bugs by
# forcing a periodic scan of all active requests.
self._force_timeout_callback = ioloop.PeriodicCallback(
self._handle_force_timeout, 1000, io_loop=io_loop)
self._force_timeout_callback.start()
# Work around a bug in libcurl 7.29.0: Some fields in the curl
# multi object are initialized lazily, and its destructor will
# segfault if it is destroyed without having been used. Add
# and remove a dummy handle to make sure everything is
# initialized.
dummy_curl_handle = pycurl.Curl()
self._multi.add_handle(dummy_curl_handle)
self._multi.remove_handle(dummy_curl_handle)
def close(self):
self._force_timeout_callback.stop()
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
for curl in self._curls:
curl.close()
self._multi.close()
self._closed = True
super(CurlAsyncHTTPClient, self).close()
def fetch_impl(self, request, callback):
self._requests.append((request, callback))
self._process_queue()
self._set_timeout(0)
def _handle_socket(self, event, fd, multi, data):
"""Called by libcurl when it wants to change the file descriptors
it cares about.
"""
event_map = {
pycurl.POLL_NONE: ioloop.IOLoop.NONE,
pycurl.POLL_IN: ioloop.IOLoop.READ,
pycurl.POLL_OUT: ioloop.IOLoop.WRITE,
pycurl.POLL_INOUT: ioloop.IOLoop.READ | ioloop.IOLoop.WRITE
}
if event == pycurl.POLL_REMOVE:
if fd in self._fds:
self.io_loop.remove_handler(fd)
del self._fds[fd]
else:
ioloop_event = event_map[event]
# libcurl sometimes closes a socket and then opens a new
# one using the same FD without giving us a POLL_NONE in
# between. This is a problem with the epoll IOLoop,
# because the kernel can tell when a socket is closed and
# removes it from the epoll automatically, causing future
# update_handler calls to fail. Since we can't tell when
# this has happened, always use remove and re-add
# instead of update.
if fd in self._fds:
self.io_loop.remove_handler(fd)
self.io_loop.add_handler(fd, self._handle_events,
ioloop_event)
self._fds[fd] = ioloop_event
def _set_timeout(self, msecs):
"""Called by libcurl to schedule a timeout."""
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
self._timeout = self.io_loop.add_timeout(
self.io_loop.time() + msecs / 1000.0, self._handle_timeout)
def _handle_events(self, fd, events):
"""Called by IOLoop when there is activity on one of our
file descriptors.
"""
action = 0
if events & ioloop.IOLoop.READ:
action |= pycurl.CSELECT_IN
if events & ioloop.IOLoop.WRITE:
action |= pycurl.CSELECT_OUT
while True:
try:
ret, num_handles = self._socket_action(fd, action)
except pycurl.error as e:
ret = e.args[0]
if ret != pycurl.E_CALL_MULTI_PERFORM:
break
self._finish_pending_requests()
def _handle_timeout(self):
"""Called by IOLoop when the requested timeout has passed."""
with stack_context.NullContext():
self._timeout = None
while True:
try:
ret, num_handles = self._socket_action(
pycurl.SOCKET_TIMEOUT, 0)
except pycurl.error as e:
ret = e.args[0]
if ret != pycurl.E_CALL_MULTI_PERFORM:
break
self._finish_pending_requests()
# In theory, we shouldn't have to do this because curl will
# call _set_timeout whenever the timeout changes. However,
# sometimes after _handle_timeout we will need to reschedule
# immediately even though nothing has changed from curl's
# perspective. This is because when socket_action is
# called with SOCKET_TIMEOUT, libcurl decides internally which
# timeouts need to be processed by using a monotonic clock
# (where available) while tornado uses python's time.time()
# to decide when timeouts have occurred. When those clocks
# disagree on elapsed time (as they will whenever there is an
# NTP adjustment), tornado might call _handle_timeout before
# libcurl is ready. After each timeout, resync the scheduled
# timeout with libcurl's current state.
new_timeout = self._multi.timeout()
if new_timeout >= 0:
self._set_timeout(new_timeout)
def _handle_force_timeout(self):
"""Called by IOLoop periodically to ask libcurl to process any
events it may have forgotten about.
"""
with stack_context.NullContext():
while True:
try:
ret, num_handles = self._multi.socket_all()
except pycurl.error as e:
ret = e.args[0]
if ret != pycurl.E_CALL_MULTI_PERFORM:
break
self._finish_pending_requests()
def _finish_pending_requests(self):
"""Process any requests that were completed by the last
call to multi.socket_action.
"""
while True:
num_q, ok_list, err_list = self._multi.info_read()
for curl in ok_list:
self._finish(curl)
for curl, errnum, errmsg in err_list:
self._finish(curl, errnum, errmsg)
if num_q == 0:
break
self._process_queue()
def _process_queue(self):
with stack_context.NullContext():
while True:
started = 0
while self._free_list and self._requests:
started += 1
curl = self._free_list.pop()
(request, callback) = self._requests.popleft()
curl.info = {
"headers": httputil.HTTPHeaders(),
"buffer": BytesIO(),
"request": request,
"callback": callback,
"curl_start_time": time.time(),
}
# Disable IPv6 to mitigate the effects of this bug
# on curl versions <= 7.21.0
# http://sourceforge.net/tracker/?func=detail&aid=3017819&group_id=976&atid=100976
if pycurl.version_info()[2] <= 0x71500: # 7.21.0
curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
_curl_setup_request(curl, request, curl.info["buffer"],
curl.info["headers"])
self._multi.add_handle(curl)
if not started:
break
def _finish(self, curl, curl_error=None, curl_message=None):
info = curl.info
curl.info = None
self._multi.remove_handle(curl)
self._free_list.append(curl)
buffer = info["buffer"]
if curl_error:
error = CurlError(curl_error, curl_message)
code = error.code
effective_url = None
buffer.close()
buffer = None
else:
error = None
code = curl.getinfo(pycurl.HTTP_CODE)
effective_url = curl.getinfo(pycurl.EFFECTIVE_URL)
buffer.seek(0)
# the various curl timings are documented at
# http://curl.haxx.se/libcurl/c/curl_easy_getinfo.html
time_info = dict(
queue=info["curl_start_time"] - info["request"].start_time,
namelookup=curl.getinfo(pycurl.NAMELOOKUP_TIME),
connect=curl.getinfo(pycurl.CONNECT_TIME),
pretransfer=curl.getinfo(pycurl.PRETRANSFER_TIME),
starttransfer=curl.getinfo(pycurl.STARTTRANSFER_TIME),
total=curl.getinfo(pycurl.TOTAL_TIME),
redirect=curl.getinfo(pycurl.REDIRECT_TIME),
)
try:
info["callback"](HTTPResponse(
request=info["request"], code=code, headers=info["headers"],
buffer=buffer, effective_url=effective_url, error=error,
request_time=time.time() - info["curl_start_time"],
time_info=time_info))
except Exception:
self.handle_callback_exception(info["callback"])
def handle_callback_exception(self, callback):
self.io_loop.handle_callback_exception(callback)
class CurlError(HTTPError):
def __init__(self, errno, message):
HTTPError.__init__(self, 599, message)
self.errno = errno
def _curl_create():
curl = pycurl.Curl()
if gen_log.isEnabledFor(logging.DEBUG):
curl.setopt(pycurl.VERBOSE, 1)
curl.setopt(pycurl.DEBUGFUNCTION, _curl_debug)
return curl
def _curl_setup_request(curl, request, buffer, headers):
curl.setopt(pycurl.URL, native_str(request.url))
# libcurl's magic "Expect: 100-continue" behavior causes delays
# with servers that don't support it (which include, among others,
# Google's OpenID endpoint). Additionally, this behavior has
# a bug in conjunction with the curl_multi_socket_action API
# (https://sourceforge.net/tracker/?func=detail&atid=100976&aid=3039744&group_id=976),
# which increases the delays. It's more trouble than it's worth,
# so just turn off the feature (yes, setting Expect: to an empty
# value is the official way to disable this)
if "Expect" not in request.headers:
request.headers["Expect"] = ""
# libcurl adds Pragma: no-cache by default; disable that too
if "Pragma" not in request.headers:
request.headers["Pragma"] = ""
# Request headers may be either a regular dict or HTTPHeaders object
if isinstance(request.headers, httputil.HTTPHeaders):
curl.setopt(pycurl.HTTPHEADER,
[native_str("%s: %s" % i) for i in request.headers.get_all()])
else:
curl.setopt(pycurl.HTTPHEADER,
[native_str("%s: %s" % i) for i in request.headers.items()])
if request.header_callback:
curl.setopt(pycurl.HEADERFUNCTION, request.header_callback)
else:
curl.setopt(pycurl.HEADERFUNCTION,
lambda line: _curl_header_callback(headers, line))
if request.streaming_callback:
write_function = request.streaming_callback
else:
write_function = buffer.write
if bytes_type is str: # py2
curl.setopt(pycurl.WRITEFUNCTION, write_function)
else: # py3
# Upstream pycurl doesn't support py3, but ubuntu 12.10 includes
# a fork/port. That version has a bug in which it passes unicode
# strings instead of bytes to the WRITEFUNCTION. This means that
# if you use a WRITEFUNCTION (which tornado always does), you cannot
# download arbitrary binary data. This needs to be fixed in the
# ported pycurl package, but in the meantime this lambda will
# make it work for downloading (utf8) text.
curl.setopt(pycurl.WRITEFUNCTION, lambda s: write_function(utf8(s)))
curl.setopt(pycurl.FOLLOWLOCATION, request.follow_redirects)
curl.setopt(pycurl.MAXREDIRS, request.max_redirects)
curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(1000 * request.connect_timeout))
curl.setopt(pycurl.TIMEOUT_MS, int(1000 * request.request_timeout))
if request.user_agent:
curl.setopt(pycurl.USERAGENT, native_str(request.user_agent))
else:
curl.setopt(pycurl.USERAGENT, "Mozilla/5.0 (compatible; pycurl)")
if request.network_interface:
curl.setopt(pycurl.INTERFACE, request.network_interface)
if request.use_gzip:
curl.setopt(pycurl.ENCODING, "gzip,deflate")
else:
curl.setopt(pycurl.ENCODING, "none")
if request.proxy_host and request.proxy_port:
curl.setopt(pycurl.PROXY, request.proxy_host)
curl.setopt(pycurl.PROXYPORT, request.proxy_port)
if request.proxy_username:
credentials = '%s:%s' % (request.proxy_username,
request.proxy_password)
curl.setopt(pycurl.PROXYUSERPWD, credentials)
else:
curl.setopt(pycurl.PROXY, '')
if request.validate_cert:
curl.setopt(pycurl.SSL_VERIFYPEER, 1)
curl.setopt(pycurl.SSL_VERIFYHOST, 2)
else:
curl.setopt(pycurl.SSL_VERIFYPEER, 0)
curl.setopt(pycurl.SSL_VERIFYHOST, 0)
if request.ca_certs is not None:
curl.setopt(pycurl.CAINFO, request.ca_certs)
else:
# There is no way to restore pycurl.CAINFO to its default value
# (Using unsetopt makes it reject all certificates).
# I don't see any way to read the default value from python so it
# can be restored later. We'll have to just leave CAINFO untouched
# if no ca_certs file was specified, and require that if any
# request uses a custom ca_certs file, they all must.
pass
if request.allow_ipv6 is False:
# Curl behaves reasonably when DNS resolution gives an ipv6 address
# that we can't reach, so allow ipv6 unless the user asks to disable.
# (but see version check in _process_queue above)
curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
# Set the request method through curl's irritating interface which makes
# up names for almost every single method
curl_options = {
"GET": pycurl.HTTPGET,
"POST": pycurl.POST,
"PUT": pycurl.UPLOAD,
"HEAD": pycurl.NOBODY,
}
custom_methods = set(["DELETE", "OPTIONS", "PATCH"])
for o in curl_options.values():
curl.setopt(o, False)
if request.method in curl_options:
curl.unsetopt(pycurl.CUSTOMREQUEST)
curl.setopt(curl_options[request.method], True)
elif request.allow_nonstandard_methods or request.method in custom_methods:
curl.setopt(pycurl.CUSTOMREQUEST, request.method)
else:
raise KeyError('unknown method ' + request.method)
# Handle curl's cryptic options for every individual HTTP method
if request.method in ("POST", "PUT"):
request_buffer = BytesIO(utf8(request.body))
curl.setopt(pycurl.READFUNCTION, request_buffer.read)
if request.method == "POST":
def ioctl(cmd):
if cmd == curl.IOCMD_RESTARTREAD:
request_buffer.seek(0)
curl.setopt(pycurl.IOCTLFUNCTION, ioctl)
curl.setopt(pycurl.POSTFIELDSIZE, len(request.body))
else:
curl.setopt(pycurl.INFILESIZE, len(request.body))
if request.auth_username is not None:
userpwd = "%s:%s" % (request.auth_username, request.auth_password or '')
if request.auth_mode is None or request.auth_mode == "basic":
curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_BASIC)
elif request.auth_mode == "digest":
curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_DIGEST)
else:
raise ValueError("Unsupported auth_mode %s" % request.auth_mode)
curl.setopt(pycurl.USERPWD, native_str(userpwd))
gen_log.debug("%s %s (username: %r)", request.method, request.url,
request.auth_username)
else:
curl.unsetopt(pycurl.USERPWD)
gen_log.debug("%s %s", request.method, request.url)
if request.client_cert is not None:
curl.setopt(pycurl.SSLCERT, request.client_cert)
if request.client_key is not None:
curl.setopt(pycurl.SSLKEY, request.client_key)
if threading.activeCount() > 1:
# libcurl/pycurl is not thread-safe by default. When multiple threads
# are used, signals should be disabled. This has the side effect
# of disabling DNS timeouts in some environments (when libcurl is
# not linked against ares), so we don't do it when there is only one
# thread. Applications that use many short-lived threads may need
# to set NOSIGNAL manually in a prepare_curl_callback since
# there may not be any other threads running at the time we call
# threading.activeCount.
curl.setopt(pycurl.NOSIGNAL, 1)
if request.prepare_curl_callback is not None:
request.prepare_curl_callback(curl)
def _curl_header_callback(headers, header_line):
# header_line as returned by curl includes the end-of-line characters.
header_line = header_line.strip()
if header_line.startswith("HTTP/"):
headers.clear()
return
if not header_line:
return
headers.parse_line(header_line)
def _curl_debug(debug_type, debug_msg):
debug_types = ('I', '<', '>', '<', '>')
if debug_type == 0:
gen_log.debug('%s', debug_msg.strip())
elif debug_type in (1, 2):
for line in debug_msg.splitlines():
gen_log.debug('%s %s', debug_types[debug_type], line)
elif debug_type == 4:
gen_log.debug('%s %r', debug_types[debug_type], debug_msg)
if __name__ == "__main__":
AsyncHTTPClient.configure(CurlAsyncHTTPClient)
main()

View file

@ -0,0 +1,380 @@
#!/usr/bin/env python
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Escaping/unescaping methods for HTML, JSON, URLs, and others.
Also includes a few other miscellaneous string manipulation functions that
have crept in over time.
"""
from __future__ import absolute_import, division, print_function, with_statement
import re
import sys
from tornado.util import bytes_type, unicode_type, basestring_type, u
try:
from urllib.parse import parse_qs as _parse_qs # py3
except ImportError:
from urlparse import parse_qs as _parse_qs # Python 2.6+
try:
import htmlentitydefs # py2
except ImportError:
import html.entities as htmlentitydefs # py3
try:
import urllib.parse as urllib_parse # py3
except ImportError:
import urllib as urllib_parse # py2
import json
try:
unichr
except NameError:
unichr = chr
_XHTML_ESCAPE_RE = re.compile('[&<>"]')
_XHTML_ESCAPE_DICT = {'&': '&amp;', '<': '&lt;', '>': '&gt;', '"': '&quot;'}
def xhtml_escape(value):
"""Escapes a string so it is valid within HTML or XML."""
return _XHTML_ESCAPE_RE.sub(lambda match: _XHTML_ESCAPE_DICT[match.group(0)],
to_basestring(value))
def xhtml_unescape(value):
"""Un-escapes an XML-escaped string."""
return re.sub(r"&(#?)(\w+?);", _convert_entity, _unicode(value))
# The fact that json_encode wraps json.dumps is an implementation detail.
# Please see https://github.com/facebook/tornado/pull/706
# before sending a pull request that adds **kwargs to this function.
def json_encode(value):
"""JSON-encodes the given Python object."""
# JSON permits but does not require forward slashes to be escaped.
# This is useful when json data is emitted in a <script> tag
# in HTML, as it prevents </script> tags from prematurely terminating
# the javscript. Some json libraries do this escaping by default,
# although python's standard library does not, so we do it here.
# http://stackoverflow.com/questions/1580647/json-why-are-forward-slashes-escaped
return json.dumps(value).replace("</", "<\\/")
def json_decode(value):
"""Returns Python objects for the given JSON string."""
return json.loads(to_basestring(value))
def squeeze(value):
"""Replace all sequences of whitespace chars with a single space."""
return re.sub(r"[\x00-\x20]+", " ", value).strip()
def url_escape(value, plus=True):
"""Returns a URL-encoded version of the given value.
If ``plus`` is true (the default), spaces will be represented
as "+" instead of "%20". This is appropriate for query strings
but not for the path component of a URL. Note that this default
is the reverse of Python's urllib module.
.. versionadded:: 3.1
The ``plus`` argument
"""
quote = urllib_parse.quote_plus if plus else urllib_parse.quote
return quote(utf8(value))
# python 3 changed things around enough that we need two separate
# implementations of url_unescape. We also need our own implementation
# of parse_qs since python 3's version insists on decoding everything.
if sys.version_info[0] < 3:
def url_unescape(value, encoding='utf-8', plus=True):
"""Decodes the given value from a URL.
The argument may be either a byte or unicode string.
If encoding is None, the result will be a byte string. Otherwise,
the result is a unicode string in the specified encoding.
If ``plus`` is true (the default), plus signs will be interpreted
as spaces (literal plus signs must be represented as "%2B"). This
is appropriate for query strings and form-encoded values but not
for the path component of a URL. Note that this default is the
reverse of Python's urllib module.
.. versionadded:: 3.1
The ``plus`` argument
"""
unquote = (urllib_parse.unquote_plus if plus else urllib_parse.unquote)
if encoding is None:
return unquote(utf8(value))
else:
return unicode_type(unquote(utf8(value)), encoding)
parse_qs_bytes = _parse_qs
else:
def url_unescape(value, encoding='utf-8', plus=True):
"""Decodes the given value from a URL.
The argument may be either a byte or unicode string.
If encoding is None, the result will be a byte string. Otherwise,
the result is a unicode string in the specified encoding.
If ``plus`` is true (the default), plus signs will be interpreted
as spaces (literal plus signs must be represented as "%2B"). This
is appropriate for query strings and form-encoded values but not
for the path component of a URL. Note that this default is the
reverse of Python's urllib module.
.. versionadded:: 3.1
The ``plus`` argument
"""
if encoding is None:
if plus:
# unquote_to_bytes doesn't have a _plus variant
value = to_basestring(value).replace('+', ' ')
return urllib_parse.unquote_to_bytes(value)
else:
unquote = (urllib_parse.unquote_plus if plus
else urllib_parse.unquote)
return unquote(to_basestring(value), encoding=encoding)
def parse_qs_bytes(qs, keep_blank_values=False, strict_parsing=False):
"""Parses a query string like urlparse.parse_qs, but returns the
values as byte strings.
Keys still become type str (interpreted as latin1 in python3!)
because it's too painful to keep them as byte strings in
python3 and in practice they're nearly always ascii anyway.
"""
# This is gross, but python3 doesn't give us another way.
# Latin1 is the universal donor of character encodings.
result = _parse_qs(qs, keep_blank_values, strict_parsing,
encoding='latin1', errors='strict')
encoded = {}
for k, v in result.items():
encoded[k] = [i.encode('latin1') for i in v]
return encoded
_UTF8_TYPES = (bytes_type, type(None))
def utf8(value):
"""Converts a string argument to a byte string.
If the argument is already a byte string or None, it is returned unchanged.
Otherwise it must be a unicode string and is encoded as utf8.
"""
if isinstance(value, _UTF8_TYPES):
return value
assert isinstance(value, unicode_type), \
"Expected bytes, unicode, or None; got %r" % type(value)
return value.encode("utf-8")
_TO_UNICODE_TYPES = (unicode_type, type(None))
def to_unicode(value):
"""Converts a string argument to a unicode string.
If the argument is already a unicode string or None, it is returned
unchanged. Otherwise it must be a byte string and is decoded as utf8.
"""
if isinstance(value, _TO_UNICODE_TYPES):
return value
assert isinstance(value, bytes_type), \
"Expected bytes, unicode, or None; got %r" % type(value)
return value.decode("utf-8")
# to_unicode was previously named _unicode not because it was private,
# but to avoid conflicts with the built-in unicode() function/type
_unicode = to_unicode
# When dealing with the standard library across python 2 and 3 it is
# sometimes useful to have a direct conversion to the native string type
if str is unicode_type:
native_str = to_unicode
else:
native_str = utf8
_BASESTRING_TYPES = (basestring_type, type(None))
def to_basestring(value):
"""Converts a string argument to a subclass of basestring.
In python2, byte and unicode strings are mostly interchangeable,
so functions that deal with a user-supplied argument in combination
with ascii string constants can use either and should return the type
the user supplied. In python3, the two types are not interchangeable,
so this method is needed to convert byte strings to unicode.
"""
if isinstance(value, _BASESTRING_TYPES):
return value
assert isinstance(value, bytes_type), \
"Expected bytes, unicode, or None; got %r" % type(value)
return value.decode("utf-8")
def recursive_unicode(obj):
"""Walks a simple data structure, converting byte strings to unicode.
Supports lists, tuples, and dictionaries.
"""
if isinstance(obj, dict):
return dict((recursive_unicode(k), recursive_unicode(v)) for (k, v) in obj.items())
elif isinstance(obj, list):
return list(recursive_unicode(i) for i in obj)
elif isinstance(obj, tuple):
return tuple(recursive_unicode(i) for i in obj)
elif isinstance(obj, bytes_type):
return to_unicode(obj)
else:
return obj
# I originally used the regex from
# http://daringfireball.net/2010/07/improved_regex_for_matching_urls
# but it gets all exponential on certain patterns (such as too many trailing
# dots), causing the regex matcher to never return.
# This regex should avoid those problems.
# Use to_unicode instead of tornado.util.u - we don't want backslashes getting
# processed as escapes.
_URL_RE = re.compile(to_unicode(r"""\b((?:([\w-]+):(/{1,3})|www[.])(?:(?:(?:[^\s&()]|&amp;|&quot;)*(?:[^!"#$%&'()*+,.:;<=>?@\[\]^`{|}~\s]))|(?:\((?:[^\s&()]|&amp;|&quot;)*\)))+)"""))
def linkify(text, shorten=False, extra_params="",
require_protocol=False, permitted_protocols=["http", "https"]):
"""Converts plain text into HTML with links.
For example: ``linkify("Hello http://tornadoweb.org!")`` would return
``Hello <a href="http://tornadoweb.org">http://tornadoweb.org</a>!``
Parameters:
* ``shorten``: Long urls will be shortened for display.
* ``extra_params``: Extra text to include in the link tag, or a callable
taking the link as an argument and returning the extra text
e.g. ``linkify(text, extra_params='rel="nofollow" class="external"')``,
or::
def extra_params_cb(url):
if url.startswith("http://example.com"):
return 'class="internal"'
else:
return 'class="external" rel="nofollow"'
linkify(text, extra_params=extra_params_cb)
* ``require_protocol``: Only linkify urls which include a protocol. If
this is False, urls such as www.facebook.com will also be linkified.
* ``permitted_protocols``: List (or set) of protocols which should be
linkified, e.g. ``linkify(text, permitted_protocols=["http", "ftp",
"mailto"])``. It is very unsafe to include protocols such as
``javascript``.
"""
if extra_params and not callable(extra_params):
extra_params = " " + extra_params.strip()
def make_link(m):
url = m.group(1)
proto = m.group(2)
if require_protocol and not proto:
return url # not protocol, no linkify
if proto and proto not in permitted_protocols:
return url # bad protocol, no linkify
href = m.group(1)
if not proto:
href = "http://" + href # no proto specified, use http
if callable(extra_params):
params = " " + extra_params(href).strip()
else:
params = extra_params
# clip long urls. max_len is just an approximation
max_len = 30
if shorten and len(url) > max_len:
before_clip = url
if proto:
proto_len = len(proto) + 1 + len(m.group(3) or "") # +1 for :
else:
proto_len = 0
parts = url[proto_len:].split("/")
if len(parts) > 1:
# Grab the whole host part plus the first bit of the path
# The path is usually not that interesting once shortened
# (no more slug, etc), so it really just provides a little
# extra indication of shortening.
url = url[:proto_len] + parts[0] + "/" + \
parts[1][:8].split('?')[0].split('.')[0]
if len(url) > max_len * 1.5: # still too long
url = url[:max_len]
if url != before_clip:
amp = url.rfind('&')
# avoid splitting html char entities
if amp > max_len - 5:
url = url[:amp]
url += "..."
if len(url) >= len(before_clip):
url = before_clip
else:
# full url is visible on mouse-over (for those who don't
# have a status bar, such as Safari by default)
params += ' title="%s"' % href
return u('<a href="%s"%s>%s</a>') % (href, params, url)
# First HTML-escape so that our strings are all safe.
# The regex is modified to avoid character entites other than &amp; so
# that we won't pick up &quot;, etc.
text = _unicode(xhtml_escape(text))
return _URL_RE.sub(make_link, text)
def _convert_entity(m):
if m.group(1) == "#":
try:
return unichr(int(m.group(2)))
except ValueError:
return "&#%s;" % m.group(2)
try:
return _HTML_UNICODE_MAP[m.group(2)]
except KeyError:
return "&%s;" % m.group(2)
def _build_unicode_map():
unicode_map = {}
for name, value in htmlentitydefs.name2codepoint.items():
unicode_map[name] = unichr(value)
return unicode_map
_HTML_UNICODE_MAP = _build_unicode_map()

View file

@ -0,0 +1,561 @@
"""``tornado.gen`` is a generator-based interface to make it easier to
work in an asynchronous environment. Code using the ``gen`` module
is technically asynchronous, but it is written as a single generator
instead of a collection of separate functions.
For example, the following asynchronous handler::
class AsyncHandler(RequestHandler):
@asynchronous
def get(self):
http_client = AsyncHTTPClient()
http_client.fetch("http://example.com",
callback=self.on_fetch)
def on_fetch(self, response):
do_something_with_response(response)
self.render("template.html")
could be written with ``gen`` as::
class GenAsyncHandler(RequestHandler):
@gen.coroutine
def get(self):
http_client = AsyncHTTPClient()
response = yield http_client.fetch("http://example.com")
do_something_with_response(response)
self.render("template.html")
Most asynchronous functions in Tornado return a `.Future`;
yielding this object returns its `~.Future.result`.
For functions that do not return ``Futures``, `Task` works with any
function that takes a ``callback`` keyword argument (most Tornado functions
can be used in either style, although the ``Future`` style is preferred
since it is both shorter and provides better exception handling)::
@gen.coroutine
def get(self):
yield gen.Task(AsyncHTTPClient().fetch, "http://example.com")
You can also yield a list of ``Futures`` and/or ``Tasks``, which will be
started at the same time and run in parallel; a list of results will
be returned when they are all finished::
@gen.coroutine
def get(self):
http_client = AsyncHTTPClient()
response1, response2 = yield [http_client.fetch(url1),
http_client.fetch(url2)]
For more complicated interfaces, `Task` can be split into two parts:
`Callback` and `Wait`::
class GenAsyncHandler2(RequestHandler):
@asynchronous
@gen.coroutine
def get(self):
http_client = AsyncHTTPClient()
http_client.fetch("http://example.com",
callback=(yield gen.Callback("key"))
response = yield gen.Wait("key")
do_something_with_response(response)
self.render("template.html")
The ``key`` argument to `Callback` and `Wait` allows for multiple
asynchronous operations to be started at different times and proceed
in parallel: yield several callbacks with different keys, then wait
for them once all the async operations have started.
The result of a `Wait` or `Task` yield expression depends on how the callback
was run. If it was called with no arguments, the result is ``None``. If
it was called with one argument, the result is that argument. If it was
called with more than one argument or any keyword arguments, the result
is an `Arguments` object, which is a named tuple ``(args, kwargs)``.
"""
from __future__ import absolute_import, division, print_function, with_statement
import collections
import functools
import itertools
import sys
import types
from tornado.concurrent import Future, TracebackFuture
from tornado.ioloop import IOLoop
from tornado.stack_context import ExceptionStackContext, wrap
class KeyReuseError(Exception):
pass
class UnknownKeyError(Exception):
pass
class LeakedCallbackError(Exception):
pass
class BadYieldError(Exception):
pass
class ReturnValueIgnoredError(Exception):
pass
def engine(func):
"""Callback-oriented decorator for asynchronous generators.
This is an older interface; for new code that does not need to be
compatible with versions of Tornado older than 3.0 the
`coroutine` decorator is recommended instead.
This decorator is similar to `coroutine`, except it does not
return a `.Future` and the ``callback`` argument is not treated
specially.
In most cases, functions decorated with `engine` should take
a ``callback`` argument and invoke it with their result when
they are finished. One notable exception is the
`~tornado.web.RequestHandler` :ref:`HTTP verb methods <verbs>`,
which use ``self.finish()`` in place of a callback argument.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
runner = None
def handle_exception(typ, value, tb):
# if the function throws an exception before its first "yield"
# (or is not a generator at all), the Runner won't exist yet.
# However, in that case we haven't reached anything asynchronous
# yet, so we can just let the exception propagate.
if runner is not None:
return runner.handle_exception(typ, value, tb)
return False
with ExceptionStackContext(handle_exception) as deactivate:
try:
result = func(*args, **kwargs)
except (Return, StopIteration) as e:
result = getattr(e, 'value', None)
else:
if isinstance(result, types.GeneratorType):
def final_callback(value):
if value is not None:
raise ReturnValueIgnoredError(
"@gen.engine functions cannot return values: "
"%r" % (value,))
assert value is None
deactivate()
runner = Runner(result, final_callback)
runner.run()
return
if result is not None:
raise ReturnValueIgnoredError(
"@gen.engine functions cannot return values: %r" %
(result,))
deactivate()
# no yield, so we're done
return wrapper
def coroutine(func):
"""Decorator for asynchronous generators.
Any generator that yields objects from this module must be wrapped
in either this decorator or `engine`.
Coroutines may "return" by raising the special exception
`Return(value) <Return>`. In Python 3.3+, it is also possible for
the function to simply use the ``return value`` statement (prior to
Python 3.3 generators were not allowed to also return values).
In all versions of Python a coroutine that simply wishes to exit
early may use the ``return`` statement without a value.
Functions with this decorator return a `.Future`. Additionally,
they may be called with a ``callback`` keyword argument, which
will be invoked with the future's result when it resolves. If the
coroutine fails, the callback will not be run and an exception
will be raised into the surrounding `.StackContext`. The
``callback`` argument is not visible inside the decorated
function; it is handled by the decorator itself.
From the caller's perspective, ``@gen.coroutine`` is similar to
the combination of ``@return_future`` and ``@gen.engine``.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
runner = None
future = TracebackFuture()
if 'callback' in kwargs:
callback = kwargs.pop('callback')
IOLoop.current().add_future(
future, lambda future: callback(future.result()))
def handle_exception(typ, value, tb):
try:
if runner is not None and runner.handle_exception(typ, value, tb):
return True
except Exception:
typ, value, tb = sys.exc_info()
future.set_exc_info((typ, value, tb))
return True
with ExceptionStackContext(handle_exception) as deactivate:
try:
result = func(*args, **kwargs)
except (Return, StopIteration) as e:
result = getattr(e, 'value', None)
except Exception:
deactivate()
future.set_exc_info(sys.exc_info())
return future
else:
if isinstance(result, types.GeneratorType):
def final_callback(value):
deactivate()
future.set_result(value)
runner = Runner(result, final_callback)
runner.run()
return future
deactivate()
future.set_result(result)
return future
return wrapper
class Return(Exception):
"""Special exception to return a value from a `coroutine`.
If this exception is raised, its value argument is used as the
result of the coroutine::
@gen.coroutine
def fetch_json(url):
response = yield AsyncHTTPClient().fetch(url)
raise gen.Return(json_decode(response.body))
In Python 3.3, this exception is no longer necessary: the ``return``
statement can be used directly to return a value (previously
``yield`` and ``return`` with a value could not be combined in the
same function).
By analogy with the return statement, the value argument is optional,
but it is never necessary to ``raise gen.Return()``. The ``return``
statement can be used with no arguments instead.
"""
def __init__(self, value=None):
super(Return, self).__init__()
self.value = value
class YieldPoint(object):
"""Base class for objects that may be yielded from the generator.
Applications do not normally need to use this class, but it may be
subclassed to provide additional yielding behavior.
"""
def start(self, runner):
"""Called by the runner after the generator has yielded.
No other methods will be called on this object before ``start``.
"""
raise NotImplementedError()
def is_ready(self):
"""Called by the runner to determine whether to resume the generator.
Returns a boolean; may be called more than once.
"""
raise NotImplementedError()
def get_result(self):
"""Returns the value to use as the result of the yield expression.
This method will only be called once, and only after `is_ready`
has returned true.
"""
raise NotImplementedError()
class Callback(YieldPoint):
"""Returns a callable object that will allow a matching `Wait` to proceed.
The key may be any value suitable for use as a dictionary key, and is
used to match ``Callbacks`` to their corresponding ``Waits``. The key
must be unique among outstanding callbacks within a single run of the
generator function, but may be reused across different runs of the same
function (so constants generally work fine).
The callback may be called with zero or one arguments; if an argument
is given it will be returned by `Wait`.
"""
def __init__(self, key):
self.key = key
def start(self, runner):
self.runner = runner
runner.register_callback(self.key)
def is_ready(self):
return True
def get_result(self):
return self.runner.result_callback(self.key)
class Wait(YieldPoint):
"""Returns the argument passed to the result of a previous `Callback`."""
def __init__(self, key):
self.key = key
def start(self, runner):
self.runner = runner
def is_ready(self):
return self.runner.is_ready(self.key)
def get_result(self):
return self.runner.pop_result(self.key)
class WaitAll(YieldPoint):
"""Returns the results of multiple previous `Callbacks <Callback>`.
The argument is a sequence of `Callback` keys, and the result is
a list of results in the same order.
`WaitAll` is equivalent to yielding a list of `Wait` objects.
"""
def __init__(self, keys):
self.keys = keys
def start(self, runner):
self.runner = runner
def is_ready(self):
return all(self.runner.is_ready(key) for key in self.keys)
def get_result(self):
return [self.runner.pop_result(key) for key in self.keys]
class Task(YieldPoint):
"""Runs a single asynchronous operation.
Takes a function (and optional additional arguments) and runs it with
those arguments plus a ``callback`` keyword argument. The argument passed
to the callback is returned as the result of the yield expression.
A `Task` is equivalent to a `Callback`/`Wait` pair (with a unique
key generated automatically)::
result = yield gen.Task(func, args)
func(args, callback=(yield gen.Callback(key)))
result = yield gen.Wait(key)
"""
def __init__(self, func, *args, **kwargs):
assert "callback" not in kwargs
self.args = args
self.kwargs = kwargs
self.func = func
def start(self, runner):
self.runner = runner
self.key = object()
runner.register_callback(self.key)
self.kwargs["callback"] = runner.result_callback(self.key)
self.func(*self.args, **self.kwargs)
def is_ready(self):
return self.runner.is_ready(self.key)
def get_result(self):
return self.runner.pop_result(self.key)
class YieldFuture(YieldPoint):
def __init__(self, future, io_loop=None):
self.future = future
self.io_loop = io_loop or IOLoop.current()
def start(self, runner):
self.runner = runner
self.key = object()
runner.register_callback(self.key)
self.io_loop.add_future(self.future, runner.result_callback(self.key))
def is_ready(self):
return self.runner.is_ready(self.key)
def get_result(self):
return self.runner.pop_result(self.key).result()
class Multi(YieldPoint):
"""Runs multiple asynchronous operations in parallel.
Takes a list of ``Tasks`` or other ``YieldPoints`` and returns a list of
their responses. It is not necessary to call `Multi` explicitly,
since the engine will do so automatically when the generator yields
a list of ``YieldPoints``.
"""
def __init__(self, children):
self.children = []
for i in children:
if isinstance(i, Future):
i = YieldFuture(i)
self.children.append(i)
assert all(isinstance(i, YieldPoint) for i in self.children)
self.unfinished_children = set(self.children)
def start(self, runner):
for i in self.children:
i.start(runner)
def is_ready(self):
finished = list(itertools.takewhile(
lambda i: i.is_ready(), self.unfinished_children))
self.unfinished_children.difference_update(finished)
return not self.unfinished_children
def get_result(self):
return [i.get_result() for i in self.children]
class _NullYieldPoint(YieldPoint):
def start(self, runner):
pass
def is_ready(self):
return True
def get_result(self):
return None
_null_yield_point = _NullYieldPoint()
class Runner(object):
"""Internal implementation of `tornado.gen.engine`.
Maintains information about pending callbacks and their results.
``final_callback`` is run after the generator exits.
"""
def __init__(self, gen, final_callback):
self.gen = gen
self.final_callback = final_callback
self.yield_point = _null_yield_point
self.pending_callbacks = set()
self.results = {}
self.running = False
self.finished = False
self.exc_info = None
self.had_exception = False
def register_callback(self, key):
"""Adds ``key`` to the list of callbacks."""
if key in self.pending_callbacks:
raise KeyReuseError("key %r is already pending" % (key,))
self.pending_callbacks.add(key)
def is_ready(self, key):
"""Returns true if a result is available for ``key``."""
if key not in self.pending_callbacks:
raise UnknownKeyError("key %r is not pending" % (key,))
return key in self.results
def set_result(self, key, result):
"""Sets the result for ``key`` and attempts to resume the generator."""
self.results[key] = result
self.run()
def pop_result(self, key):
"""Returns the result for ``key`` and unregisters it."""
self.pending_callbacks.remove(key)
return self.results.pop(key)
def run(self):
"""Starts or resumes the generator, running until it reaches a
yield point that is not ready.
"""
if self.running or self.finished:
return
try:
self.running = True
while True:
if self.exc_info is None:
try:
if not self.yield_point.is_ready():
return
next = self.yield_point.get_result()
self.yield_point = None
except Exception:
self.exc_info = sys.exc_info()
try:
if self.exc_info is not None:
self.had_exception = True
exc_info = self.exc_info
self.exc_info = None
yielded = self.gen.throw(*exc_info)
else:
yielded = self.gen.send(next)
except (StopIteration, Return) as e:
self.finished = True
self.yield_point = _null_yield_point
if self.pending_callbacks and not self.had_exception:
# If we ran cleanly without waiting on all callbacks
# raise an error (really more of a warning). If we
# had an exception then some callbacks may have been
# orphaned, so skip the check in that case.
raise LeakedCallbackError(
"finished without waiting for callbacks %r" %
self.pending_callbacks)
self.final_callback(getattr(e, 'value', None))
self.final_callback = None
return
except Exception:
self.finished = True
self.yield_point = _null_yield_point
raise
if isinstance(yielded, list):
yielded = Multi(yielded)
elif isinstance(yielded, Future):
yielded = YieldFuture(yielded)
if isinstance(yielded, YieldPoint):
self.yield_point = yielded
try:
self.yield_point.start(self)
except Exception:
self.exc_info = sys.exc_info()
else:
self.exc_info = (BadYieldError(
"yielded unknown object %r" % (yielded,)),)
finally:
self.running = False
def result_callback(self, key):
def inner(*args, **kwargs):
if kwargs or len(args) > 1:
result = Arguments(args, kwargs)
elif args:
result = args[0]
else:
result = None
self.set_result(key, result)
return wrap(inner)
def handle_exception(self, typ, value, tb):
if not self.running and not self.finished:
self.exc_info = (typ, value, tb)
self.run()
return True
else:
return False
Arguments = collections.namedtuple('Arguments', ['args', 'kwargs'])

View file

@ -0,0 +1,497 @@
"""Blocking and non-blocking HTTP client interfaces.
This module defines a common interface shared by two implementations,
``simple_httpclient`` and ``curl_httpclient``. Applications may either
instantiate their chosen implementation class directly or use the
`AsyncHTTPClient` class from this module, which selects an implementation
that can be overridden with the `AsyncHTTPClient.configure` method.
The default implementation is ``simple_httpclient``, and this is expected
to be suitable for most users' needs. However, some applications may wish
to switch to ``curl_httpclient`` for reasons such as the following:
* ``curl_httpclient`` has some features not found in ``simple_httpclient``,
including support for HTTP proxies and the ability to use a specified
network interface.
* ``curl_httpclient`` is more likely to be compatible with sites that are
not-quite-compliant with the HTTP spec, or sites that use little-exercised
features of HTTP.
* ``curl_httpclient`` is faster.
* ``curl_httpclient`` was the default prior to Tornado 2.0.
Note that if you are using ``curl_httpclient``, it is highly recommended that
you use a recent version of ``libcurl`` and ``pycurl``. Currently the minimum
supported version is 7.18.2, and the recommended version is 7.21.1 or newer.
"""
from __future__ import absolute_import, division, print_function, with_statement
import functools
import time
import weakref
from tornado.concurrent import Future
from tornado.escape import utf8
from tornado import httputil, stack_context
from tornado.ioloop import IOLoop
from tornado.util import Configurable
class HTTPClient(object):
"""A blocking HTTP client.
This interface is provided for convenience and testing; most applications
that are running an IOLoop will want to use `AsyncHTTPClient` instead.
Typical usage looks like this::
http_client = httpclient.HTTPClient()
try:
response = http_client.fetch("http://www.google.com/")
print response.body
except httpclient.HTTPError as e:
print "Error:", e
http_client.close()
"""
def __init__(self, async_client_class=None, **kwargs):
self._io_loop = IOLoop()
if async_client_class is None:
async_client_class = AsyncHTTPClient
self._async_client = async_client_class(self._io_loop, **kwargs)
self._closed = False
def __del__(self):
self.close()
def close(self):
"""Closes the HTTPClient, freeing any resources used."""
if not self._closed:
self._async_client.close()
self._io_loop.close()
self._closed = True
def fetch(self, request, **kwargs):
"""Executes a request, returning an `HTTPResponse`.
The request may be either a string URL or an `HTTPRequest` object.
If it is a string, we construct an `HTTPRequest` using any additional
kwargs: ``HTTPRequest(request, **kwargs)``
If an error occurs during the fetch, we raise an `HTTPError`.
"""
response = self._io_loop.run_sync(functools.partial(
self._async_client.fetch, request, **kwargs))
response.rethrow()
return response
class AsyncHTTPClient(Configurable):
"""An non-blocking HTTP client.
Example usage::
def handle_request(response):
if response.error:
print "Error:", response.error
else:
print response.body
http_client = AsyncHTTPClient()
http_client.fetch("http://www.google.com/", handle_request)
The constructor for this class is magic in several respects: It
actually creates an instance of an implementation-specific
subclass, and instances are reused as a kind of pseudo-singleton
(one per `.IOLoop`). The keyword argument ``force_instance=True``
can be used to suppress this singleton behavior. Constructor
arguments other than ``io_loop`` and ``force_instance`` are
deprecated. The implementation subclass as well as arguments to
its constructor can be set with the static method `configure()`
"""
@classmethod
def configurable_base(cls):
return AsyncHTTPClient
@classmethod
def configurable_default(cls):
from tornado.simple_httpclient import SimpleAsyncHTTPClient
return SimpleAsyncHTTPClient
@classmethod
def _async_clients(cls):
attr_name = '_async_client_dict_' + cls.__name__
if not hasattr(cls, attr_name):
setattr(cls, attr_name, weakref.WeakKeyDictionary())
return getattr(cls, attr_name)
def __new__(cls, io_loop=None, force_instance=False, **kwargs):
io_loop = io_loop or IOLoop.current()
if io_loop in cls._async_clients() and not force_instance:
return cls._async_clients()[io_loop]
instance = super(AsyncHTTPClient, cls).__new__(cls, io_loop=io_loop,
**kwargs)
if not force_instance:
cls._async_clients()[io_loop] = instance
return instance
def initialize(self, io_loop, defaults=None):
self.io_loop = io_loop
self.defaults = dict(HTTPRequest._DEFAULTS)
if defaults is not None:
self.defaults.update(defaults)
def close(self):
"""Destroys this HTTP client, freeing any file descriptors used.
Not needed in normal use, but may be helpful in unittests that
create and destroy http clients. No other methods may be called
on the `AsyncHTTPClient` after ``close()``.
"""
if self._async_clients().get(self.io_loop) is self:
del self._async_clients()[self.io_loop]
def fetch(self, request, callback=None, **kwargs):
"""Executes a request, asynchronously returning an `HTTPResponse`.
The request may be either a string URL or an `HTTPRequest` object.
If it is a string, we construct an `HTTPRequest` using any additional
kwargs: ``HTTPRequest(request, **kwargs)``
This method returns a `.Future` whose result is an
`HTTPResponse`. The ``Future`` wil raise an `HTTPError` if
the request returned a non-200 response code.
If a ``callback`` is given, it will be invoked with the `HTTPResponse`.
In the callback interface, `HTTPError` is not automatically raised.
Instead, you must check the response's ``error`` attribute or
call its `~HTTPResponse.rethrow` method.
"""
if not isinstance(request, HTTPRequest):
request = HTTPRequest(url=request, **kwargs)
# We may modify this (to add Host, Accept-Encoding, etc),
# so make sure we don't modify the caller's object. This is also
# where normal dicts get converted to HTTPHeaders objects.
request.headers = httputil.HTTPHeaders(request.headers)
request = _RequestProxy(request, self.defaults)
future = Future()
if callback is not None:
callback = stack_context.wrap(callback)
def handle_future(future):
exc = future.exception()
if isinstance(exc, HTTPError) and exc.response is not None:
response = exc.response
elif exc is not None:
response = HTTPResponse(
request, 599, error=exc,
request_time=time.time() - request.start_time)
else:
response = future.result()
self.io_loop.add_callback(callback, response)
future.add_done_callback(handle_future)
def handle_response(response):
if response.error:
future.set_exception(response.error)
else:
future.set_result(response)
self.fetch_impl(request, handle_response)
return future
def fetch_impl(self, request, callback):
raise NotImplementedError()
@classmethod
def configure(cls, impl, **kwargs):
"""Configures the `AsyncHTTPClient` subclass to use.
``AsyncHTTPClient()`` actually creates an instance of a subclass.
This method may be called with either a class object or the
fully-qualified name of such a class (or ``None`` to use the default,
``SimpleAsyncHTTPClient``)
If additional keyword arguments are given, they will be passed
to the constructor of each subclass instance created. The
keyword argument ``max_clients`` determines the maximum number
of simultaneous `~AsyncHTTPClient.fetch()` operations that can
execute in parallel on each `.IOLoop`. Additional arguments
may be supported depending on the implementation class in use.
Example::
AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient")
"""
super(AsyncHTTPClient, cls).configure(impl, **kwargs)
class HTTPRequest(object):
"""HTTP client request object."""
# Default values for HTTPRequest parameters.
# Merged with the values on the request object by AsyncHTTPClient
# implementations.
_DEFAULTS = dict(
connect_timeout=20.0,
request_timeout=20.0,
follow_redirects=True,
max_redirects=5,
use_gzip=True,
proxy_password='',
allow_nonstandard_methods=False,
validate_cert=True)
def __init__(self, url, method="GET", headers=None, body=None,
auth_username=None, auth_password=None, auth_mode=None,
connect_timeout=None, request_timeout=None,
if_modified_since=None, follow_redirects=None,
max_redirects=None, user_agent=None, use_gzip=None,
network_interface=None, streaming_callback=None,
header_callback=None, prepare_curl_callback=None,
proxy_host=None, proxy_port=None, proxy_username=None,
proxy_password=None, allow_nonstandard_methods=None,
validate_cert=None, ca_certs=None,
allow_ipv6=None,
client_key=None, client_cert=None):
r"""All parameters except ``url`` are optional.
:arg string url: URL to fetch
:arg string method: HTTP method, e.g. "GET" or "POST"
:arg headers: Additional HTTP headers to pass on the request
:arg body: HTTP body to pass on the request
:type headers: `~tornado.httputil.HTTPHeaders` or `dict`
:arg string auth_username: Username for HTTP authentication
:arg string auth_password: Password for HTTP authentication
:arg string auth_mode: Authentication mode; default is "basic".
Allowed values are implementation-defined; ``curl_httpclient``
supports "basic" and "digest"; ``simple_httpclient`` only supports
"basic"
:arg float connect_timeout: Timeout for initial connection in seconds
:arg float request_timeout: Timeout for entire request in seconds
:arg if_modified_since: Timestamp for ``If-Modified-Since`` header
:type if_modified_since: `datetime` or `float`
:arg bool follow_redirects: Should redirects be followed automatically
or return the 3xx response?
:arg int max_redirects: Limit for ``follow_redirects``
:arg string user_agent: String to send as ``User-Agent`` header
:arg bool use_gzip: Request gzip encoding from the server
:arg string network_interface: Network interface to use for request
:arg callable streaming_callback: If set, ``streaming_callback`` will
be run with each chunk of data as it is received, and
``HTTPResponse.body`` and ``HTTPResponse.buffer`` will be empty in
the final response.
:arg callable header_callback: If set, ``header_callback`` will
be run with each header line as it is received (including the
first line, e.g. ``HTTP/1.0 200 OK\r\n``, and a final line
containing only ``\r\n``. All lines include the trailing newline
characters). ``HTTPResponse.headers`` will be empty in the final
response. This is most useful in conjunction with
``streaming_callback``, because it's the only way to get access to
header data while the request is in progress.
:arg callable prepare_curl_callback: If set, will be called with
a ``pycurl.Curl`` object to allow the application to make additional
``setopt`` calls.
:arg string proxy_host: HTTP proxy hostname. To use proxies,
``proxy_host`` and ``proxy_port`` must be set; ``proxy_username`` and
``proxy_pass`` are optional. Proxies are currently only supported
with ``curl_httpclient``.
:arg int proxy_port: HTTP proxy port
:arg string proxy_username: HTTP proxy username
:arg string proxy_password: HTTP proxy password
:arg bool allow_nonstandard_methods: Allow unknown values for ``method``
argument?
:arg bool validate_cert: For HTTPS requests, validate the server's
certificate?
:arg string ca_certs: filename of CA certificates in PEM format,
or None to use defaults. Note that in ``curl_httpclient``, if
any request uses a custom ``ca_certs`` file, they all must (they
don't have to all use the same ``ca_certs``, but it's not possible
to mix requests with ``ca_certs`` and requests that use the defaults.
:arg bool allow_ipv6: Use IPv6 when available? Default is false in
``simple_httpclient`` and true in ``curl_httpclient``
:arg string client_key: Filename for client SSL key, if any
:arg string client_cert: Filename for client SSL certificate, if any
.. versionadded:: 3.1
The ``auth_mode`` argument.
"""
if headers is None:
headers = httputil.HTTPHeaders()
if if_modified_since:
headers["If-Modified-Since"] = httputil.format_timestamp(
if_modified_since)
self.proxy_host = proxy_host
self.proxy_port = proxy_port
self.proxy_username = proxy_username
self.proxy_password = proxy_password
self.url = url
self.method = method
self.headers = headers
self.body = utf8(body)
self.auth_username = auth_username
self.auth_password = auth_password
self.auth_mode = auth_mode
self.connect_timeout = connect_timeout
self.request_timeout = request_timeout
self.follow_redirects = follow_redirects
self.max_redirects = max_redirects
self.user_agent = user_agent
self.use_gzip = use_gzip
self.network_interface = network_interface
self.streaming_callback = stack_context.wrap(streaming_callback)
self.header_callback = stack_context.wrap(header_callback)
self.prepare_curl_callback = stack_context.wrap(prepare_curl_callback)
self.allow_nonstandard_methods = allow_nonstandard_methods
self.validate_cert = validate_cert
self.ca_certs = ca_certs
self.allow_ipv6 = allow_ipv6
self.client_key = client_key
self.client_cert = client_cert
self.start_time = time.time()
class HTTPResponse(object):
"""HTTP Response object.
Attributes:
* request: HTTPRequest object
* code: numeric HTTP status code, e.g. 200 or 404
* reason: human-readable reason phrase describing the status code
(with curl_httpclient, this is a default value rather than the
server's actual response)
* headers: `tornado.httputil.HTTPHeaders` object
* buffer: ``cStringIO`` object for response body
* body: response body as string (created on demand from ``self.buffer``)
* error: Exception object, if any
* request_time: seconds from request start to finish
* time_info: dictionary of diagnostic timing information from the request.
Available data are subject to change, but currently uses timings
available from http://curl.haxx.se/libcurl/c/curl_easy_getinfo.html,
plus ``queue``, which is the delay (if any) introduced by waiting for
a slot under `AsyncHTTPClient`'s ``max_clients`` setting.
"""
def __init__(self, request, code, headers=None, buffer=None,
effective_url=None, error=None, request_time=None,
time_info=None, reason=None):
if isinstance(request, _RequestProxy):
self.request = request.request
else:
self.request = request
self.code = code
self.reason = reason or httputil.responses.get(code, "Unknown")
if headers is not None:
self.headers = headers
else:
self.headers = httputil.HTTPHeaders()
self.buffer = buffer
self._body = None
if effective_url is None:
self.effective_url = request.url
else:
self.effective_url = effective_url
if error is None:
if self.code < 200 or self.code >= 300:
self.error = HTTPError(self.code, response=self)
else:
self.error = None
else:
self.error = error
self.request_time = request_time
self.time_info = time_info or {}
def _get_body(self):
if self.buffer is None:
return None
elif self._body is None:
self._body = self.buffer.getvalue()
return self._body
body = property(_get_body)
def rethrow(self):
"""If there was an error on the request, raise an `HTTPError`."""
if self.error:
raise self.error
def __repr__(self):
args = ",".join("%s=%r" % i for i in sorted(self.__dict__.items()))
return "%s(%s)" % (self.__class__.__name__, args)
class HTTPError(Exception):
"""Exception thrown for an unsuccessful HTTP request.
Attributes:
* ``code`` - HTTP error integer error code, e.g. 404. Error code 599 is
used when no HTTP response was received, e.g. for a timeout.
* ``response`` - `HTTPResponse` object, if any.
Note that if ``follow_redirects`` is False, redirects become HTTPErrors,
and you can look at ``error.response.headers['Location']`` to see the
destination of the redirect.
"""
def __init__(self, code, message=None, response=None):
self.code = code
message = message or httputil.responses.get(code, "Unknown")
self.response = response
Exception.__init__(self, "HTTP %d: %s" % (self.code, message))
class _RequestProxy(object):
"""Combines an object with a dictionary of defaults.
Used internally by AsyncHTTPClient implementations.
"""
def __init__(self, request, defaults):
self.request = request
self.defaults = defaults
def __getattr__(self, name):
request_attr = getattr(self.request, name)
if request_attr is not None:
return request_attr
elif self.defaults is not None:
return self.defaults.get(name, None)
else:
return None
def main():
from tornado.options import define, options, parse_command_line
define("print_headers", type=bool, default=False)
define("print_body", type=bool, default=True)
define("follow_redirects", type=bool, default=True)
define("validate_cert", type=bool, default=True)
args = parse_command_line()
client = HTTPClient()
for arg in args:
try:
response = client.fetch(arg,
follow_redirects=options.follow_redirects,
validate_cert=options.validate_cert,
)
except HTTPError as e:
if e.response is not None:
response = e.response
else:
raise
if options.print_headers:
print(response.headers)
if options.print_body:
print(response.body)
client.close()
if __name__ == "__main__":
main()

View file

@ -0,0 +1,529 @@
#!/usr/bin/env python
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""A non-blocking, single-threaded HTTP server.
Typical applications have little direct interaction with the `HTTPServer`
class except to start a server at the beginning of the process
(and even that is often done indirectly via `tornado.web.Application.listen`).
This module also defines the `HTTPRequest` class which is exposed via
`tornado.web.RequestHandler.request`.
"""
from __future__ import absolute_import, division, print_function, with_statement
import socket
import ssl
import time
from tornado.escape import native_str, parse_qs_bytes
from tornado import httputil
from tornado import iostream
from tornado.log import gen_log
from tornado import netutil
from tornado.tcpserver import TCPServer
from tornado import stack_context
from tornado.util import bytes_type
try:
import Cookie # py2
except ImportError:
import http.cookies as Cookie # py3
class HTTPServer(TCPServer):
r"""A non-blocking, single-threaded HTTP server.
A server is defined by a request callback that takes an HTTPRequest
instance as an argument and writes a valid HTTP response with
`HTTPRequest.write`. `HTTPRequest.finish` finishes the request (but does
not necessarily close the connection in the case of HTTP/1.1 keep-alive
requests). A simple example server that echoes back the URI you
requested::
import tornado.httpserver
import tornado.ioloop
def handle_request(request):
message = "You requested %s\n" % request.uri
request.write("HTTP/1.1 200 OK\r\nContent-Length: %d\r\n\r\n%s" % (
len(message), message))
request.finish()
http_server = tornado.httpserver.HTTPServer(handle_request)
http_server.listen(8888)
tornado.ioloop.IOLoop.instance().start()
`HTTPServer` is a very basic connection handler. It parses the request
headers and body, but the request callback is responsible for producing
the response exactly as it will appear on the wire. This affords
maximum flexibility for applications to implement whatever parts
of HTTP responses are required.
`HTTPServer` supports keep-alive connections by default
(automatically for HTTP/1.1, or for HTTP/1.0 when the client
requests ``Connection: keep-alive``). This means that the request
callback must generate a properly-framed response, using either
the ``Content-Length`` header or ``Transfer-Encoding: chunked``.
Applications that are unable to frame their responses properly
should instead return a ``Connection: close`` header in each
response and pass ``no_keep_alive=True`` to the `HTTPServer`
constructor.
If ``xheaders`` is ``True``, we support the
``X-Real-Ip``/``X-Forwarded-For`` and
``X-Scheme``/``X-Forwarded-Proto`` headers, which override the
remote IP and URI scheme/protocol for all requests. These headers
are useful when running Tornado behind a reverse proxy or load
balancer. The ``protocol`` argument can also be set to ``https``
if Tornado is run behind an SSL-decoding proxy that does not set one of
the supported ``xheaders``.
To make this server serve SSL traffic, send the ``ssl_options`` dictionary
argument with the arguments required for the `ssl.wrap_socket` method,
including ``certfile`` and ``keyfile``. (In Python 3.2+ you can pass
an `ssl.SSLContext` object instead of a dict)::
HTTPServer(applicaton, ssl_options={
"certfile": os.path.join(data_dir, "mydomain.crt"),
"keyfile": os.path.join(data_dir, "mydomain.key"),
})
`HTTPServer` initialization follows one of three patterns (the
initialization methods are defined on `tornado.tcpserver.TCPServer`):
1. `~tornado.tcpserver.TCPServer.listen`: simple single-process::
server = HTTPServer(app)
server.listen(8888)
IOLoop.instance().start()
In many cases, `tornado.web.Application.listen` can be used to avoid
the need to explicitly create the `HTTPServer`.
2. `~tornado.tcpserver.TCPServer.bind`/`~tornado.tcpserver.TCPServer.start`:
simple multi-process::
server = HTTPServer(app)
server.bind(8888)
server.start(0) # Forks multiple sub-processes
IOLoop.instance().start()
When using this interface, an `.IOLoop` must *not* be passed
to the `HTTPServer` constructor. `~.TCPServer.start` will always start
the server on the default singleton `.IOLoop`.
3. `~tornado.tcpserver.TCPServer.add_sockets`: advanced multi-process::
sockets = tornado.netutil.bind_sockets(8888)
tornado.process.fork_processes(0)
server = HTTPServer(app)
server.add_sockets(sockets)
IOLoop.instance().start()
The `~.TCPServer.add_sockets` interface is more complicated,
but it can be used with `tornado.process.fork_processes` to
give you more flexibility in when the fork happens.
`~.TCPServer.add_sockets` can also be used in single-process
servers if you want to create your listening sockets in some
way other than `tornado.netutil.bind_sockets`.
"""
def __init__(self, request_callback, no_keep_alive=False, io_loop=None,
xheaders=False, ssl_options=None, protocol=None, **kwargs):
self.request_callback = request_callback
self.no_keep_alive = no_keep_alive
self.xheaders = xheaders
self.protocol = protocol
TCPServer.__init__(self, io_loop=io_loop, ssl_options=ssl_options,
**kwargs)
def handle_stream(self, stream, address):
HTTPConnection(stream, address, self.request_callback,
self.no_keep_alive, self.xheaders, self.protocol)
class _BadRequestException(Exception):
"""Exception class for malformed HTTP requests."""
pass
class HTTPConnection(object):
"""Handles a connection to an HTTP client, executing HTTP requests.
We parse HTTP headers and bodies, and execute the request callback
until the HTTP conection is closed.
"""
def __init__(self, stream, address, request_callback, no_keep_alive=False,
xheaders=False, protocol=None):
self.stream = stream
self.address = address
# Save the socket's address family now so we know how to
# interpret self.address even after the stream is closed
# and its socket attribute replaced with None.
self.address_family = stream.socket.family
self.request_callback = request_callback
self.no_keep_alive = no_keep_alive
self.xheaders = xheaders
self.protocol = protocol
self._clear_request_state()
# Save stack context here, outside of any request. This keeps
# contexts from one request from leaking into the next.
self._header_callback = stack_context.wrap(self._on_headers)
self.stream.set_close_callback(self._on_connection_close)
self.stream.read_until(b"\r\n\r\n", self._header_callback)
def _clear_request_state(self):
"""Clears the per-request state.
This is run in between requests to allow the previous handler
to be garbage collected (and prevent spurious close callbacks),
and when the connection is closed (to break up cycles and
facilitate garbage collection in cpython).
"""
self._request = None
self._request_finished = False
self._write_callback = None
self._close_callback = None
def set_close_callback(self, callback):
"""Sets a callback that will be run when the connection is closed.
Use this instead of accessing
`HTTPConnection.stream.set_close_callback
<.BaseIOStream.set_close_callback>` directly (which was the
recommended approach prior to Tornado 3.0).
"""
self._close_callback = stack_context.wrap(callback)
def _on_connection_close(self):
if self._close_callback is not None:
callback = self._close_callback
self._close_callback = None
callback()
# Delete any unfinished callbacks to break up reference cycles.
self._header_callback = None
self._clear_request_state()
def close(self):
self.stream.close()
# Remove this reference to self, which would otherwise cause a
# cycle and delay garbage collection of this connection.
self._header_callback = None
self._clear_request_state()
def write(self, chunk, callback=None):
"""Writes a chunk of output to the stream."""
if not self.stream.closed():
self._write_callback = stack_context.wrap(callback)
self.stream.write(chunk, self._on_write_complete)
def finish(self):
"""Finishes the request."""
self._request_finished = True
# No more data is coming, so instruct TCP to send any remaining
# data immediately instead of waiting for a full packet or ack.
self.stream.set_nodelay(True)
if not self.stream.writing():
self._finish_request()
def _on_write_complete(self):
if self._write_callback is not None:
callback = self._write_callback
self._write_callback = None
callback()
# _on_write_complete is enqueued on the IOLoop whenever the
# IOStream's write buffer becomes empty, but it's possible for
# another callback that runs on the IOLoop before it to
# simultaneously write more data and finish the request. If
# there is still data in the IOStream, a future
# _on_write_complete will be responsible for calling
# _finish_request.
if self._request_finished and not self.stream.writing():
self._finish_request()
def _finish_request(self):
if self.no_keep_alive or self._request is None:
disconnect = True
else:
connection_header = self._request.headers.get("Connection")
if connection_header is not None:
connection_header = connection_header.lower()
if self._request.supports_http_1_1():
disconnect = connection_header == "close"
elif ("Content-Length" in self._request.headers
or self._request.method in ("HEAD", "GET")):
disconnect = connection_header != "keep-alive"
else:
disconnect = True
self._clear_request_state()
if disconnect:
self.close()
return
try:
# Use a try/except instead of checking stream.closed()
# directly, because in some cases the stream doesn't discover
# that it's closed until you try to read from it.
self.stream.read_until(b"\r\n\r\n", self._header_callback)
# Turn Nagle's algorithm back on, leaving the stream in its
# default state for the next request.
self.stream.set_nodelay(False)
except iostream.StreamClosedError:
self.close()
def _on_headers(self, data):
try:
data = native_str(data.decode('latin1'))
eol = data.find("\r\n")
start_line = data[:eol]
try:
method, uri, version = start_line.split(" ")
except ValueError:
raise _BadRequestException("Malformed HTTP request line")
if not version.startswith("HTTP/"):
raise _BadRequestException("Malformed HTTP version in HTTP Request-Line")
try:
headers = httputil.HTTPHeaders.parse(data[eol:])
except ValueError:
# Probably from split() if there was no ':' in the line
raise _BadRequestException("Malformed HTTP headers")
# HTTPRequest wants an IP, not a full socket address
if self.address_family in (socket.AF_INET, socket.AF_INET6):
remote_ip = self.address[0]
else:
# Unix (or other) socket; fake the remote address
remote_ip = '0.0.0.0'
self._request = HTTPRequest(
connection=self, method=method, uri=uri, version=version,
headers=headers, remote_ip=remote_ip, protocol=self.protocol)
content_length = headers.get("Content-Length")
if content_length:
content_length = int(content_length)
if content_length > self.stream.max_buffer_size:
raise _BadRequestException("Content-Length too long")
if headers.get("Expect") == "100-continue":
self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
self.stream.read_bytes(content_length, self._on_request_body)
return
self.request_callback(self._request)
except _BadRequestException as e:
gen_log.info("Malformed HTTP request from %s: %s",
self.address[0], e)
self.close()
return
def _on_request_body(self, data):
self._request.body = data
if self._request.method in ("POST", "PATCH", "PUT"):
httputil.parse_body_arguments(
self._request.headers.get("Content-Type", ""), data,
self._request.arguments, self._request.files)
self.request_callback(self._request)
class HTTPRequest(object):
"""A single HTTP request.
All attributes are type `str` unless otherwise noted.
.. attribute:: method
HTTP request method, e.g. "GET" or "POST"
.. attribute:: uri
The requested uri.
.. attribute:: path
The path portion of `uri`
.. attribute:: query
The query portion of `uri`
.. attribute:: version
HTTP version specified in request, e.g. "HTTP/1.1"
.. attribute:: headers
`.HTTPHeaders` dictionary-like object for request headers. Acts like
a case-insensitive dictionary with additional methods for repeated
headers.
.. attribute:: body
Request body, if present, as a byte string.
.. attribute:: remote_ip
Client's IP address as a string. If ``HTTPServer.xheaders`` is set,
will pass along the real IP address provided by a load balancer
in the ``X-Real-Ip`` or ``X-Forwarded-For`` header.
.. versionchanged:: 3.1
The list format of ``X-Forwarded-For`` is now supported.
.. attribute:: protocol
The protocol used, either "http" or "https". If ``HTTPServer.xheaders``
is set, will pass along the protocol used by a load balancer if
reported via an ``X-Scheme`` header.
.. attribute:: host
The requested hostname, usually taken from the ``Host`` header.
.. attribute:: arguments
GET/POST arguments are available in the arguments property, which
maps arguments names to lists of values (to support multiple values
for individual names). Names are of type `str`, while arguments
are byte strings. Note that this is different from
`.RequestHandler.get_argument`, which returns argument values as
unicode strings.
.. attribute:: files
File uploads are available in the files property, which maps file
names to lists of `.HTTPFile`.
.. attribute:: connection
An HTTP request is attached to a single HTTP connection, which can
be accessed through the "connection" attribute. Since connections
are typically kept open in HTTP/1.1, multiple requests can be handled
sequentially on a single connection.
"""
def __init__(self, method, uri, version="HTTP/1.0", headers=None,
body=None, remote_ip=None, protocol=None, host=None,
files=None, connection=None):
self.method = method
self.uri = uri
self.version = version
self.headers = headers or httputil.HTTPHeaders()
self.body = body or ""
# set remote IP and protocol
self.remote_ip = remote_ip
if protocol:
self.protocol = protocol
elif connection and isinstance(connection.stream,
iostream.SSLIOStream):
self.protocol = "https"
else:
self.protocol = "http"
# xheaders can override the defaults
if connection and connection.xheaders:
# Squid uses X-Forwarded-For, others use X-Real-Ip
ip = self.headers.get("X-Forwarded-For", self.remote_ip)
ip = ip.split(',')[-1].strip()
ip = self.headers.get(
"X-Real-Ip", ip)
if netutil.is_valid_ip(ip):
self.remote_ip = ip
# AWS uses X-Forwarded-Proto
proto = self.headers.get(
"X-Scheme", self.headers.get("X-Forwarded-Proto", self.protocol))
if proto in ("http", "https"):
self.protocol = proto
self.host = host or self.headers.get("Host") or "127.0.0.1"
self.files = files or {}
self.connection = connection
self._start_time = time.time()
self._finish_time = None
self.path, sep, self.query = uri.partition('?')
self.arguments = parse_qs_bytes(self.query, keep_blank_values=True)
def supports_http_1_1(self):
"""Returns True if this request supports HTTP/1.1 semantics"""
return self.version == "HTTP/1.1"
@property
def cookies(self):
"""A dictionary of Cookie.Morsel objects."""
if not hasattr(self, "_cookies"):
self._cookies = Cookie.SimpleCookie()
if "Cookie" in self.headers:
try:
self._cookies.load(
native_str(self.headers["Cookie"]))
except Exception:
self._cookies = {}
return self._cookies
def write(self, chunk, callback=None):
"""Writes the given chunk to the response stream."""
assert isinstance(chunk, bytes_type)
self.connection.write(chunk, callback=callback)
def finish(self):
"""Finishes this HTTP request on the open connection."""
self.connection.finish()
self._finish_time = time.time()
def full_url(self):
"""Reconstructs the full URL for this request."""
return self.protocol + "://" + self.host + self.uri
def request_time(self):
"""Returns the amount of time it took for this request to execute."""
if self._finish_time is None:
return time.time() - self._start_time
else:
return self._finish_time - self._start_time
def get_ssl_certificate(self, binary_form=False):
"""Returns the client's SSL certificate, if any.
To use client certificates, the HTTPServer must have been constructed
with cert_reqs set in ssl_options, e.g.::
server = HTTPServer(app,
ssl_options=dict(
certfile="foo.crt",
keyfile="foo.key",
cert_reqs=ssl.CERT_REQUIRED,
ca_certs="cacert.crt"))
By default, the return value is a dictionary (or None, if no
client certificate is present). If ``binary_form`` is true, a
DER-encoded form of the certificate is returned instead. See
SSLSocket.getpeercert() in the standard library for more
details.
http://docs.python.org/library/ssl.html#sslsocket-objects
"""
try:
return self.connection.stream.socket.getpeercert(
binary_form=binary_form)
except ssl.SSLError:
return None
def __repr__(self):
attrs = ("protocol", "host", "method", "uri", "version", "remote_ip")
args = ", ".join(["%s=%r" % (n, getattr(self, n)) for n in attrs])
return "%s(%s, headers=%s)" % (
self.__class__.__name__, args, dict(self.headers))

View file

@ -0,0 +1,445 @@
#!/usr/bin/env python
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""HTTP utility code shared by clients and servers."""
from __future__ import absolute_import, division, print_function, with_statement
import calendar
import collections
import datetime
import email.utils
import numbers
import time
from tornado.escape import native_str, parse_qs_bytes, utf8
from tornado.log import gen_log
from tornado.util import ObjectDict
try:
from httplib import responses # py2
except ImportError:
from http.client import responses # py3
# responses is unused in this file, but we re-export it to other files.
# Reference it so pyflakes doesn't complain.
responses
try:
from urllib import urlencode # py2
except ImportError:
from urllib.parse import urlencode # py3
class _NormalizedHeaderCache(dict):
"""Dynamic cached mapping of header names to Http-Header-Case.
Implemented as a dict subclass so that cache hits are as fast as a
normal dict lookup, without the overhead of a python function
call.
>>> normalized_headers = _NormalizedHeaderCache(10)
>>> normalized_headers["coNtent-TYPE"]
'Content-Type'
"""
def __init__(self, size):
super(_NormalizedHeaderCache, self).__init__()
self.size = size
self.queue = collections.deque()
def __missing__(self, key):
normalized = "-".join([w.capitalize() for w in key.split("-")])
self[key] = normalized
self.queue.append(key)
if len(self.queue) > self.size:
# Limit the size of the cache. LRU would be better, but this
# simpler approach should be fine. In Python 2.7+ we could
# use OrderedDict (or in 3.2+, @functools.lru_cache).
old_key = self.queue.popleft()
del self[old_key]
return normalized
_normalized_headers = _NormalizedHeaderCache(1000)
class HTTPHeaders(dict):
"""A dictionary that maintains ``Http-Header-Case`` for all keys.
Supports multiple values per key via a pair of new methods,
`add()` and `get_list()`. The regular dictionary interface
returns a single value per key, with multiple values joined by a
comma.
>>> h = HTTPHeaders({"content-type": "text/html"})
>>> list(h.keys())
['Content-Type']
>>> h["Content-Type"]
'text/html'
>>> h.add("Set-Cookie", "A=B")
>>> h.add("Set-Cookie", "C=D")
>>> h["set-cookie"]
'A=B,C=D'
>>> h.get_list("set-cookie")
['A=B', 'C=D']
>>> for (k,v) in sorted(h.get_all()):
... print('%s: %s' % (k,v))
...
Content-Type: text/html
Set-Cookie: A=B
Set-Cookie: C=D
"""
def __init__(self, *args, **kwargs):
# Don't pass args or kwargs to dict.__init__, as it will bypass
# our __setitem__
dict.__init__(self)
self._as_list = {}
self._last_key = None
if (len(args) == 1 and len(kwargs) == 0 and
isinstance(args[0], HTTPHeaders)):
# Copy constructor
for k, v in args[0].get_all():
self.add(k, v)
else:
# Dict-style initialization
self.update(*args, **kwargs)
# new public methods
def add(self, name, value):
"""Adds a new value for the given key."""
norm_name = _normalized_headers[name]
self._last_key = norm_name
if norm_name in self:
# bypass our override of __setitem__ since it modifies _as_list
dict.__setitem__(self, norm_name,
native_str(self[norm_name]) + ',' +
native_str(value))
self._as_list[norm_name].append(value)
else:
self[norm_name] = value
def get_list(self, name):
"""Returns all values for the given header as a list."""
norm_name = _normalized_headers[name]
return self._as_list.get(norm_name, [])
def get_all(self):
"""Returns an iterable of all (name, value) pairs.
If a header has multiple values, multiple pairs will be
returned with the same name.
"""
for name, values in self._as_list.items():
for value in values:
yield (name, value)
def parse_line(self, line):
"""Updates the dictionary with a single header line.
>>> h = HTTPHeaders()
>>> h.parse_line("Content-Type: text/html")
>>> h.get('content-type')
'text/html'
"""
if line[0].isspace():
# continuation of a multi-line header
new_part = ' ' + line.lstrip()
self._as_list[self._last_key][-1] += new_part
dict.__setitem__(self, self._last_key,
self[self._last_key] + new_part)
else:
name, value = line.split(":", 1)
self.add(name, value.strip())
@classmethod
def parse(cls, headers):
"""Returns a dictionary from HTTP header text.
>>> h = HTTPHeaders.parse("Content-Type: text/html\\r\\nContent-Length: 42\\r\\n")
>>> sorted(h.items())
[('Content-Length', '42'), ('Content-Type', 'text/html')]
"""
h = cls()
for line in headers.splitlines():
if line:
h.parse_line(line)
return h
# dict implementation overrides
def __setitem__(self, name, value):
norm_name = _normalized_headers[name]
dict.__setitem__(self, norm_name, value)
self._as_list[norm_name] = [value]
def __getitem__(self, name):
return dict.__getitem__(self, _normalized_headers[name])
def __delitem__(self, name):
norm_name = _normalized_headers[name]
dict.__delitem__(self, norm_name)
del self._as_list[norm_name]
def __contains__(self, name):
norm_name = _normalized_headers[name]
return dict.__contains__(self, norm_name)
def get(self, name, default=None):
return dict.get(self, _normalized_headers[name], default)
def update(self, *args, **kwargs):
# dict.update bypasses our __setitem__
for k, v in dict(*args, **kwargs).items():
self[k] = v
def copy(self):
# default implementation returns dict(self), not the subclass
return HTTPHeaders(self)
def url_concat(url, args):
"""Concatenate url and argument dictionary regardless of whether
url has existing query parameters.
>>> url_concat("http://example.com/foo?a=b", dict(c="d"))
'http://example.com/foo?a=b&c=d'
"""
if not args:
return url
if url[-1] not in ('?', '&'):
url += '&' if ('?' in url) else '?'
return url + urlencode(args)
class HTTPFile(ObjectDict):
"""Represents a file uploaded via a form.
For backwards compatibility, its instance attributes are also
accessible as dictionary keys.
* ``filename``
* ``body``
* ``content_type``
"""
pass
def _parse_request_range(range_header):
"""Parses a Range header.
Returns either ``None`` or tuple ``(start, end)``.
Note that while the HTTP headers use inclusive byte positions,
this method returns indexes suitable for use in slices.
>>> start, end = _parse_request_range("bytes=1-2")
>>> start, end
(1, 3)
>>> [0, 1, 2, 3, 4][start:end]
[1, 2]
>>> _parse_request_range("bytes=6-")
(6, None)
>>> _parse_request_range("bytes=-6")
(-6, None)
>>> _parse_request_range("bytes=-0")
(None, 0)
>>> _parse_request_range("bytes=")
(None, None)
>>> _parse_request_range("foo=42")
>>> _parse_request_range("bytes=1-2,6-10")
Note: only supports one range (ex, ``bytes=1-2,6-10`` is not allowed).
See [0] for the details of the range header.
[0]: http://greenbytes.de/tech/webdav/draft-ietf-httpbis-p5-range-latest.html#byte.ranges
"""
unit, _, value = range_header.partition("=")
unit, value = unit.strip(), value.strip()
if unit != "bytes":
return None
start_b, _, end_b = value.partition("-")
try:
start = _int_or_none(start_b)
end = _int_or_none(end_b)
except ValueError:
return None
if end is not None:
if start is None:
if end != 0:
start = -end
end = None
else:
end += 1
return (start, end)
def _get_content_range(start, end, total):
"""Returns a suitable Content-Range header:
>>> print(_get_content_range(None, 1, 4))
bytes 0-0/4
>>> print(_get_content_range(1, 3, 4))
bytes 1-2/4
>>> print(_get_content_range(None, None, 4))
bytes 0-3/4
"""
start = start or 0
end = (end or total) - 1
return "bytes %s-%s/%s" % (start, end, total)
def _int_or_none(val):
val = val.strip()
if val == "":
return None
return int(val)
def parse_body_arguments(content_type, body, arguments, files):
"""Parses a form request body.
Supports ``application/x-www-form-urlencoded`` and
``multipart/form-data``. The ``content_type`` parameter should be
a string and ``body`` should be a byte string. The ``arguments``
and ``files`` parameters are dictionaries that will be updated
with the parsed contents.
"""
if content_type.startswith("application/x-www-form-urlencoded"):
uri_arguments = parse_qs_bytes(native_str(body), keep_blank_values=True)
for name, values in uri_arguments.items():
if values:
arguments.setdefault(name, []).extend(values)
elif content_type.startswith("multipart/form-data"):
fields = content_type.split(";")
for field in fields:
k, sep, v = field.strip().partition("=")
if k == "boundary" and v:
parse_multipart_form_data(utf8(v), body, arguments, files)
break
else:
gen_log.warning("Invalid multipart/form-data")
def parse_multipart_form_data(boundary, data, arguments, files):
"""Parses a ``multipart/form-data`` body.
The ``boundary`` and ``data`` parameters are both byte strings.
The dictionaries given in the arguments and files parameters
will be updated with the contents of the body.
"""
# The standard allows for the boundary to be quoted in the header,
# although it's rare (it happens at least for google app engine
# xmpp). I think we're also supposed to handle backslash-escapes
# here but I'll save that until we see a client that uses them
# in the wild.
if boundary.startswith(b'"') and boundary.endswith(b'"'):
boundary = boundary[1:-1]
final_boundary_index = data.rfind(b"--" + boundary + b"--")
if final_boundary_index == -1:
gen_log.warning("Invalid multipart/form-data: no final boundary")
return
parts = data[:final_boundary_index].split(b"--" + boundary + b"\r\n")
for part in parts:
if not part:
continue
eoh = part.find(b"\r\n\r\n")
if eoh == -1:
gen_log.warning("multipart/form-data missing headers")
continue
headers = HTTPHeaders.parse(part[:eoh].decode("utf-8"))
disp_header = headers.get("Content-Disposition", "")
disposition, disp_params = _parse_header(disp_header)
if disposition != "form-data" or not part.endswith(b"\r\n"):
gen_log.warning("Invalid multipart/form-data")
continue
value = part[eoh + 4:-2]
if not disp_params.get("name"):
gen_log.warning("multipart/form-data value missing name")
continue
name = disp_params["name"]
if disp_params.get("filename"):
ctype = headers.get("Content-Type", "application/unknown")
files.setdefault(name, []).append(HTTPFile(
filename=disp_params["filename"], body=value,
content_type=ctype))
else:
arguments.setdefault(name, []).append(value)
def format_timestamp(ts):
"""Formats a timestamp in the format used by HTTP.
The argument may be a numeric timestamp as returned by `time.time`,
a time tuple as returned by `time.gmtime`, or a `datetime.datetime`
object.
>>> format_timestamp(1359312200)
'Sun, 27 Jan 2013 18:43:20 GMT'
"""
if isinstance(ts, numbers.Real):
pass
elif isinstance(ts, (tuple, time.struct_time)):
ts = calendar.timegm(ts)
elif isinstance(ts, datetime.datetime):
ts = calendar.timegm(ts.utctimetuple())
else:
raise TypeError("unknown timestamp type: %r" % ts)
return email.utils.formatdate(ts, usegmt=True)
# _parseparam and _parse_header are copied and modified from python2.7's cgi.py
# The original 2.7 version of this code did not correctly support some
# combinations of semicolons and double quotes.
def _parseparam(s):
while s[:1] == ';':
s = s[1:]
end = s.find(';')
while end > 0 and (s.count('"', 0, end) - s.count('\\"', 0, end)) % 2:
end = s.find(';', end + 1)
if end < 0:
end = len(s)
f = s[:end]
yield f.strip()
s = s[end:]
def _parse_header(line):
"""Parse a Content-type like header.
Return the main content-type and a dictionary of options.
"""
parts = _parseparam(';' + line)
key = next(parts)
pdict = {}
for p in parts:
i = p.find('=')
if i >= 0:
name = p[:i].strip().lower()
value = p[i + 1:].strip()
if len(value) >= 2 and value[0] == value[-1] == '"':
value = value[1:-1]
value = value.replace('\\\\', '\\').replace('\\"', '"')
pdict[name] = value
return key, pdict
def doctests():
import doctest
return doctest.DocTestSuite()

View file

@ -0,0 +1,824 @@
#!/usr/bin/env python
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""An I/O event loop for non-blocking sockets.
Typical applications will use a single `IOLoop` object, in the
`IOLoop.instance` singleton. The `IOLoop.start` method should usually
be called at the end of the ``main()`` function. Atypical applications may
use more than one `IOLoop`, such as one `IOLoop` per thread, or per `unittest`
case.
In addition to I/O events, the `IOLoop` can also schedule time-based events.
`IOLoop.add_timeout` is a non-blocking alternative to `time.sleep`.
"""
from __future__ import absolute_import, division, print_function, with_statement
import datetime
import errno
import functools
import heapq
import logging
import numbers
import os
import select
import sys
import threading
import time
import traceback
from tornado.concurrent import Future, TracebackFuture
from tornado.log import app_log, gen_log
from tornado import stack_context
from tornado.util import Configurable
try:
import signal
except ImportError:
signal = None
try:
import thread # py2
except ImportError:
import _thread as thread # py3
from tornado.platform.auto import set_close_exec, Waker
class TimeoutError(Exception):
pass
class IOLoop(Configurable):
"""A level-triggered I/O loop.
We use ``epoll`` (Linux) or ``kqueue`` (BSD and Mac OS X) if they
are available, or else we fall back on select(). If you are
implementing a system that needs to handle thousands of
simultaneous connections, you should use a system that supports
either ``epoll`` or ``kqueue``.
Example usage for a simple TCP server::
import errno
import functools
import ioloop
import socket
def connection_ready(sock, fd, events):
while True:
try:
connection, address = sock.accept()
except socket.error, e:
if e.args[0] not in (errno.EWOULDBLOCK, errno.EAGAIN):
raise
return
connection.setblocking(0)
handle_connection(connection, address)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setblocking(0)
sock.bind(("", port))
sock.listen(128)
io_loop = ioloop.IOLoop.instance()
callback = functools.partial(connection_ready, sock)
io_loop.add_handler(sock.fileno(), callback, io_loop.READ)
io_loop.start()
"""
# Constants from the epoll module
_EPOLLIN = 0x001
_EPOLLPRI = 0x002
_EPOLLOUT = 0x004
_EPOLLERR = 0x008
_EPOLLHUP = 0x010
_EPOLLRDHUP = 0x2000
_EPOLLONESHOT = (1 << 30)
_EPOLLET = (1 << 31)
# Our events map exactly to the epoll events
NONE = 0
READ = _EPOLLIN
WRITE = _EPOLLOUT
ERROR = _EPOLLERR | _EPOLLHUP
# Global lock for creating global IOLoop instance
_instance_lock = threading.Lock()
_current = threading.local()
@staticmethod
def instance():
"""Returns a global `IOLoop` instance.
Most applications have a single, global `IOLoop` running on the
main thread. Use this method to get this instance from
another thread. To get the current thread's `IOLoop`, use `current()`.
"""
if not hasattr(IOLoop, "_instance"):
with IOLoop._instance_lock:
if not hasattr(IOLoop, "_instance"):
# New instance after double check
IOLoop._instance = IOLoop()
return IOLoop._instance
@staticmethod
def initialized():
"""Returns true if the singleton instance has been created."""
return hasattr(IOLoop, "_instance")
def install(self):
"""Installs this `IOLoop` object as the singleton instance.
This is normally not necessary as `instance()` will create
an `IOLoop` on demand, but you may want to call `install` to use
a custom subclass of `IOLoop`.
"""
assert not IOLoop.initialized()
IOLoop._instance = self
@staticmethod
def current():
"""Returns the current thread's `IOLoop`.
If an `IOLoop` is currently running or has been marked as current
by `make_current`, returns that instance. Otherwise returns
`IOLoop.instance()`, i.e. the main thread's `IOLoop`.
A common pattern for classes that depend on ``IOLoops`` is to use
a default argument to enable programs with multiple ``IOLoops``
but not require the argument for simpler applications::
class MyClass(object):
def __init__(self, io_loop=None):
self.io_loop = io_loop or IOLoop.current()
In general you should use `IOLoop.current` as the default when
constructing an asynchronous object, and use `IOLoop.instance`
when you mean to communicate to the main thread from a different
one.
"""
current = getattr(IOLoop._current, "instance", None)
if current is None:
return IOLoop.instance()
return current
def make_current(self):
"""Makes this the `IOLoop` for the current thread.
An `IOLoop` automatically becomes current for its thread
when it is started, but it is sometimes useful to call
`make_current` explictly before starting the `IOLoop`,
so that code run at startup time can find the right
instance.
"""
IOLoop._current.instance = self
@staticmethod
def clear_current():
IOLoop._current.instance = None
@classmethod
def configurable_base(cls):
return IOLoop
@classmethod
def configurable_default(cls):
if hasattr(select, "epoll"):
from tornado.platform.epoll import EPollIOLoop
return EPollIOLoop
if hasattr(select, "kqueue"):
# Python 2.6+ on BSD or Mac
from tornado.platform.kqueue import KQueueIOLoop
return KQueueIOLoop
from tornado.platform.select import SelectIOLoop
return SelectIOLoop
def initialize(self):
pass
def close(self, all_fds=False):
"""Closes the `IOLoop`, freeing any resources used.
If ``all_fds`` is true, all file descriptors registered on the
IOLoop will be closed (not just the ones created by the
`IOLoop` itself).
Many applications will only use a single `IOLoop` that runs for the
entire lifetime of the process. In that case closing the `IOLoop`
is not necessary since everything will be cleaned up when the
process exits. `IOLoop.close` is provided mainly for scenarios
such as unit tests, which create and destroy a large number of
``IOLoops``.
An `IOLoop` must be completely stopped before it can be closed. This
means that `IOLoop.stop()` must be called *and* `IOLoop.start()` must
be allowed to return before attempting to call `IOLoop.close()`.
Therefore the call to `close` will usually appear just after
the call to `start` rather than near the call to `stop`.
.. versionchanged:: 3.1
If the `IOLoop` implementation supports non-integer objects
for "file descriptors", those objects will have their
``close`` method when ``all_fds`` is true.
"""
raise NotImplementedError()
def add_handler(self, fd, handler, events):
"""Registers the given handler to receive the given events for fd.
The ``events`` argument is a bitwise or of the constants
``IOLoop.READ``, ``IOLoop.WRITE``, and ``IOLoop.ERROR``.
When an event occurs, ``handler(fd, events)`` will be run.
"""
raise NotImplementedError()
def update_handler(self, fd, events):
"""Changes the events we listen for fd."""
raise NotImplementedError()
def remove_handler(self, fd):
"""Stop listening for events on fd."""
raise NotImplementedError()
def set_blocking_signal_threshold(self, seconds, action):
"""Sends a signal if the `IOLoop` is blocked for more than
``s`` seconds.
Pass ``seconds=None`` to disable. Requires Python 2.6 on a unixy
platform.
The action parameter is a Python signal handler. Read the
documentation for the `signal` module for more information.
If ``action`` is None, the process will be killed if it is
blocked for too long.
"""
raise NotImplementedError()
def set_blocking_log_threshold(self, seconds):
"""Logs a stack trace if the `IOLoop` is blocked for more than
``s`` seconds.
Equivalent to ``set_blocking_signal_threshold(seconds,
self.log_stack)``
"""
self.set_blocking_signal_threshold(seconds, self.log_stack)
def log_stack(self, signal, frame):
"""Signal handler to log the stack trace of the current thread.
For use with `set_blocking_signal_threshold`.
"""
gen_log.warning('IOLoop blocked for %f seconds in\n%s',
self._blocking_signal_threshold,
''.join(traceback.format_stack(frame)))
def start(self):
"""Starts the I/O loop.
The loop will run until one of the callbacks calls `stop()`, which
will make the loop stop after the current event iteration completes.
"""
raise NotImplementedError()
def stop(self):
"""Stop the I/O loop.
If the event loop is not currently running, the next call to `start()`
will return immediately.
To use asynchronous methods from otherwise-synchronous code (such as
unit tests), you can start and stop the event loop like this::
ioloop = IOLoop()
async_method(ioloop=ioloop, callback=ioloop.stop)
ioloop.start()
``ioloop.start()`` will return after ``async_method`` has run
its callback, whether that callback was invoked before or
after ``ioloop.start``.
Note that even after `stop` has been called, the `IOLoop` is not
completely stopped until `IOLoop.start` has also returned.
Some work that was scheduled before the call to `stop` may still
be run before the `IOLoop` shuts down.
"""
raise NotImplementedError()
def run_sync(self, func, timeout=None):
"""Starts the `IOLoop`, runs the given function, and stops the loop.
If the function returns a `.Future`, the `IOLoop` will run
until the future is resolved. If it raises an exception, the
`IOLoop` will stop and the exception will be re-raised to the
caller.
The keyword-only argument ``timeout`` may be used to set
a maximum duration for the function. If the timeout expires,
a `TimeoutError` is raised.
This method is useful in conjunction with `tornado.gen.coroutine`
to allow asynchronous calls in a ``main()`` function::
@gen.coroutine
def main():
# do stuff...
if __name__ == '__main__':
IOLoop.instance().run_sync(main)
"""
future_cell = [None]
def run():
try:
result = func()
except Exception:
future_cell[0] = TracebackFuture()
future_cell[0].set_exc_info(sys.exc_info())
else:
if isinstance(result, Future):
future_cell[0] = result
else:
future_cell[0] = Future()
future_cell[0].set_result(result)
self.add_future(future_cell[0], lambda future: self.stop())
self.add_callback(run)
if timeout is not None:
timeout_handle = self.add_timeout(self.time() + timeout, self.stop)
self.start()
if timeout is not None:
self.remove_timeout(timeout_handle)
if not future_cell[0].done():
raise TimeoutError('Operation timed out after %s seconds' % timeout)
return future_cell[0].result()
def time(self):
"""Returns the current time according to the `IOLoop`'s clock.
The return value is a floating-point number relative to an
unspecified time in the past.
By default, the `IOLoop`'s time function is `time.time`. However,
it may be configured to use e.g. `time.monotonic` instead.
Calls to `add_timeout` that pass a number instead of a
`datetime.timedelta` should use this function to compute the
appropriate time, so they can work no matter what time function
is chosen.
"""
return time.time()
def add_timeout(self, deadline, callback):
"""Runs the ``callback`` at the time ``deadline`` from the I/O loop.
Returns an opaque handle that may be passed to
`remove_timeout` to cancel.
``deadline`` may be a number denoting a time (on the same
scale as `IOLoop.time`, normally `time.time`), or a
`datetime.timedelta` object for a deadline relative to the
current time.
Note that it is not safe to call `add_timeout` from other threads.
Instead, you must use `add_callback` to transfer control to the
`IOLoop`'s thread, and then call `add_timeout` from there.
"""
raise NotImplementedError()
def remove_timeout(self, timeout):
"""Cancels a pending timeout.
The argument is a handle as returned by `add_timeout`. It is
safe to call `remove_timeout` even if the callback has already
been run.
"""
raise NotImplementedError()
def add_callback(self, callback, *args, **kwargs):
"""Calls the given callback on the next I/O loop iteration.
It is safe to call this method from any thread at any time,
except from a signal handler. Note that this is the **only**
method in `IOLoop` that makes this thread-safety guarantee; all
other interaction with the `IOLoop` must be done from that
`IOLoop`'s thread. `add_callback()` may be used to transfer
control from other threads to the `IOLoop`'s thread.
To add a callback from a signal handler, see
`add_callback_from_signal`.
"""
raise NotImplementedError()
def add_callback_from_signal(self, callback, *args, **kwargs):
"""Calls the given callback on the next I/O loop iteration.
Safe for use from a Python signal handler; should not be used
otherwise.
Callbacks added with this method will be run without any
`.stack_context`, to avoid picking up the context of the function
that was interrupted by the signal.
"""
raise NotImplementedError()
def add_future(self, future, callback):
"""Schedules a callback on the ``IOLoop`` when the given
`.Future` is finished.
The callback is invoked with one argument, the
`.Future`.
"""
assert isinstance(future, Future)
callback = stack_context.wrap(callback)
future.add_done_callback(
lambda future: self.add_callback(callback, future))
def _run_callback(self, callback):
"""Runs a callback with error handling.
For use in subclasses.
"""
try:
callback()
except Exception:
self.handle_callback_exception(callback)
def handle_callback_exception(self, callback):
"""This method is called whenever a callback run by the `IOLoop`
throws an exception.
By default simply logs the exception as an error. Subclasses
may override this method to customize reporting of exceptions.
The exception itself is not passed explicitly, but is available
in `sys.exc_info`.
"""
app_log.error("Exception in callback %r", callback, exc_info=True)
class PollIOLoop(IOLoop):
"""Base class for IOLoops built around a select-like function.
For concrete implementations, see `tornado.platform.epoll.EPollIOLoop`
(Linux), `tornado.platform.kqueue.KQueueIOLoop` (BSD and Mac), or
`tornado.platform.select.SelectIOLoop` (all platforms).
"""
def initialize(self, impl, time_func=None):
super(PollIOLoop, self).initialize()
self._impl = impl
if hasattr(self._impl, 'fileno'):
set_close_exec(self._impl.fileno())
self.time_func = time_func or time.time
self._handlers = {}
self._events = {}
self._callbacks = []
self._callback_lock = threading.Lock()
self._timeouts = []
self._cancellations = 0
self._running = False
self._stopped = False
self._closing = False
self._thread_ident = None
self._blocking_signal_threshold = None
# Create a pipe that we send bogus data to when we want to wake
# the I/O loop when it is idle
self._waker = Waker()
self.add_handler(self._waker.fileno(),
lambda fd, events: self._waker.consume(),
self.READ)
def close(self, all_fds=False):
with self._callback_lock:
self._closing = True
self.remove_handler(self._waker.fileno())
if all_fds:
for fd in self._handlers.keys():
try:
close_method = getattr(fd, 'close', None)
if close_method is not None:
close_method()
else:
os.close(fd)
except Exception:
gen_log.debug("error closing fd %s", fd, exc_info=True)
self._waker.close()
self._impl.close()
def add_handler(self, fd, handler, events):
self._handlers[fd] = stack_context.wrap(handler)
self._impl.register(fd, events | self.ERROR)
def update_handler(self, fd, events):
self._impl.modify(fd, events | self.ERROR)
def remove_handler(self, fd):
self._handlers.pop(fd, None)
self._events.pop(fd, None)
try:
self._impl.unregister(fd)
except Exception:
gen_log.debug("Error deleting fd from IOLoop", exc_info=True)
def set_blocking_signal_threshold(self, seconds, action):
if not hasattr(signal, "setitimer"):
gen_log.error("set_blocking_signal_threshold requires a signal module "
"with the setitimer method")
return
self._blocking_signal_threshold = seconds
if seconds is not None:
signal.signal(signal.SIGALRM,
action if action is not None else signal.SIG_DFL)
def start(self):
if not logging.getLogger().handlers:
# The IOLoop catches and logs exceptions, so it's
# important that log output be visible. However, python's
# default behavior for non-root loggers (prior to python
# 3.2) is to print an unhelpful "no handlers could be
# found" message rather than the actual log entry, so we
# must explicitly configure logging if we've made it this
# far without anything.
logging.basicConfig()
if self._stopped:
self._stopped = False
return
old_current = getattr(IOLoop._current, "instance", None)
IOLoop._current.instance = self
self._thread_ident = thread.get_ident()
self._running = True
# signal.set_wakeup_fd closes a race condition in event loops:
# a signal may arrive at the beginning of select/poll/etc
# before it goes into its interruptible sleep, so the signal
# will be consumed without waking the select. The solution is
# for the (C, synchronous) signal handler to write to a pipe,
# which will then be seen by select.
#
# In python's signal handling semantics, this only matters on the
# main thread (fortunately, set_wakeup_fd only works on the main
# thread and will raise a ValueError otherwise).
#
# If someone has already set a wakeup fd, we don't want to
# disturb it. This is an issue for twisted, which does its
# SIGCHILD processing in response to its own wakeup fd being
# written to. As long as the wakeup fd is registered on the IOLoop,
# the loop will still wake up and everything should work.
old_wakeup_fd = None
if hasattr(signal, 'set_wakeup_fd') and os.name == 'posix':
# requires python 2.6+, unix. set_wakeup_fd exists but crashes
# the python process on windows.
try:
old_wakeup_fd = signal.set_wakeup_fd(self._waker.write_fileno())
if old_wakeup_fd != -1:
# Already set, restore previous value. This is a little racy,
# but there's no clean get_wakeup_fd and in real use the
# IOLoop is just started once at the beginning.
signal.set_wakeup_fd(old_wakeup_fd)
old_wakeup_fd = None
except ValueError: # non-main thread
pass
while True:
poll_timeout = 3600.0
# Prevent IO event starvation by delaying new callbacks
# to the next iteration of the event loop.
with self._callback_lock:
callbacks = self._callbacks
self._callbacks = []
for callback in callbacks:
self._run_callback(callback)
if self._timeouts:
now = self.time()
while self._timeouts:
if self._timeouts[0].callback is None:
# the timeout was cancelled
heapq.heappop(self._timeouts)
self._cancellations -= 1
elif self._timeouts[0].deadline <= now:
timeout = heapq.heappop(self._timeouts)
self._run_callback(timeout.callback)
else:
seconds = self._timeouts[0].deadline - now
poll_timeout = min(seconds, poll_timeout)
break
if (self._cancellations > 512
and self._cancellations > (len(self._timeouts) >> 1)):
# Clean up the timeout queue when it gets large and it's
# more than half cancellations.
self._cancellations = 0
self._timeouts = [x for x in self._timeouts
if x.callback is not None]
heapq.heapify(self._timeouts)
if self._callbacks:
# If any callbacks or timeouts called add_callback,
# we don't want to wait in poll() before we run them.
poll_timeout = 0.0
if not self._running:
break
if self._blocking_signal_threshold is not None:
# clear alarm so it doesn't fire while poll is waiting for
# events.
signal.setitimer(signal.ITIMER_REAL, 0, 0)
try:
event_pairs = self._impl.poll(poll_timeout)
except Exception as e:
# Depending on python version and IOLoop implementation,
# different exception types may be thrown and there are
# two ways EINTR might be signaled:
# * e.errno == errno.EINTR
# * e.args is like (errno.EINTR, 'Interrupted system call')
if (getattr(e, 'errno', None) == errno.EINTR or
(isinstance(getattr(e, 'args', None), tuple) and
len(e.args) == 2 and e.args[0] == errno.EINTR)):
continue
else:
raise
if self._blocking_signal_threshold is not None:
signal.setitimer(signal.ITIMER_REAL,
self._blocking_signal_threshold, 0)
# Pop one fd at a time from the set of pending fds and run
# its handler. Since that handler may perform actions on
# other file descriptors, there may be reentrant calls to
# this IOLoop that update self._events
self._events.update(event_pairs)
while self._events:
fd, events = self._events.popitem()
try:
self._handlers[fd](fd, events)
except (OSError, IOError) as e:
if e.args[0] == errno.EPIPE:
# Happens when the client closes the connection
pass
else:
app_log.error("Exception in I/O handler for fd %s",
fd, exc_info=True)
except Exception:
app_log.error("Exception in I/O handler for fd %s",
fd, exc_info=True)
# reset the stopped flag so another start/stop pair can be issued
self._stopped = False
if self._blocking_signal_threshold is not None:
signal.setitimer(signal.ITIMER_REAL, 0, 0)
IOLoop._current.instance = old_current
if old_wakeup_fd is not None:
signal.set_wakeup_fd(old_wakeup_fd)
def stop(self):
self._running = False
self._stopped = True
self._waker.wake()
def time(self):
return self.time_func()
def add_timeout(self, deadline, callback):
timeout = _Timeout(deadline, stack_context.wrap(callback), self)
heapq.heappush(self._timeouts, timeout)
return timeout
def remove_timeout(self, timeout):
# Removing from a heap is complicated, so just leave the defunct
# timeout object in the queue (see discussion in
# http://docs.python.org/library/heapq.html).
# If this turns out to be a problem, we could add a garbage
# collection pass whenever there are too many dead timeouts.
timeout.callback = None
self._cancellations += 1
def add_callback(self, callback, *args, **kwargs):
with self._callback_lock:
if self._closing:
raise RuntimeError("IOLoop is closing")
list_empty = not self._callbacks
self._callbacks.append(functools.partial(
stack_context.wrap(callback), *args, **kwargs))
if list_empty and thread.get_ident() != self._thread_ident:
# If we're in the IOLoop's thread, we know it's not currently
# polling. If we're not, and we added the first callback to an
# empty list, we may need to wake it up (it may wake up on its
# own, but an occasional extra wake is harmless). Waking
# up a polling IOLoop is relatively expensive, so we try to
# avoid it when we can.
self._waker.wake()
def add_callback_from_signal(self, callback, *args, **kwargs):
with stack_context.NullContext():
if thread.get_ident() != self._thread_ident:
# if the signal is handled on another thread, we can add
# it normally (modulo the NullContext)
self.add_callback(callback, *args, **kwargs)
else:
# If we're on the IOLoop's thread, we cannot use
# the regular add_callback because it may deadlock on
# _callback_lock. Blindly insert into self._callbacks.
# This is safe because the GIL makes list.append atomic.
# One subtlety is that if the signal interrupted the
# _callback_lock block in IOLoop.start, we may modify
# either the old or new version of self._callbacks,
# but either way will work.
self._callbacks.append(functools.partial(
stack_context.wrap(callback), *args, **kwargs))
class _Timeout(object):
"""An IOLoop timeout, a UNIX timestamp and a callback"""
# Reduce memory overhead when there are lots of pending callbacks
__slots__ = ['deadline', 'callback']
def __init__(self, deadline, callback, io_loop):
if isinstance(deadline, numbers.Real):
self.deadline = deadline
elif isinstance(deadline, datetime.timedelta):
self.deadline = io_loop.time() + _Timeout.timedelta_to_seconds(deadline)
else:
raise TypeError("Unsupported deadline %r" % deadline)
self.callback = callback
@staticmethod
def timedelta_to_seconds(td):
"""Equivalent to td.total_seconds() (introduced in python 2.7)."""
return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10 ** 6) / float(10 ** 6)
# Comparison methods to sort by deadline, with object id as a tiebreaker
# to guarantee a consistent ordering. The heapq module uses __le__
# in python2.5, and __lt__ in 2.6+ (sort() and most other comparisons
# use __lt__).
def __lt__(self, other):
return ((self.deadline, id(self)) <
(other.deadline, id(other)))
def __le__(self, other):
return ((self.deadline, id(self)) <=
(other.deadline, id(other)))
class PeriodicCallback(object):
"""Schedules the given callback to be called periodically.
The callback is called every ``callback_time`` milliseconds.
`start` must be called after the `PeriodicCallback` is created.
"""
def __init__(self, callback, callback_time, io_loop=None):
self.callback = callback
if callback_time <= 0:
raise ValueError("Periodic callback must have a positive callback_time")
self.callback_time = callback_time
self.io_loop = io_loop or IOLoop.current()
self._running = False
self._timeout = None
def start(self):
"""Starts the timer."""
self._running = True
self._next_timeout = self.io_loop.time()
self._schedule_next()
def stop(self):
"""Stops the timer."""
self._running = False
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
self._timeout = None
def _run(self):
if not self._running:
return
try:
self.callback()
except Exception:
app_log.error("Error in periodic callback", exc_info=True)
self._schedule_next()
def _schedule_next(self):
if self._running:
current_time = self.io_loop.time()
while self._next_timeout <= current_time:
self._next_timeout += self.callback_time / 1000.0
self._timeout = self.io_loop.add_timeout(self._next_timeout, self._run)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,513 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Translation methods for generating localized strings.
To load a locale and generate a translated string::
user_locale = tornado.locale.get("es_LA")
print user_locale.translate("Sign out")
`tornado.locale.get()` returns the closest matching locale, not necessarily the
specific locale you requested. You can support pluralization with
additional arguments to `~Locale.translate()`, e.g.::
people = [...]
message = user_locale.translate(
"%(list)s is online", "%(list)s are online", len(people))
print message % {"list": user_locale.list(people)}
The first string is chosen if ``len(people) == 1``, otherwise the second
string is chosen.
Applications should call one of `load_translations` (which uses a simple
CSV format) or `load_gettext_translations` (which uses the ``.mo`` format
supported by `gettext` and related tools). If neither method is called,
the `Locale.translate` method will simply return the original string.
"""
from __future__ import absolute_import, division, print_function, with_statement
import csv
import datetime
import numbers
import os
import re
from tornado import escape
from tornado.log import gen_log
from tornado.util import u
_default_locale = "en_US"
_translations = {}
_supported_locales = frozenset([_default_locale])
_use_gettext = False
def get(*locale_codes):
"""Returns the closest match for the given locale codes.
We iterate over all given locale codes in order. If we have a tight
or a loose match for the code (e.g., "en" for "en_US"), we return
the locale. Otherwise we move to the next code in the list.
By default we return ``en_US`` if no translations are found for any of
the specified locales. You can change the default locale with
`set_default_locale()`.
"""
return Locale.get_closest(*locale_codes)
def set_default_locale(code):
"""Sets the default locale.
The default locale is assumed to be the language used for all strings
in the system. The translations loaded from disk are mappings from
the default locale to the destination locale. Consequently, you don't
need to create a translation file for the default locale.
"""
global _default_locale
global _supported_locales
_default_locale = code
_supported_locales = frozenset(list(_translations.keys()) + [_default_locale])
def load_translations(directory):
"""Loads translations from CSV files in a directory.
Translations are strings with optional Python-style named placeholders
(e.g., ``My name is %(name)s``) and their associated translations.
The directory should have translation files of the form ``LOCALE.csv``,
e.g. ``es_GT.csv``. The CSV files should have two or three columns: string,
translation, and an optional plural indicator. Plural indicators should
be one of "plural" or "singular". A given string can have both singular
and plural forms. For example ``%(name)s liked this`` may have a
different verb conjugation depending on whether %(name)s is one
name or a list of names. There should be two rows in the CSV file for
that string, one with plural indicator "singular", and one "plural".
For strings with no verbs that would change on translation, simply
use "unknown" or the empty string (or don't include the column at all).
The file is read using the `csv` module in the default "excel" dialect.
In this format there should not be spaces after the commas.
Example translation ``es_LA.csv``::
"I love you","Te amo"
"%(name)s liked this","A %(name)s les gustó esto","plural"
"%(name)s liked this","A %(name)s le gustó esto","singular"
"""
global _translations
global _supported_locales
_translations = {}
for path in os.listdir(directory):
if not path.endswith(".csv"):
continue
locale, extension = path.split(".")
if not re.match("[a-z]+(_[A-Z]+)?$", locale):
gen_log.error("Unrecognized locale %r (path: %s)", locale,
os.path.join(directory, path))
continue
full_path = os.path.join(directory, path)
try:
# python 3: csv.reader requires a file open in text mode.
# Force utf8 to avoid dependence on $LANG environment variable.
f = open(full_path, "r", encoding="utf-8")
except TypeError:
# python 2: files return byte strings, which are decoded below.
f = open(full_path, "r")
_translations[locale] = {}
for i, row in enumerate(csv.reader(f)):
if not row or len(row) < 2:
continue
row = [escape.to_unicode(c).strip() for c in row]
english, translation = row[:2]
if len(row) > 2:
plural = row[2] or "unknown"
else:
plural = "unknown"
if plural not in ("plural", "singular", "unknown"):
gen_log.error("Unrecognized plural indicator %r in %s line %d",
plural, path, i + 1)
continue
_translations[locale].setdefault(plural, {})[english] = translation
f.close()
_supported_locales = frozenset(list(_translations.keys()) + [_default_locale])
gen_log.debug("Supported locales: %s", sorted(_supported_locales))
def load_gettext_translations(directory, domain):
"""Loads translations from `gettext`'s locale tree
Locale tree is similar to system's ``/usr/share/locale``, like::
{directory}/{lang}/LC_MESSAGES/{domain}.mo
Three steps are required to have you app translated:
1. Generate POT translation file::
xgettext --language=Python --keyword=_:1,2 -d mydomain file1.py file2.html etc
2. Merge against existing POT file::
msgmerge old.po mydomain.po > new.po
3. Compile::
msgfmt mydomain.po -o {directory}/pt_BR/LC_MESSAGES/mydomain.mo
"""
import gettext
global _translations
global _supported_locales
global _use_gettext
_translations = {}
for lang in os.listdir(directory):
if lang.startswith('.'):
continue # skip .svn, etc
if os.path.isfile(os.path.join(directory, lang)):
continue
try:
os.stat(os.path.join(directory, lang, "LC_MESSAGES", domain + ".mo"))
_translations[lang] = gettext.translation(domain, directory,
languages=[lang])
except Exception as e:
gen_log.error("Cannot load translation for '%s': %s", lang, str(e))
continue
_supported_locales = frozenset(list(_translations.keys()) + [_default_locale])
_use_gettext = True
gen_log.debug("Supported locales: %s", sorted(_supported_locales))
def get_supported_locales():
"""Returns a list of all the supported locale codes."""
return _supported_locales
class Locale(object):
"""Object representing a locale.
After calling one of `load_translations` or `load_gettext_translations`,
call `get` or `get_closest` to get a Locale object.
"""
@classmethod
def get_closest(cls, *locale_codes):
"""Returns the closest match for the given locale code."""
for code in locale_codes:
if not code:
continue
code = code.replace("-", "_")
parts = code.split("_")
if len(parts) > 2:
continue
elif len(parts) == 2:
code = parts[0].lower() + "_" + parts[1].upper()
if code in _supported_locales:
return cls.get(code)
if parts[0].lower() in _supported_locales:
return cls.get(parts[0].lower())
return cls.get(_default_locale)
@classmethod
def get(cls, code):
"""Returns the Locale for the given locale code.
If it is not supported, we raise an exception.
"""
if not hasattr(cls, "_cache"):
cls._cache = {}
if code not in cls._cache:
assert code in _supported_locales
translations = _translations.get(code, None)
if translations is None:
locale = CSVLocale(code, {})
elif _use_gettext:
locale = GettextLocale(code, translations)
else:
locale = CSVLocale(code, translations)
cls._cache[code] = locale
return cls._cache[code]
def __init__(self, code, translations):
self.code = code
self.name = LOCALE_NAMES.get(code, {}).get("name", u("Unknown"))
self.rtl = False
for prefix in ["fa", "ar", "he"]:
if self.code.startswith(prefix):
self.rtl = True
break
self.translations = translations
# Initialize strings for date formatting
_ = self.translate
self._months = [
_("January"), _("February"), _("March"), _("April"),
_("May"), _("June"), _("July"), _("August"),
_("September"), _("October"), _("November"), _("December")]
self._weekdays = [
_("Monday"), _("Tuesday"), _("Wednesday"), _("Thursday"),
_("Friday"), _("Saturday"), _("Sunday")]
def translate(self, message, plural_message=None, count=None):
"""Returns the translation for the given message for this locale.
If ``plural_message`` is given, you must also provide
``count``. We return ``plural_message`` when ``count != 1``,
and we return the singular form for the given message when
``count == 1``.
"""
raise NotImplementedError()
def format_date(self, date, gmt_offset=0, relative=True, shorter=False,
full_format=False):
"""Formats the given date (which should be GMT).
By default, we return a relative time (e.g., "2 minutes ago"). You
can return an absolute date string with ``relative=False``.
You can force a full format date ("July 10, 1980") with
``full_format=True``.
This method is primarily intended for dates in the past.
For dates in the future, we fall back to full format.
"""
if self.code.startswith("ru"):
relative = False
if isinstance(date, numbers.Real):
date = datetime.datetime.utcfromtimestamp(date)
now = datetime.datetime.utcnow()
if date > now:
if relative and (date - now).seconds < 60:
# Due to click skew, things are some things slightly
# in the future. Round timestamps in the immediate
# future down to now in relative mode.
date = now
else:
# Otherwise, future dates always use the full format.
full_format = True
local_date = date - datetime.timedelta(minutes=gmt_offset)
local_now = now - datetime.timedelta(minutes=gmt_offset)
local_yesterday = local_now - datetime.timedelta(hours=24)
difference = now - date
seconds = difference.seconds
days = difference.days
_ = self.translate
format = None
if not full_format:
if relative and days == 0:
if seconds < 50:
return _("1 second ago", "%(seconds)d seconds ago",
seconds) % {"seconds": seconds}
if seconds < 50 * 60:
minutes = round(seconds / 60.0)
return _("1 minute ago", "%(minutes)d minutes ago",
minutes) % {"minutes": minutes}
hours = round(seconds / (60.0 * 60))
return _("1 hour ago", "%(hours)d hours ago",
hours) % {"hours": hours}
if days == 0:
format = _("%(time)s")
elif days == 1 and local_date.day == local_yesterday.day and \
relative:
format = _("yesterday") if shorter else \
_("yesterday at %(time)s")
elif days < 5:
format = _("%(weekday)s") if shorter else \
_("%(weekday)s at %(time)s")
elif days < 334: # 11mo, since confusing for same month last year
format = _("%(month_name)s %(day)s") if shorter else \
_("%(month_name)s %(day)s at %(time)s")
if format is None:
format = _("%(month_name)s %(day)s, %(year)s") if shorter else \
_("%(month_name)s %(day)s, %(year)s at %(time)s")
tfhour_clock = self.code not in ("en", "en_US", "zh_CN")
if tfhour_clock:
str_time = "%d:%02d" % (local_date.hour, local_date.minute)
elif self.code == "zh_CN":
str_time = "%s%d:%02d" % (
(u('\u4e0a\u5348'), u('\u4e0b\u5348'))[local_date.hour >= 12],
local_date.hour % 12 or 12, local_date.minute)
else:
str_time = "%d:%02d %s" % (
local_date.hour % 12 or 12, local_date.minute,
("am", "pm")[local_date.hour >= 12])
return format % {
"month_name": self._months[local_date.month - 1],
"weekday": self._weekdays[local_date.weekday()],
"day": str(local_date.day),
"year": str(local_date.year),
"time": str_time
}
def format_day(self, date, gmt_offset=0, dow=True):
"""Formats the given date as a day of week.
Example: "Monday, January 22". You can remove the day of week with
``dow=False``.
"""
local_date = date - datetime.timedelta(minutes=gmt_offset)
_ = self.translate
if dow:
return _("%(weekday)s, %(month_name)s %(day)s") % {
"month_name": self._months[local_date.month - 1],
"weekday": self._weekdays[local_date.weekday()],
"day": str(local_date.day),
}
else:
return _("%(month_name)s %(day)s") % {
"month_name": self._months[local_date.month - 1],
"day": str(local_date.day),
}
def list(self, parts):
"""Returns a comma-separated list for the given list of parts.
The format is, e.g., "A, B and C", "A and B" or just "A" for lists
of size 1.
"""
_ = self.translate
if len(parts) == 0:
return ""
if len(parts) == 1:
return parts[0]
comma = u(' \u0648 ') if self.code.startswith("fa") else u(", ")
return _("%(commas)s and %(last)s") % {
"commas": comma.join(parts[:-1]),
"last": parts[len(parts) - 1],
}
def friendly_number(self, value):
"""Returns a comma-separated number for the given integer."""
if self.code not in ("en", "en_US"):
return str(value)
value = str(value)
parts = []
while value:
parts.append(value[-3:])
value = value[:-3]
return ",".join(reversed(parts))
class CSVLocale(Locale):
"""Locale implementation using tornado's CSV translation format."""
def translate(self, message, plural_message=None, count=None):
if plural_message is not None:
assert count is not None
if count != 1:
message = plural_message
message_dict = self.translations.get("plural", {})
else:
message_dict = self.translations.get("singular", {})
else:
message_dict = self.translations.get("unknown", {})
return message_dict.get(message, message)
class GettextLocale(Locale):
"""Locale implementation using the `gettext` module."""
def __init__(self, code, translations):
try:
# python 2
self.ngettext = translations.ungettext
self.gettext = translations.ugettext
except AttributeError:
# python 3
self.ngettext = translations.ngettext
self.gettext = translations.gettext
# self.gettext must exist before __init__ is called, since it
# calls into self.translate
super(GettextLocale, self).__init__(code, translations)
def translate(self, message, plural_message=None, count=None):
if plural_message is not None:
assert count is not None
return self.ngettext(message, plural_message, count)
else:
return self.gettext(message)
LOCALE_NAMES = {
"af_ZA": {"name_en": u("Afrikaans"), "name": u("Afrikaans")},
"am_ET": {"name_en": u("Amharic"), "name": u('\u12a0\u121b\u122d\u129b')},
"ar_AR": {"name_en": u("Arabic"), "name": u("\u0627\u0644\u0639\u0631\u0628\u064a\u0629")},
"bg_BG": {"name_en": u("Bulgarian"), "name": u("\u0411\u044a\u043b\u0433\u0430\u0440\u0441\u043a\u0438")},
"bn_IN": {"name_en": u("Bengali"), "name": u("\u09ac\u09be\u0982\u09b2\u09be")},
"bs_BA": {"name_en": u("Bosnian"), "name": u("Bosanski")},
"ca_ES": {"name_en": u("Catalan"), "name": u("Catal\xe0")},
"cs_CZ": {"name_en": u("Czech"), "name": u("\u010ce\u0161tina")},
"cy_GB": {"name_en": u("Welsh"), "name": u("Cymraeg")},
"da_DK": {"name_en": u("Danish"), "name": u("Dansk")},
"de_DE": {"name_en": u("German"), "name": u("Deutsch")},
"el_GR": {"name_en": u("Greek"), "name": u("\u0395\u03bb\u03bb\u03b7\u03bd\u03b9\u03ba\u03ac")},
"en_GB": {"name_en": u("English (UK)"), "name": u("English (UK)")},
"en_US": {"name_en": u("English (US)"), "name": u("English (US)")},
"es_ES": {"name_en": u("Spanish (Spain)"), "name": u("Espa\xf1ol (Espa\xf1a)")},
"es_LA": {"name_en": u("Spanish"), "name": u("Espa\xf1ol")},
"et_EE": {"name_en": u("Estonian"), "name": u("Eesti")},
"eu_ES": {"name_en": u("Basque"), "name": u("Euskara")},
"fa_IR": {"name_en": u("Persian"), "name": u("\u0641\u0627\u0631\u0633\u06cc")},
"fi_FI": {"name_en": u("Finnish"), "name": u("Suomi")},
"fr_CA": {"name_en": u("French (Canada)"), "name": u("Fran\xe7ais (Canada)")},
"fr_FR": {"name_en": u("French"), "name": u("Fran\xe7ais")},
"ga_IE": {"name_en": u("Irish"), "name": u("Gaeilge")},
"gl_ES": {"name_en": u("Galician"), "name": u("Galego")},
"he_IL": {"name_en": u("Hebrew"), "name": u("\u05e2\u05d1\u05e8\u05d9\u05ea")},
"hi_IN": {"name_en": u("Hindi"), "name": u("\u0939\u093f\u0928\u094d\u0926\u0940")},
"hr_HR": {"name_en": u("Croatian"), "name": u("Hrvatski")},
"hu_HU": {"name_en": u("Hungarian"), "name": u("Magyar")},
"id_ID": {"name_en": u("Indonesian"), "name": u("Bahasa Indonesia")},
"is_IS": {"name_en": u("Icelandic"), "name": u("\xcdslenska")},
"it_IT": {"name_en": u("Italian"), "name": u("Italiano")},
"ja_JP": {"name_en": u("Japanese"), "name": u("\u65e5\u672c\u8a9e")},
"ko_KR": {"name_en": u("Korean"), "name": u("\ud55c\uad6d\uc5b4")},
"lt_LT": {"name_en": u("Lithuanian"), "name": u("Lietuvi\u0173")},
"lv_LV": {"name_en": u("Latvian"), "name": u("Latvie\u0161u")},
"mk_MK": {"name_en": u("Macedonian"), "name": u("\u041c\u0430\u043a\u0435\u0434\u043e\u043d\u0441\u043a\u0438")},
"ml_IN": {"name_en": u("Malayalam"), "name": u("\u0d2e\u0d32\u0d2f\u0d3e\u0d33\u0d02")},
"ms_MY": {"name_en": u("Malay"), "name": u("Bahasa Melayu")},
"nb_NO": {"name_en": u("Norwegian (bokmal)"), "name": u("Norsk (bokm\xe5l)")},
"nl_NL": {"name_en": u("Dutch"), "name": u("Nederlands")},
"nn_NO": {"name_en": u("Norwegian (nynorsk)"), "name": u("Norsk (nynorsk)")},
"pa_IN": {"name_en": u("Punjabi"), "name": u("\u0a2a\u0a70\u0a1c\u0a3e\u0a2c\u0a40")},
"pl_PL": {"name_en": u("Polish"), "name": u("Polski")},
"pt_BR": {"name_en": u("Portuguese (Brazil)"), "name": u("Portugu\xeas (Brasil)")},
"pt_PT": {"name_en": u("Portuguese (Portugal)"), "name": u("Portugu\xeas (Portugal)")},
"ro_RO": {"name_en": u("Romanian"), "name": u("Rom\xe2n\u0103")},
"ru_RU": {"name_en": u("Russian"), "name": u("\u0420\u0443\u0441\u0441\u043a\u0438\u0439")},
"sk_SK": {"name_en": u("Slovak"), "name": u("Sloven\u010dina")},
"sl_SI": {"name_en": u("Slovenian"), "name": u("Sloven\u0161\u010dina")},
"sq_AL": {"name_en": u("Albanian"), "name": u("Shqip")},
"sr_RS": {"name_en": u("Serbian"), "name": u("\u0421\u0440\u043f\u0441\u043a\u0438")},
"sv_SE": {"name_en": u("Swedish"), "name": u("Svenska")},
"sw_KE": {"name_en": u("Swahili"), "name": u("Kiswahili")},
"ta_IN": {"name_en": u("Tamil"), "name": u("\u0ba4\u0bae\u0bbf\u0bb4\u0bcd")},
"te_IN": {"name_en": u("Telugu"), "name": u("\u0c24\u0c46\u0c32\u0c41\u0c17\u0c41")},
"th_TH": {"name_en": u("Thai"), "name": u("\u0e20\u0e32\u0e29\u0e32\u0e44\u0e17\u0e22")},
"tl_PH": {"name_en": u("Filipino"), "name": u("Filipino")},
"tr_TR": {"name_en": u("Turkish"), "name": u("T\xfcrk\xe7e")},
"uk_UA": {"name_en": u("Ukraini "), "name": u("\u0423\u043a\u0440\u0430\u0457\u043d\u0441\u044c\u043a\u0430")},
"vi_VN": {"name_en": u("Vietnamese"), "name": u("Ti\u1ebfng Vi\u1ec7t")},
"zh_CN": {"name_en": u("Chinese (Simplified)"), "name": u("\u4e2d\u6587(\u7b80\u4f53)")},
"zh_TW": {"name_en": u("Chinese (Traditional)"), "name": u("\u4e2d\u6587(\u7e41\u9ad4)")},
}

View file

@ -0,0 +1,205 @@
#!/usr/bin/env python
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Logging support for Tornado.
Tornado uses three logger streams:
* ``tornado.access``: Per-request logging for Tornado's HTTP servers (and
potentially other servers in the future)
* ``tornado.application``: Logging of errors from application code (i.e.
uncaught exceptions from callbacks)
* ``tornado.general``: General-purpose logging, including any errors
or warnings from Tornado itself.
These streams may be configured independently using the standard library's
`logging` module. For example, you may wish to send ``tornado.access`` logs
to a separate file for analysis.
"""
from __future__ import absolute_import, division, print_function, with_statement
import logging
import logging.handlers
import sys
import time
from tornado.escape import _unicode
from tornado.util import unicode_type, basestring_type
try:
import curses
except ImportError:
curses = None
# Logger objects for internal tornado use
access_log = logging.getLogger("tornado.access")
app_log = logging.getLogger("tornado.application")
gen_log = logging.getLogger("tornado.general")
def _stderr_supports_color():
color = False
if curses and sys.stderr.isatty():
try:
curses.setupterm()
if curses.tigetnum("colors") > 0:
color = True
except Exception:
pass
return color
class LogFormatter(logging.Formatter):
"""Log formatter used in Tornado.
Key features of this formatter are:
* Color support when logging to a terminal that supports it.
* Timestamps on every log line.
* Robust against str/bytes encoding problems.
This formatter is enabled automatically by
`tornado.options.parse_command_line` (unless ``--logging=none`` is
used).
"""
def __init__(self, color=True, *args, **kwargs):
logging.Formatter.__init__(self, *args, **kwargs)
self._color = color and _stderr_supports_color()
if self._color:
# The curses module has some str/bytes confusion in
# python3. Until version 3.2.3, most methods return
# bytes, but only accept strings. In addition, we want to
# output these strings with the logging module, which
# works with unicode strings. The explicit calls to
# unicode() below are harmless in python2 but will do the
# right conversion in python 3.
fg_color = (curses.tigetstr("setaf") or
curses.tigetstr("setf") or "")
if (3, 0) < sys.version_info < (3, 2, 3):
fg_color = unicode_type(fg_color, "ascii")
self._colors = {
logging.DEBUG: unicode_type(curses.tparm(fg_color, 4), # Blue
"ascii"),
logging.INFO: unicode_type(curses.tparm(fg_color, 2), # Green
"ascii"),
logging.WARNING: unicode_type(curses.tparm(fg_color, 3), # Yellow
"ascii"),
logging.ERROR: unicode_type(curses.tparm(fg_color, 1), # Red
"ascii"),
}
self._normal = unicode_type(curses.tigetstr("sgr0"), "ascii")
def format(self, record):
try:
record.message = record.getMessage()
except Exception as e:
record.message = "Bad message (%r): %r" % (e, record.__dict__)
assert isinstance(record.message, basestring_type) # guaranteed by logging
record.asctime = time.strftime(
"%y%m%d %H:%M:%S", self.converter(record.created))
prefix = '[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]' % \
record.__dict__
if self._color:
prefix = (self._colors.get(record.levelno, self._normal) +
prefix + self._normal)
# Encoding notes: The logging module prefers to work with character
# strings, but only enforces that log messages are instances of
# basestring. In python 2, non-ascii bytestrings will make
# their way through the logging framework until they blow up with
# an unhelpful decoding error (with this formatter it happens
# when we attach the prefix, but there are other opportunities for
# exceptions further along in the framework).
#
# If a byte string makes it this far, convert it to unicode to
# ensure it will make it out to the logs. Use repr() as a fallback
# to ensure that all byte strings can be converted successfully,
# but don't do it by default so we don't add extra quotes to ascii
# bytestrings. This is a bit of a hacky place to do this, but
# it's worth it since the encoding errors that would otherwise
# result are so useless (and tornado is fond of using utf8-encoded
# byte strings whereever possible).
def safe_unicode(s):
try:
return _unicode(s)
except UnicodeDecodeError:
return repr(s)
formatted = prefix + " " + safe_unicode(record.message)
if record.exc_info:
if not record.exc_text:
record.exc_text = self.formatException(record.exc_info)
if record.exc_text:
# exc_text contains multiple lines. We need to safe_unicode
# each line separately so that non-utf8 bytes don't cause
# all the newlines to turn into '\n'.
lines = [formatted.rstrip()]
lines.extend(safe_unicode(ln) for ln in record.exc_text.split('\n'))
formatted = '\n'.join(lines)
return formatted.replace("\n", "\n ")
def enable_pretty_logging(options=None, logger=None):
"""Turns on formatted logging output as configured.
This is called automaticaly by `tornado.options.parse_command_line`
and `tornado.options.parse_config_file`.
"""
if options is None:
from tornado.options import options
if options.logging == 'none':
return
if logger is None:
logger = logging.getLogger()
logger.setLevel(getattr(logging, options.logging.upper()))
if options.log_file_prefix:
channel = logging.handlers.RotatingFileHandler(
filename=options.log_file_prefix,
maxBytes=options.log_file_max_size,
backupCount=options.log_file_num_backups)
channel.setFormatter(LogFormatter(color=False))
logger.addHandler(channel)
if (options.log_to_stderr or
(options.log_to_stderr is None and not logger.handlers)):
# Set up color if we are in a tty and curses is installed
channel = logging.StreamHandler()
channel.setFormatter(LogFormatter())
logger.addHandler(channel)
def define_logging_options(options=None):
if options is None:
# late import to prevent cycle
from tornado.options import options
options.define("logging", default="info",
help=("Set the Python log level. If 'none', tornado won't touch the "
"logging configuration."),
metavar="debug|info|warning|error|none")
options.define("log_to_stderr", type=bool, default=None,
help=("Send log output to stderr (colorized if possible). "
"By default use stderr if --log_file_prefix is not set and "
"no other logging is configured."))
options.define("log_file_prefix", type=str, default=None, metavar="PATH",
help=("Path prefix for log files. "
"Note that if you are running multiple tornado processes, "
"log_file_prefix must be different for each of them (e.g. "
"include the port number)"))
options.define("log_file_max_size", type=int, default=100 * 1000 * 1000,
help="max size of log files before rollover")
options.define("log_file_num_backups", type=int, default=10,
help="number of log files to keep")
options.add_parse_callback(enable_pretty_logging)

View file

@ -0,0 +1,459 @@
#!/usr/bin/env python
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Miscellaneous network utility code."""
from __future__ import absolute_import, division, print_function, with_statement
import errno
import os
import re
import socket
import ssl
import stat
from tornado.concurrent import dummy_executor, run_on_executor
from tornado.ioloop import IOLoop
from tornado.platform.auto import set_close_exec
from tornado.util import Configurable
def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128, flags=None):
"""Creates listening sockets bound to the given port and address.
Returns a list of socket objects (multiple sockets are returned if
the given address maps to multiple IP addresses, which is most common
for mixed IPv4 and IPv6 use).
Address may be either an IP address or hostname. If it's a hostname,
the server will listen on all IP addresses associated with the
name. Address may be an empty string or None to listen on all
available interfaces. Family may be set to either `socket.AF_INET`
or `socket.AF_INET6` to restrict to IPv4 or IPv6 addresses, otherwise
both will be used if available.
The ``backlog`` argument has the same meaning as for
`socket.listen() <socket.socket.listen>`.
``flags`` is a bitmask of AI_* flags to `~socket.getaddrinfo`, like
``socket.AI_PASSIVE | socket.AI_NUMERICHOST``.
"""
sockets = []
if address == "":
address = None
if not socket.has_ipv6 and family == socket.AF_UNSPEC:
# Python can be compiled with --disable-ipv6, which causes
# operations on AF_INET6 sockets to fail, but does not
# automatically exclude those results from getaddrinfo
# results.
# http://bugs.python.org/issue16208
family = socket.AF_INET
if flags is None:
flags = socket.AI_PASSIVE
for res in set(socket.getaddrinfo(address, port, family, socket.SOCK_STREAM,
0, flags)):
af, socktype, proto, canonname, sockaddr = res
try:
sock = socket.socket(af, socktype, proto)
except socket.error as e:
if e.args[0] == errno.EAFNOSUPPORT:
continue
raise
set_close_exec(sock.fileno())
if os.name != 'nt':
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if af == socket.AF_INET6:
# On linux, ipv6 sockets accept ipv4 too by default,
# but this makes it impossible to bind to both
# 0.0.0.0 in ipv4 and :: in ipv6. On other systems,
# separate sockets *must* be used to listen for both ipv4
# and ipv6. For consistency, always disable ipv4 on our
# ipv6 sockets and use a separate ipv4 socket when needed.
#
# Python 2.x on windows doesn't have IPPROTO_IPV6.
if hasattr(socket, "IPPROTO_IPV6"):
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
sock.setblocking(0)
sock.bind(sockaddr)
sock.listen(backlog)
sockets.append(sock)
return sockets
if hasattr(socket, 'AF_UNIX'):
def bind_unix_socket(file, mode=0o600, backlog=128):
"""Creates a listening unix socket.
If a socket with the given name already exists, it will be deleted.
If any other file with that name exists, an exception will be
raised.
Returns a socket object (not a list of socket objects like
`bind_sockets`)
"""
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
set_close_exec(sock.fileno())
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setblocking(0)
try:
st = os.stat(file)
except OSError as err:
if err.errno != errno.ENOENT:
raise
else:
if stat.S_ISSOCK(st.st_mode):
os.remove(file)
else:
raise ValueError("File %s exists and is not a socket", file)
sock.bind(file)
os.chmod(file, mode)
sock.listen(backlog)
return sock
def add_accept_handler(sock, callback, io_loop=None):
"""Adds an `.IOLoop` event handler to accept new connections on ``sock``.
When a connection is accepted, ``callback(connection, address)`` will
be run (``connection`` is a socket object, and ``address`` is the
address of the other end of the connection). Note that this signature
is different from the ``callback(fd, events)`` signature used for
`.IOLoop` handlers.
"""
if io_loop is None:
io_loop = IOLoop.current()
def accept_handler(fd, events):
while True:
try:
connection, address = sock.accept()
except socket.error as e:
# EWOULDBLOCK and EAGAIN indicate we have accepted every
# connection that is available.
if e.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
return
# ECONNABORTED indicates that there was a connection
# but it was closed while still in the accept queue.
# (observed on FreeBSD).
if e.args[0] == errno.ECONNABORTED:
continue
raise
callback(connection, address)
io_loop.add_handler(sock.fileno(), accept_handler, IOLoop.READ)
def is_valid_ip(ip):
"""Returns true if the given string is a well-formed IP address.
Supports IPv4 and IPv6.
"""
try:
res = socket.getaddrinfo(ip, 0, socket.AF_UNSPEC,
socket.SOCK_STREAM,
0, socket.AI_NUMERICHOST)
return bool(res)
except socket.gaierror as e:
if e.args[0] == socket.EAI_NONAME:
return False
raise
return True
class Resolver(Configurable):
"""Configurable asynchronous DNS resolver interface.
By default, a blocking implementation is used (which simply calls
`socket.getaddrinfo`). An alternative implementation can be
chosen with the `Resolver.configure <.Configurable.configure>`
class method::
Resolver.configure('tornado.netutil.ThreadedResolver')
The implementations of this interface included with Tornado are
* `tornado.netutil.BlockingResolver`
* `tornado.netutil.ThreadedResolver`
* `tornado.netutil.OverrideResolver`
* `tornado.platform.twisted.TwistedResolver`
* `tornado.platform.caresresolver.CaresResolver`
"""
@classmethod
def configurable_base(cls):
return Resolver
@classmethod
def configurable_default(cls):
return BlockingResolver
def resolve(self, host, port, family=socket.AF_UNSPEC, callback=None):
"""Resolves an address.
The ``host`` argument is a string which may be a hostname or a
literal IP address.
Returns a `.Future` whose result is a list of (family,
address) pairs, where address is a tuple suitable to pass to
`socket.connect <socket.socket.connect>` (i.e. a ``(host,
port)`` pair for IPv4; additional fields may be present for
IPv6). If a ``callback`` is passed, it will be run with the
result as an argument when it is complete.
"""
raise NotImplementedError()
def close(self):
"""Closes the `Resolver`, freeing any resources used.
.. versionadded:: 3.1
"""
pass
class ExecutorResolver(Resolver):
"""Resolver implementation using a `concurrent.futures.Executor`.
Use this instead of `ThreadedResolver` when you require additional
control over the executor being used.
The executor will be shut down when the resolver is closed unless
``close_resolver=False``; use this if you want to reuse the same
executor elsewhere.
"""
def initialize(self, io_loop=None, executor=None, close_executor=True):
self.io_loop = io_loop or IOLoop.current()
if executor is not None:
self.executor = executor
self.close_executor = close_executor
else:
self.executor = dummy_executor
self.close_executor = False
def close(self):
if self.close_executor:
self.executor.shutdown()
self.executor = None
@run_on_executor
def resolve(self, host, port, family=socket.AF_UNSPEC):
# On Solaris, getaddrinfo fails if the given port is not found
# in /etc/services and no socket type is given, so we must pass
# one here. The socket type used here doesn't seem to actually
# matter (we discard the one we get back in the results),
# so the addresses we return should still be usable with SOCK_DGRAM.
addrinfo = socket.getaddrinfo(host, port, family, socket.SOCK_STREAM)
results = []
for family, socktype, proto, canonname, address in addrinfo:
results.append((family, address))
return results
class BlockingResolver(ExecutorResolver):
"""Default `Resolver` implementation, using `socket.getaddrinfo`.
The `.IOLoop` will be blocked during the resolution, although the
callback will not be run until the next `.IOLoop` iteration.
"""
def initialize(self, io_loop=None):
super(BlockingResolver, self).initialize(io_loop=io_loop)
class ThreadedResolver(ExecutorResolver):
"""Multithreaded non-blocking `Resolver` implementation.
Requires the `concurrent.futures` package to be installed
(available in the standard library since Python 3.2,
installable with ``pip install futures`` in older versions).
The thread pool size can be configured with::
Resolver.configure('tornado.netutil.ThreadedResolver',
num_threads=10)
.. versionchanged:: 3.1
All ``ThreadedResolvers`` share a single thread pool, whose
size is set by the first one to be created.
"""
_threadpool = None
_threadpool_pid = None
def initialize(self, io_loop=None, num_threads=10):
threadpool = ThreadedResolver._create_threadpool(num_threads)
super(ThreadedResolver, self).initialize(
io_loop=io_loop, executor=threadpool, close_executor=False)
@classmethod
def _create_threadpool(cls, num_threads):
pid = os.getpid()
if cls._threadpool_pid != pid:
# Threads cannot survive after a fork, so if our pid isn't what it
# was when we created the pool then delete it.
cls._threadpool = None
if cls._threadpool is None:
from concurrent.futures import ThreadPoolExecutor
cls._threadpool = ThreadPoolExecutor(num_threads)
cls._threadpool_pid = pid
return cls._threadpool
class OverrideResolver(Resolver):
"""Wraps a resolver with a mapping of overrides.
This can be used to make local DNS changes (e.g. for testing)
without modifying system-wide settings.
The mapping can contain either host strings or host-port pairs.
"""
def initialize(self, resolver, mapping):
self.resolver = resolver
self.mapping = mapping
def close(self):
self.resolver.close()
def resolve(self, host, port, *args, **kwargs):
if (host, port) in self.mapping:
host, port = self.mapping[(host, port)]
elif host in self.mapping:
host = self.mapping[host]
return self.resolver.resolve(host, port, *args, **kwargs)
# These are the keyword arguments to ssl.wrap_socket that must be translated
# to their SSLContext equivalents (the other arguments are still passed
# to SSLContext.wrap_socket).
_SSL_CONTEXT_KEYWORDS = frozenset(['ssl_version', 'certfile', 'keyfile',
'cert_reqs', 'ca_certs', 'ciphers'])
def ssl_options_to_context(ssl_options):
"""Try to convert an ``ssl_options`` dictionary to an
`~ssl.SSLContext` object.
The ``ssl_options`` dictionary contains keywords to be passed to
`ssl.wrap_socket`. In Python 3.2+, `ssl.SSLContext` objects can
be used instead. This function converts the dict form to its
`~ssl.SSLContext` equivalent, and may be used when a component which
accepts both forms needs to upgrade to the `~ssl.SSLContext` version
to use features like SNI or NPN.
"""
if isinstance(ssl_options, dict):
assert all(k in _SSL_CONTEXT_KEYWORDS for k in ssl_options), ssl_options
if (not hasattr(ssl, 'SSLContext') or
isinstance(ssl_options, ssl.SSLContext)):
return ssl_options
context = ssl.SSLContext(
ssl_options.get('ssl_version', ssl.PROTOCOL_SSLv23))
if 'certfile' in ssl_options:
context.load_cert_chain(ssl_options['certfile'], ssl_options.get('keyfile', None))
if 'cert_reqs' in ssl_options:
context.verify_mode = ssl_options['cert_reqs']
if 'ca_certs' in ssl_options:
context.load_verify_locations(ssl_options['ca_certs'])
if 'ciphers' in ssl_options:
context.set_ciphers(ssl_options['ciphers'])
return context
def ssl_wrap_socket(socket, ssl_options, server_hostname=None, **kwargs):
"""Returns an ``ssl.SSLSocket`` wrapping the given socket.
``ssl_options`` may be either a dictionary (as accepted by
`ssl_options_to_context`) or an `ssl.SSLContext` object.
Additional keyword arguments are passed to ``wrap_socket``
(either the `~ssl.SSLContext` method or the `ssl` module function
as appropriate).
"""
context = ssl_options_to_context(ssl_options)
if hasattr(ssl, 'SSLContext') and isinstance(context, ssl.SSLContext):
if server_hostname is not None and getattr(ssl, 'HAS_SNI'):
# Python doesn't have server-side SNI support so we can't
# really unittest this, but it can be manually tested with
# python3.2 -m tornado.httpclient https://sni.velox.ch
return context.wrap_socket(socket, server_hostname=server_hostname,
**kwargs)
else:
return context.wrap_socket(socket, **kwargs)
else:
return ssl.wrap_socket(socket, **dict(context, **kwargs))
if hasattr(ssl, 'match_hostname') and hasattr(ssl, 'CertificateError'): # python 3.2+
ssl_match_hostname = ssl.match_hostname
SSLCertificateError = ssl.CertificateError
else:
# match_hostname was added to the standard library ssl module in python 3.2.
# The following code was backported for older releases and copied from
# https://bitbucket.org/brandon/backports.ssl_match_hostname
class SSLCertificateError(ValueError):
pass
def _dnsname_to_pat(dn, max_wildcards=1):
pats = []
for frag in dn.split(r'.'):
if frag.count('*') > max_wildcards:
# Issue #17980: avoid denials of service by refusing more
# than one wildcard per fragment. A survery of established
# policy among SSL implementations showed it to be a
# reasonable choice.
raise SSLCertificateError(
"too many wildcards in certificate DNS name: " + repr(dn))
if frag == '*':
# When '*' is a fragment by itself, it matches a non-empty dotless
# fragment.
pats.append('[^.]+')
else:
# Otherwise, '*' matches any dotless fragment.
frag = re.escape(frag)
pats.append(frag.replace(r'\*', '[^.]*'))
return re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE)
def ssl_match_hostname(cert, hostname):
"""Verify that *cert* (in decoded format as returned by
SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 rules
are mostly followed, but IP addresses are not accepted for *hostname*.
CertificateError is raised on failure. On success, the function
returns nothing.
"""
if not cert:
raise ValueError("empty or no certificate")
dnsnames = []
san = cert.get('subjectAltName', ())
for key, value in san:
if key == 'DNS':
if _dnsname_to_pat(value).match(hostname):
return
dnsnames.append(value)
if not dnsnames:
# The subject is only checked when there is no dNSName entry
# in subjectAltName
for sub in cert.get('subject', ()):
for key, value in sub:
# XXX according to RFC 2818, the most specific Common Name
# must be used.
if key == 'commonName':
if _dnsname_to_pat(value).match(hostname):
return
dnsnames.append(value)
if len(dnsnames) > 1:
raise SSLCertificateError("hostname %r "
"doesn't match either of %s"
% (hostname, ', '.join(map(repr, dnsnames))))
elif len(dnsnames) == 1:
raise SSLCertificateError("hostname %r "
"doesn't match %r"
% (hostname, dnsnames[0]))
else:
raise SSLCertificateError("no appropriate commonName or "
"subjectAltName fields were found")

View file

@ -0,0 +1,539 @@
#!/usr/bin/env python
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""A command line parsing module that lets modules define their own options.
Each module defines its own options which are added to the global
option namespace, e.g.::
from tornado.options import define, options
define("mysql_host", default="127.0.0.1:3306", help="Main user DB")
define("memcache_hosts", default="127.0.0.1:11011", multiple=True,
help="Main user memcache servers")
def connect():
db = database.Connection(options.mysql_host)
...
The ``main()`` method of your application does not need to be aware of all of
the options used throughout your program; they are all automatically loaded
when the modules are loaded. However, all modules that define options
must have been imported before the command line is parsed.
Your ``main()`` method can parse the command line or parse a config file with
either::
tornado.options.parse_command_line()
# or
tornado.options.parse_config_file("/etc/server.conf")
Command line formats are what you would expect (``--myoption=myvalue``).
Config files are just Python files. Global names become options, e.g.::
myoption = "myvalue"
myotheroption = "myothervalue"
We support `datetimes <datetime.datetime>`, `timedeltas
<datetime.timedelta>`, ints, and floats (just pass a ``type`` kwarg to
`define`). We also accept multi-value options. See the documentation for
`define()` below.
`tornado.options.options` is a singleton instance of `OptionParser`, and
the top-level functions in this module (`define`, `parse_command_line`, etc)
simply call methods on it. You may create additional `OptionParser`
instances to define isolated sets of options, such as for subcommands.
"""
from __future__ import absolute_import, division, print_function, with_statement
import datetime
import numbers
import re
import sys
import os
import textwrap
from tornado.escape import _unicode
from tornado.log import define_logging_options
from tornado import stack_context
from tornado.util import basestring_type, exec_in
class Error(Exception):
"""Exception raised by errors in the options module."""
pass
class OptionParser(object):
"""A collection of options, a dictionary with object-like access.
Normally accessed via static functions in the `tornado.options` module,
which reference a global instance.
"""
def __init__(self):
# we have to use self.__dict__ because we override setattr.
self.__dict__['_options'] = {}
self.__dict__['_parse_callbacks'] = []
self.define("help", type=bool, help="show this help information",
callback=self._help_callback)
def __getattr__(self, name):
if isinstance(self._options.get(name), _Option):
return self._options[name].value()
raise AttributeError("Unrecognized option %r" % name)
def __setattr__(self, name, value):
if isinstance(self._options.get(name), _Option):
return self._options[name].set(value)
raise AttributeError("Unrecognized option %r" % name)
def __iter__(self):
return iter(self._options)
def __getitem__(self, item):
return self._options[item].value()
def items(self):
"""A sequence of (name, value) pairs.
.. versionadded:: 3.1
"""
return [(name, opt.value()) for name, opt in self._options.items()]
def groups(self):
"""The set of option-groups created by ``define``.
.. versionadded:: 3.1
"""
return set(opt.group_name for opt in self._options.values())
def group_dict(self, group):
"""The names and values of options in a group.
Useful for copying options into Application settings::
from tornado.options import define, parse_command_line, options
define('template_path', group='application')
define('static_path', group='application')
parse_command_line()
application = Application(
handlers, **options.group_dict('application'))
.. versionadded:: 3.1
"""
return dict(
(name, opt.value()) for name, opt in self._options.items()
if not group or group == opt.group_name)
def as_dict(self):
"""The names and values of all options.
.. versionadded:: 3.1
"""
return dict(
(name, opt.value()) for name, opt in self._options.items())
def define(self, name, default=None, type=None, help=None, metavar=None,
multiple=False, group=None, callback=None):
"""Defines a new command line option.
If ``type`` is given (one of str, float, int, datetime, or timedelta)
or can be inferred from the ``default``, we parse the command line
arguments based on the given type. If ``multiple`` is True, we accept
comma-separated values, and the option value is always a list.
For multi-value integers, we also accept the syntax ``x:y``, which
turns into ``range(x, y)`` - very useful for long integer ranges.
``help`` and ``metavar`` are used to construct the
automatically generated command line help string. The help
message is formatted like::
--name=METAVAR help string
``group`` is used to group the defined options in logical
groups. By default, command line options are grouped by the
file in which they are defined.
Command line option names must be unique globally. They can be parsed
from the command line with `parse_command_line` or parsed from a
config file with `parse_config_file`.
If a ``callback`` is given, it will be run with the new value whenever
the option is changed. This can be used to combine command-line
and file-based options::
define("config", type=str, help="path to config file",
callback=lambda path: parse_config_file(path, final=False))
With this definition, options in the file specified by ``--config`` will
override options set earlier on the command line, but can be overridden
by later flags.
"""
if name in self._options:
raise Error("Option %r already defined in %s" %
(name, self._options[name].file_name))
frame = sys._getframe(0)
options_file = frame.f_code.co_filename
file_name = frame.f_back.f_code.co_filename
if file_name == options_file:
file_name = ""
if type is None:
if not multiple and default is not None:
type = default.__class__
else:
type = str
if group:
group_name = group
else:
group_name = file_name
self._options[name] = _Option(name, file_name=file_name,
default=default, type=type, help=help,
metavar=metavar, multiple=multiple,
group_name=group_name,
callback=callback)
def parse_command_line(self, args=None, final=True):
"""Parses all options given on the command line (defaults to
`sys.argv`).
Note that ``args[0]`` is ignored since it is the program name
in `sys.argv`.
We return a list of all arguments that are not parsed as options.
If ``final`` is ``False``, parse callbacks will not be run.
This is useful for applications that wish to combine configurations
from multiple sources.
"""
if args is None:
args = sys.argv
remaining = []
for i in range(1, len(args)):
# All things after the last option are command line arguments
if not args[i].startswith("-"):
remaining = args[i:]
break
if args[i] == "--":
remaining = args[i + 1:]
break
arg = args[i].lstrip("-")
name, equals, value = arg.partition("=")
name = name.replace('-', '_')
if not name in self._options:
self.print_help()
raise Error('Unrecognized command line option: %r' % name)
option = self._options[name]
if not equals:
if option.type == bool:
value = "true"
else:
raise Error('Option %r requires a value' % name)
option.parse(value)
if final:
self.run_parse_callbacks()
return remaining
def parse_config_file(self, path, final=True):
"""Parses and loads the Python config file at the given path.
If ``final`` is ``False``, parse callbacks will not be run.
This is useful for applications that wish to combine configurations
from multiple sources.
"""
config = {}
with open(path) as f:
exec_in(f.read(), config, config)
for name in config:
if name in self._options:
self._options[name].set(config[name])
if final:
self.run_parse_callbacks()
def print_help(self, file=None):
"""Prints all the command line options to stderr (or another file)."""
if file is None:
file = sys.stderr
print("Usage: %s [OPTIONS]" % sys.argv[0], file=file)
print("\nOptions:\n", file=file)
by_group = {}
for option in self._options.values():
by_group.setdefault(option.group_name, []).append(option)
for filename, o in sorted(by_group.items()):
if filename:
print("\n%s options:\n" % os.path.normpath(filename), file=file)
o.sort(key=lambda option: option.name)
for option in o:
prefix = option.name
if option.metavar:
prefix += "=" + option.metavar
description = option.help or ""
if option.default is not None and option.default != '':
description += " (default %s)" % option.default
lines = textwrap.wrap(description, 79 - 35)
if len(prefix) > 30 or len(lines) == 0:
lines.insert(0, '')
print(" --%-30s %s" % (prefix, lines[0]), file=file)
for line in lines[1:]:
print("%-34s %s" % (' ', line), file=file)
print(file=file)
def _help_callback(self, value):
if value:
self.print_help()
sys.exit(0)
def add_parse_callback(self, callback):
"""Adds a parse callback, to be invoked when option parsing is done."""
self._parse_callbacks.append(stack_context.wrap(callback))
def run_parse_callbacks(self):
for callback in self._parse_callbacks:
callback()
def mockable(self):
"""Returns a wrapper around self that is compatible with
`mock.patch <unittest.mock.patch>`.
The `mock.patch <unittest.mock.patch>` function (included in
the standard library `unittest.mock` package since Python 3.3,
or in the third-party ``mock`` package for older versions of
Python) is incompatible with objects like ``options`` that
override ``__getattr__`` and ``__setattr__``. This function
returns an object that can be used with `mock.patch.object
<unittest.mock.patch.object>` to modify option values::
with mock.patch.object(options.mockable(), 'name', value):
assert options.name == value
"""
return _Mockable(self)
class _Mockable(object):
"""`mock.patch` compatible wrapper for `OptionParser`.
As of ``mock`` version 1.0.1, when an object uses ``__getattr__``
hooks instead of ``__dict__``, ``patch.__exit__`` tries to delete
the attribute it set instead of setting a new one (assuming that
the object does not catpure ``__setattr__``, so the patch
created a new attribute in ``__dict__``).
_Mockable's getattr and setattr pass through to the underlying
OptionParser, and delattr undoes the effect of a previous setattr.
"""
def __init__(self, options):
# Modify __dict__ directly to bypass __setattr__
self.__dict__['_options'] = options
self.__dict__['_originals'] = {}
def __getattr__(self, name):
return getattr(self._options, name)
def __setattr__(self, name, value):
assert name not in self._originals, "don't reuse mockable objects"
self._originals[name] = getattr(self._options, name)
setattr(self._options, name, value)
def __delattr__(self, name):
setattr(self._options, name, self._originals.pop(name))
class _Option(object):
def __init__(self, name, default=None, type=basestring_type, help=None,
metavar=None, multiple=False, file_name=None, group_name=None,
callback=None):
if default is None and multiple:
default = []
self.name = name
self.type = type
self.help = help
self.metavar = metavar
self.multiple = multiple
self.file_name = file_name
self.group_name = group_name
self.callback = callback
self.default = default
self._value = None
def value(self):
return self.default if self._value is None else self._value
def parse(self, value):
_parse = {
datetime.datetime: self._parse_datetime,
datetime.timedelta: self._parse_timedelta,
bool: self._parse_bool,
basestring_type: self._parse_string,
}.get(self.type, self.type)
if self.multiple:
self._value = []
for part in value.split(","):
if issubclass(self.type, numbers.Integral):
# allow ranges of the form X:Y (inclusive at both ends)
lo, _, hi = part.partition(":")
lo = _parse(lo)
hi = _parse(hi) if hi else lo
self._value.extend(range(lo, hi + 1))
else:
self._value.append(_parse(part))
else:
self._value = _parse(value)
if self.callback is not None:
self.callback(self._value)
return self.value()
def set(self, value):
if self.multiple:
if not isinstance(value, list):
raise Error("Option %r is required to be a list of %s" %
(self.name, self.type.__name__))
for item in value:
if item is not None and not isinstance(item, self.type):
raise Error("Option %r is required to be a list of %s" %
(self.name, self.type.__name__))
else:
if value is not None and not isinstance(value, self.type):
raise Error("Option %r is required to be a %s (%s given)" %
(self.name, self.type.__name__, type(value)))
self._value = value
if self.callback is not None:
self.callback(self._value)
# Supported date/time formats in our options
_DATETIME_FORMATS = [
"%a %b %d %H:%M:%S %Y",
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%d %H:%M",
"%Y-%m-%dT%H:%M",
"%Y%m%d %H:%M:%S",
"%Y%m%d %H:%M",
"%Y-%m-%d",
"%Y%m%d",
"%H:%M:%S",
"%H:%M",
]
def _parse_datetime(self, value):
for format in self._DATETIME_FORMATS:
try:
return datetime.datetime.strptime(value, format)
except ValueError:
pass
raise Error('Unrecognized date/time format: %r' % value)
_TIMEDELTA_ABBREVS = [
('hours', ['h']),
('minutes', ['m', 'min']),
('seconds', ['s', 'sec']),
('milliseconds', ['ms']),
('microseconds', ['us']),
('days', ['d']),
('weeks', ['w']),
]
_TIMEDELTA_ABBREV_DICT = dict(
(abbrev, full) for full, abbrevs in _TIMEDELTA_ABBREVS
for abbrev in abbrevs)
_FLOAT_PATTERN = r'[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?'
_TIMEDELTA_PATTERN = re.compile(
r'\s*(%s)\s*(\w*)\s*' % _FLOAT_PATTERN, re.IGNORECASE)
def _parse_timedelta(self, value):
try:
sum = datetime.timedelta()
start = 0
while start < len(value):
m = self._TIMEDELTA_PATTERN.match(value, start)
if not m:
raise Exception()
num = float(m.group(1))
units = m.group(2) or 'seconds'
units = self._TIMEDELTA_ABBREV_DICT.get(units, units)
sum += datetime.timedelta(**{units: num})
start = m.end()
return sum
except Exception:
raise
def _parse_bool(self, value):
return value.lower() not in ("false", "0", "f")
def _parse_string(self, value):
return _unicode(value)
options = OptionParser()
"""Global options object.
All defined options are available as attributes on this object.
"""
def define(name, default=None, type=None, help=None, metavar=None,
multiple=False, group=None, callback=None):
"""Defines an option in the global namespace.
See `OptionParser.define`.
"""
return options.define(name, default=default, type=type, help=help,
metavar=metavar, multiple=multiple, group=group,
callback=callback)
def parse_command_line(args=None, final=True):
"""Parses global options from the command line.
See `OptionParser.parse_command_line`.
"""
return options.parse_command_line(args, final=final)
def parse_config_file(path, final=True):
"""Parses global options from a config file.
See `OptionParser.parse_config_file`.
"""
return options.parse_config_file(path, final=final)
def print_help(file=None):
"""Prints all the command line options to stderr (or another file).
See `OptionParser.print_help`.
"""
return options.print_help(file)
def add_parse_callback(callback):
"""Adds a parse callback, to be invoked when option parsing is done.
See `OptionParser.add_parse_callback`
"""
options.add_parse_callback(callback)
# Default options
define_logging_options(options)

View file

@ -0,0 +1,45 @@
#!/usr/bin/env python
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Implementation of platform-specific functionality.
For each function or class described in `tornado.platform.interface`,
the appropriate platform-specific implementation exists in this module.
Most code that needs access to this functionality should do e.g.::
from tornado.platform.auto import set_close_exec
"""
from __future__ import absolute_import, division, print_function, with_statement
import os
if os.name == 'nt':
from tornado.platform.common import Waker
from tornado.platform.windows import set_close_exec
else:
from tornado.platform.posix import set_close_exec, Waker
try:
# monotime monkey-patches the time module to have a monotonic function
# in versions of python before 3.3.
import monotime
except ImportError:
pass
try:
from time import monotonic as monotonic_time
except ImportError:
monotonic_time = None

View file

@ -0,0 +1,75 @@
import pycares
import socket
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.netutil import Resolver, is_valid_ip
class CaresResolver(Resolver):
"""Name resolver based on the c-ares library.
This is a non-blocking and non-threaded resolver. It may not produce
the same results as the system resolver, but can be used for non-blocking
resolution when threads cannot be used.
c-ares fails to resolve some names when ``family`` is ``AF_UNSPEC``,
so it is only recommended for use in ``AF_INET`` (i.e. IPv4). This is
the default for ``tornado.simple_httpclient``, but other libraries
may default to ``AF_UNSPEC``.
"""
def initialize(self, io_loop=None):
self.io_loop = io_loop or IOLoop.current()
self.channel = pycares.Channel(sock_state_cb=self._sock_state_cb)
self.fds = {}
def _sock_state_cb(self, fd, readable, writable):
state = ((IOLoop.READ if readable else 0) |
(IOLoop.WRITE if writable else 0))
if not state:
self.io_loop.remove_handler(fd)
del self.fds[fd]
elif fd in self.fds:
self.io_loop.update_handler(fd, state)
self.fds[fd] = state
else:
self.io_loop.add_handler(fd, self._handle_events, state)
self.fds[fd] = state
def _handle_events(self, fd, events):
read_fd = pycares.ARES_SOCKET_BAD
write_fd = pycares.ARES_SOCKET_BAD
if events & IOLoop.READ:
read_fd = fd
if events & IOLoop.WRITE:
write_fd = fd
self.channel.process_fd(read_fd, write_fd)
@gen.coroutine
def resolve(self, host, port, family=0):
if is_valid_ip(host):
addresses = [host]
else:
# gethostbyname doesn't take callback as a kwarg
self.channel.gethostbyname(host, family, (yield gen.Callback(1)))
callback_args = yield gen.Wait(1)
assert isinstance(callback_args, gen.Arguments)
assert not callback_args.kwargs
result, error = callback_args.args
if error:
raise Exception('C-Ares returned error %s: %s while resolving %s' %
(error, pycares.errno.strerror(error), host))
addresses = result.addresses
addrinfo = []
for address in addresses:
if '.' in address:
address_family = socket.AF_INET
elif ':' in address:
address_family = socket.AF_INET6
else:
address_family = socket.AF_UNSPEC
if family != socket.AF_UNSPEC and family != address_family:
raise Exception('Requested socket family %d but got %d' %
(family, address_family))
addrinfo.append((address_family, (address, port)))
raise gen.Return(addrinfo)

View file

@ -0,0 +1,91 @@
"""Lowest-common-denominator implementations of platform functionality."""
from __future__ import absolute_import, division, print_function, with_statement
import errno
import socket
from tornado.platform import interface
class Waker(interface.Waker):
"""Create an OS independent asynchronous pipe.
For use on platforms that don't have os.pipe() (or where pipes cannot
be passed to select()), but do have sockets. This includes Windows
and Jython.
"""
def __init__(self):
# Based on Zope async.py: http://svn.zope.org/zc.ngi/trunk/src/zc/ngi/async.py
self.writer = socket.socket()
# Disable buffering -- pulling the trigger sends 1 byte,
# and we want that sent immediately, to wake up ASAP.
self.writer.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
count = 0
while 1:
count += 1
# Bind to a local port; for efficiency, let the OS pick
# a free port for us.
# Unfortunately, stress tests showed that we may not
# be able to connect to that port ("Address already in
# use") despite that the OS picked it. This appears
# to be a race bug in the Windows socket implementation.
# So we loop until a connect() succeeds (almost always
# on the first try). See the long thread at
# http://mail.zope.org/pipermail/zope/2005-July/160433.html
# for hideous details.
a = socket.socket()
a.bind(("127.0.0.1", 0))
a.listen(1)
connect_address = a.getsockname() # assigned (host, port) pair
try:
self.writer.connect(connect_address)
break # success
except socket.error as detail:
if (not hasattr(errno, 'WSAEADDRINUSE') or
detail[0] != errno.WSAEADDRINUSE):
# "Address already in use" is the only error
# I've seen on two WinXP Pro SP2 boxes, under
# Pythons 2.3.5 and 2.4.1.
raise
# (10048, 'Address already in use')
# assert count <= 2 # never triggered in Tim's tests
if count >= 10: # I've never seen it go above 2
a.close()
self.writer.close()
raise socket.error("Cannot bind trigger!")
# Close `a` and try again. Note: I originally put a short
# sleep() here, but it didn't appear to help or hurt.
a.close()
self.reader, addr = a.accept()
self.reader.setblocking(0)
self.writer.setblocking(0)
a.close()
self.reader_fd = self.reader.fileno()
def fileno(self):
return self.reader.fileno()
def write_fileno(self):
return self.writer.fileno()
def wake(self):
try:
self.writer.send(b"x")
except (IOError, socket.error):
pass
def consume(self):
try:
while True:
result = self.reader.recv(1024)
if not result:
break
except (IOError, socket.error):
pass
def close(self):
self.reader.close()
self.writer.close()

View file

@ -0,0 +1,26 @@
#!/usr/bin/env python
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""EPoll-based IOLoop implementation for Linux systems."""
from __future__ import absolute_import, division, print_function, with_statement
import select
from tornado.ioloop import PollIOLoop
class EPollIOLoop(PollIOLoop):
def initialize(self, **kwargs):
super(EPollIOLoop, self).initialize(impl=select.epoll(), **kwargs)

View file

@ -0,0 +1,63 @@
#!/usr/bin/env python
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Interfaces for platform-specific functionality.
This module exists primarily for documentation purposes and as base classes
for other tornado.platform modules. Most code should import the appropriate
implementation from `tornado.platform.auto`.
"""
from __future__ import absolute_import, division, print_function, with_statement
def set_close_exec(fd):
"""Sets the close-on-exec bit (``FD_CLOEXEC``)for a file descriptor."""
raise NotImplementedError()
class Waker(object):
"""A socket-like object that can wake another thread from ``select()``.
The `~tornado.ioloop.IOLoop` will add the Waker's `fileno()` to
its ``select`` (or ``epoll`` or ``kqueue``) calls. When another
thread wants to wake up the loop, it calls `wake`. Once it has woken
up, it will call `consume` to do any necessary per-wake cleanup. When
the ``IOLoop`` is closed, it closes its waker too.
"""
def fileno(self):
"""Returns the read file descriptor for this waker.
Must be suitable for use with ``select()`` or equivalent on the
local platform.
"""
raise NotImplementedError()
def write_fileno(self):
"""Returns the write file descriptor for this waker."""
raise NotImplementedError()
def wake(self):
"""Triggers activity on the waker's file descriptor."""
raise NotImplementedError()
def consume(self):
"""Called after the listen has woken up to do any necessary cleanup."""
raise NotImplementedError()
def close(self):
"""Closes the waker's file descriptor(s)."""
raise NotImplementedError()

View file

@ -0,0 +1,92 @@
#!/usr/bin/env python
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""KQueue-based IOLoop implementation for BSD/Mac systems."""
from __future__ import absolute_import, division, print_function, with_statement
import select
from tornado.ioloop import IOLoop, PollIOLoop
assert hasattr(select, 'kqueue'), 'kqueue not supported'
class _KQueue(object):
"""A kqueue-based event loop for BSD/Mac systems."""
def __init__(self):
self._kqueue = select.kqueue()
self._active = {}
def fileno(self):
return self._kqueue.fileno()
def close(self):
self._kqueue.close()
def register(self, fd, events):
if fd in self._active:
raise IOError("fd %d already registered" % fd)
self._control(fd, events, select.KQ_EV_ADD)
self._active[fd] = events
def modify(self, fd, events):
self.unregister(fd)
self.register(fd, events)
def unregister(self, fd):
events = self._active.pop(fd)
self._control(fd, events, select.KQ_EV_DELETE)
def _control(self, fd, events, flags):
kevents = []
if events & IOLoop.WRITE:
kevents.append(select.kevent(
fd, filter=select.KQ_FILTER_WRITE, flags=flags))
if events & IOLoop.READ or not kevents:
# Always read when there is not a write
kevents.append(select.kevent(
fd, filter=select.KQ_FILTER_READ, flags=flags))
# Even though control() takes a list, it seems to return EINVAL
# on Mac OS X (10.6) when there is more than one event in the list.
for kevent in kevents:
self._kqueue.control([kevent], 0)
def poll(self, timeout):
kevents = self._kqueue.control(None, 1000, timeout)
events = {}
for kevent in kevents:
fd = kevent.ident
if kevent.filter == select.KQ_FILTER_READ:
events[fd] = events.get(fd, 0) | IOLoop.READ
if kevent.filter == select.KQ_FILTER_WRITE:
if kevent.flags & select.KQ_EV_EOF:
# If an asynchronous connection is refused, kqueue
# returns a write event with the EOF flag set.
# Turn this into an error for consistency with the
# other IOLoop implementations.
# Note that for read events, EOF may be returned before
# all data has been consumed from the socket buffer,
# so we only check for EOF on write events.
events[fd] = IOLoop.ERROR
else:
events[fd] = events.get(fd, 0) | IOLoop.WRITE
if kevent.flags & select.KQ_EV_ERROR:
events[fd] = events.get(fd, 0) | IOLoop.ERROR
return events.items()
class KQueueIOLoop(PollIOLoop):
def initialize(self, **kwargs):
super(KQueueIOLoop, self).initialize(impl=_KQueue(), **kwargs)

View file

@ -0,0 +1,70 @@
#!/usr/bin/env python
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Posix implementations of platform-specific functionality."""
from __future__ import absolute_import, division, print_function, with_statement
import fcntl
import os
from tornado.platform import interface
def set_close_exec(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFD)
fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC)
def _set_nonblocking(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
class Waker(interface.Waker):
def __init__(self):
r, w = os.pipe()
_set_nonblocking(r)
_set_nonblocking(w)
set_close_exec(r)
set_close_exec(w)
self.reader = os.fdopen(r, "rb", 0)
self.writer = os.fdopen(w, "wb", 0)
def fileno(self):
return self.reader.fileno()
def write_fileno(self):
return self.writer.fileno()
def wake(self):
try:
self.writer.write(b"x")
except IOError:
pass
def consume(self):
try:
while True:
result = self.reader.read()
if not result:
break
except IOError:
pass
def close(self):
self.reader.close()
self.writer.close()

View file

@ -0,0 +1,76 @@
#!/usr/bin/env python
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Select-based IOLoop implementation.
Used as a fallback for systems that don't support epoll or kqueue.
"""
from __future__ import absolute_import, division, print_function, with_statement
import select
from tornado.ioloop import IOLoop, PollIOLoop
class _Select(object):
"""A simple, select()-based IOLoop implementation for non-Linux systems"""
def __init__(self):
self.read_fds = set()
self.write_fds = set()
self.error_fds = set()
self.fd_sets = (self.read_fds, self.write_fds, self.error_fds)
def close(self):
pass
def register(self, fd, events):
if fd in self.read_fds or fd in self.write_fds or fd in self.error_fds:
raise IOError("fd %d already registered" % fd)
if events & IOLoop.READ:
self.read_fds.add(fd)
if events & IOLoop.WRITE:
self.write_fds.add(fd)
if events & IOLoop.ERROR:
self.error_fds.add(fd)
# Closed connections are reported as errors by epoll and kqueue,
# but as zero-byte reads by select, so when errors are requested
# we need to listen for both read and error.
self.read_fds.add(fd)
def modify(self, fd, events):
self.unregister(fd)
self.register(fd, events)
def unregister(self, fd):
self.read_fds.discard(fd)
self.write_fds.discard(fd)
self.error_fds.discard(fd)
def poll(self, timeout):
readable, writeable, errors = select.select(
self.read_fds, self.write_fds, self.error_fds, timeout)
events = {}
for fd in readable:
events[fd] = events.get(fd, 0) | IOLoop.READ
for fd in writeable:
events[fd] = events.get(fd, 0) | IOLoop.WRITE
for fd in errors:
events[fd] = events.get(fd, 0) | IOLoop.ERROR
return events.items()
class SelectIOLoop(PollIOLoop):
def initialize(self, **kwargs):
super(SelectIOLoop, self).initialize(impl=_Select(), **kwargs)

View file

@ -0,0 +1,543 @@
# Author: Ovidiu Predescu
# Date: July 2011
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
# Note: This module's docs are not currently extracted automatically,
# so changes must be made manually to twisted.rst
# TODO: refactor doc build process to use an appropriate virtualenv
"""Bridges between the Twisted reactor and Tornado IOLoop.
This module lets you run applications and libraries written for
Twisted in a Tornado application. It can be used in two modes,
depending on which library's underlying event loop you want to use.
This module has been tested with Twisted versions 11.0.0 and newer.
Twisted on Tornado
------------------
`TornadoReactor` implements the Twisted reactor interface on top of
the Tornado IOLoop. To use it, simply call `install` at the beginning
of the application::
import tornado.platform.twisted
tornado.platform.twisted.install()
from twisted.internet import reactor
When the app is ready to start, call `IOLoop.instance().start()`
instead of `reactor.run()`.
It is also possible to create a non-global reactor by calling
`tornado.platform.twisted.TornadoReactor(io_loop)`. However, if
the `IOLoop` and reactor are to be short-lived (such as those used in
unit tests), additional cleanup may be required. Specifically, it is
recommended to call::
reactor.fireSystemEvent('shutdown')
reactor.disconnectAll()
before closing the `IOLoop`.
Tornado on Twisted
------------------
`TwistedIOLoop` implements the Tornado IOLoop interface on top of the Twisted
reactor. Recommended usage::
from tornado.platform.twisted import TwistedIOLoop
from twisted.internet import reactor
TwistedIOLoop().install()
# Set up your tornado application as usual using `IOLoop.instance`
reactor.run()
`TwistedIOLoop` always uses the global Twisted reactor.
"""
from __future__ import absolute_import, division, print_function, with_statement
import datetime
import functools
import socket
import twisted.internet.abstract
from twisted.internet.posixbase import PosixReactorBase
from twisted.internet.interfaces import \
IReactorFDSet, IDelayedCall, IReactorTime, IReadDescriptor, IWriteDescriptor
from twisted.python import failure, log
from twisted.internet import error
import twisted.names.cache
import twisted.names.client
import twisted.names.hosts
import twisted.names.resolve
from zope.interface import implementer
from tornado.escape import utf8
from tornado import gen
import tornado.ioloop
from tornado.log import app_log
from tornado.netutil import Resolver
from tornado.stack_context import NullContext, wrap
from tornado.ioloop import IOLoop
@implementer(IDelayedCall)
class TornadoDelayedCall(object):
"""DelayedCall object for Tornado."""
def __init__(self, reactor, seconds, f, *args, **kw):
self._reactor = reactor
self._func = functools.partial(f, *args, **kw)
self._time = self._reactor.seconds() + seconds
self._timeout = self._reactor._io_loop.add_timeout(self._time,
self._called)
self._active = True
def _called(self):
self._active = False
self._reactor._removeDelayedCall(self)
try:
self._func()
except:
app_log.error("_called caught exception", exc_info=True)
def getTime(self):
return self._time
def cancel(self):
self._active = False
self._reactor._io_loop.remove_timeout(self._timeout)
self._reactor._removeDelayedCall(self)
def delay(self, seconds):
self._reactor._io_loop.remove_timeout(self._timeout)
self._time += seconds
self._timeout = self._reactor._io_loop.add_timeout(self._time,
self._called)
def reset(self, seconds):
self._reactor._io_loop.remove_timeout(self._timeout)
self._time = self._reactor.seconds() + seconds
self._timeout = self._reactor._io_loop.add_timeout(self._time,
self._called)
def active(self):
return self._active
@implementer(IReactorTime, IReactorFDSet)
class TornadoReactor(PosixReactorBase):
"""Twisted reactor built on the Tornado IOLoop.
Since it is intented to be used in applications where the top-level
event loop is ``io_loop.start()`` rather than ``reactor.run()``,
it is implemented a little differently than other Twisted reactors.
We override `mainLoop` instead of `doIteration` and must implement
timed call functionality on top of `IOLoop.add_timeout` rather than
using the implementation in `PosixReactorBase`.
"""
def __init__(self, io_loop=None):
if not io_loop:
io_loop = tornado.ioloop.IOLoop.current()
self._io_loop = io_loop
self._readers = {} # map of reader objects to fd
self._writers = {} # map of writer objects to fd
self._fds = {} # a map of fd to a (reader, writer) tuple
self._delayedCalls = {}
PosixReactorBase.__init__(self)
self.addSystemEventTrigger('during', 'shutdown', self.crash)
# IOLoop.start() bypasses some of the reactor initialization.
# Fire off the necessary events if they weren't already triggered
# by reactor.run().
def start_if_necessary():
if not self._started:
self.fireSystemEvent('startup')
self._io_loop.add_callback(start_if_necessary)
# IReactorTime
def seconds(self):
return self._io_loop.time()
def callLater(self, seconds, f, *args, **kw):
dc = TornadoDelayedCall(self, seconds, f, *args, **kw)
self._delayedCalls[dc] = True
return dc
def getDelayedCalls(self):
return [x for x in self._delayedCalls if x._active]
def _removeDelayedCall(self, dc):
if dc in self._delayedCalls:
del self._delayedCalls[dc]
# IReactorThreads
def callFromThread(self, f, *args, **kw):
"""See `twisted.internet.interfaces.IReactorThreads.callFromThread`"""
assert callable(f), "%s is not callable" % f
with NullContext():
# This NullContext is mainly for an edge case when running
# TwistedIOLoop on top of a TornadoReactor.
# TwistedIOLoop.add_callback uses reactor.callFromThread and
# should not pick up additional StackContexts along the way.
self._io_loop.add_callback(f, *args, **kw)
# We don't need the waker code from the super class, Tornado uses
# its own waker.
def installWaker(self):
pass
def wakeUp(self):
pass
# IReactorFDSet
def _invoke_callback(self, fd, events):
if fd not in self._fds:
return
(reader, writer) = self._fds[fd]
if reader:
err = None
if reader.fileno() == -1:
err = error.ConnectionLost()
elif events & IOLoop.READ:
err = log.callWithLogger(reader, reader.doRead)
if err is None and events & IOLoop.ERROR:
err = error.ConnectionLost()
if err is not None:
self.removeReader(reader)
reader.readConnectionLost(failure.Failure(err))
if writer:
err = None
if writer.fileno() == -1:
err = error.ConnectionLost()
elif events & IOLoop.WRITE:
err = log.callWithLogger(writer, writer.doWrite)
if err is None and events & IOLoop.ERROR:
err = error.ConnectionLost()
if err is not None:
self.removeWriter(writer)
writer.writeConnectionLost(failure.Failure(err))
def addReader(self, reader):
"""Add a FileDescriptor for notification of data available to read."""
if reader in self._readers:
# Don't add the reader if it's already there
return
fd = reader.fileno()
self._readers[reader] = fd
if fd in self._fds:
(_, writer) = self._fds[fd]
self._fds[fd] = (reader, writer)
if writer:
# We already registered this fd for write events,
# update it for read events as well.
self._io_loop.update_handler(fd, IOLoop.READ | IOLoop.WRITE)
else:
with NullContext():
self._fds[fd] = (reader, None)
self._io_loop.add_handler(fd, self._invoke_callback,
IOLoop.READ)
def addWriter(self, writer):
"""Add a FileDescriptor for notification of data available to write."""
if writer in self._writers:
return
fd = writer.fileno()
self._writers[writer] = fd
if fd in self._fds:
(reader, _) = self._fds[fd]
self._fds[fd] = (reader, writer)
if reader:
# We already registered this fd for read events,
# update it for write events as well.
self._io_loop.update_handler(fd, IOLoop.READ | IOLoop.WRITE)
else:
with NullContext():
self._fds[fd] = (None, writer)
self._io_loop.add_handler(fd, self._invoke_callback,
IOLoop.WRITE)
def removeReader(self, reader):
"""Remove a Selectable for notification of data available to read."""
if reader in self._readers:
fd = self._readers.pop(reader)
(_, writer) = self._fds[fd]
if writer:
# We have a writer so we need to update the IOLoop for
# write events only.
self._fds[fd] = (None, writer)
self._io_loop.update_handler(fd, IOLoop.WRITE)
else:
# Since we have no writer registered, we remove the
# entry from _fds and unregister the handler from the
# IOLoop
del self._fds[fd]
self._io_loop.remove_handler(fd)
def removeWriter(self, writer):
"""Remove a Selectable for notification of data available to write."""
if writer in self._writers:
fd = self._writers.pop(writer)
(reader, _) = self._fds[fd]
if reader:
# We have a reader so we need to update the IOLoop for
# read events only.
self._fds[fd] = (reader, None)
self._io_loop.update_handler(fd, IOLoop.READ)
else:
# Since we have no reader registered, we remove the
# entry from the _fds and unregister the handler from
# the IOLoop.
del self._fds[fd]
self._io_loop.remove_handler(fd)
def removeAll(self):
return self._removeAll(self._readers, self._writers)
def getReaders(self):
return self._readers.keys()
def getWriters(self):
return self._writers.keys()
# The following functions are mainly used in twisted-style test cases;
# it is expected that most users of the TornadoReactor will call
# IOLoop.start() instead of Reactor.run().
def stop(self):
PosixReactorBase.stop(self)
fire_shutdown = functools.partial(self.fireSystemEvent, "shutdown")
self._io_loop.add_callback(fire_shutdown)
def crash(self):
PosixReactorBase.crash(self)
self._io_loop.stop()
def doIteration(self, delay):
raise NotImplementedError("doIteration")
def mainLoop(self):
self._io_loop.start()
class _TestReactor(TornadoReactor):
"""Subclass of TornadoReactor for use in unittests.
This can't go in the test.py file because of import-order dependencies
with the Twisted reactor test builder.
"""
def __init__(self):
# always use a new ioloop
super(_TestReactor, self).__init__(IOLoop())
def listenTCP(self, port, factory, backlog=50, interface=''):
# default to localhost to avoid firewall prompts on the mac
if not interface:
interface = '127.0.0.1'
return super(_TestReactor, self).listenTCP(
port, factory, backlog=backlog, interface=interface)
def listenUDP(self, port, protocol, interface='', maxPacketSize=8192):
if not interface:
interface = '127.0.0.1'
return super(_TestReactor, self).listenUDP(
port, protocol, interface=interface, maxPacketSize=maxPacketSize)
def install(io_loop=None):
"""Install this package as the default Twisted reactor."""
if not io_loop:
io_loop = tornado.ioloop.IOLoop.current()
reactor = TornadoReactor(io_loop)
from twisted.internet.main import installReactor
installReactor(reactor)
return reactor
@implementer(IReadDescriptor, IWriteDescriptor)
class _FD(object):
def __init__(self, fd, handler):
self.fd = fd
self.handler = handler
self.reading = False
self.writing = False
self.lost = False
def fileno(self):
return self.fd
def doRead(self):
if not self.lost:
self.handler(self.fd, tornado.ioloop.IOLoop.READ)
def doWrite(self):
if not self.lost:
self.handler(self.fd, tornado.ioloop.IOLoop.WRITE)
def connectionLost(self, reason):
if not self.lost:
self.handler(self.fd, tornado.ioloop.IOLoop.ERROR)
self.lost = True
def logPrefix(self):
return ''
class TwistedIOLoop(tornado.ioloop.IOLoop):
"""IOLoop implementation that runs on Twisted.
Uses the global Twisted reactor by default. To create multiple
`TwistedIOLoops` in the same process, you must pass a unique reactor
when constructing each one.
Not compatible with `tornado.process.Subprocess.set_exit_callback`
because the ``SIGCHLD`` handlers used by Tornado and Twisted conflict
with each other.
"""
def initialize(self, reactor=None):
if reactor is None:
import twisted.internet.reactor
reactor = twisted.internet.reactor
self.reactor = reactor
self.fds = {}
self.reactor.callWhenRunning(self.make_current)
def close(self, all_fds=False):
self.reactor.removeAll()
for c in self.reactor.getDelayedCalls():
c.cancel()
def add_handler(self, fd, handler, events):
if fd in self.fds:
raise ValueError('fd %d added twice' % fd)
self.fds[fd] = _FD(fd, wrap(handler))
if events & tornado.ioloop.IOLoop.READ:
self.fds[fd].reading = True
self.reactor.addReader(self.fds[fd])
if events & tornado.ioloop.IOLoop.WRITE:
self.fds[fd].writing = True
self.reactor.addWriter(self.fds[fd])
def update_handler(self, fd, events):
if events & tornado.ioloop.IOLoop.READ:
if not self.fds[fd].reading:
self.fds[fd].reading = True
self.reactor.addReader(self.fds[fd])
else:
if self.fds[fd].reading:
self.fds[fd].reading = False
self.reactor.removeReader(self.fds[fd])
if events & tornado.ioloop.IOLoop.WRITE:
if not self.fds[fd].writing:
self.fds[fd].writing = True
self.reactor.addWriter(self.fds[fd])
else:
if self.fds[fd].writing:
self.fds[fd].writing = False
self.reactor.removeWriter(self.fds[fd])
def remove_handler(self, fd):
if fd not in self.fds:
return
self.fds[fd].lost = True
if self.fds[fd].reading:
self.reactor.removeReader(self.fds[fd])
if self.fds[fd].writing:
self.reactor.removeWriter(self.fds[fd])
del self.fds[fd]
def start(self):
self.reactor.run()
def stop(self):
self.reactor.crash()
def _run_callback(self, callback, *args, **kwargs):
try:
callback(*args, **kwargs)
except Exception:
self.handle_callback_exception(callback)
def add_timeout(self, deadline, callback):
if isinstance(deadline, (int, long, float)):
delay = max(deadline - self.time(), 0)
elif isinstance(deadline, datetime.timedelta):
delay = tornado.ioloop._Timeout.timedelta_to_seconds(deadline)
else:
raise TypeError("Unsupported deadline %r")
return self.reactor.callLater(delay, self._run_callback, wrap(callback))
def remove_timeout(self, timeout):
if timeout.active():
timeout.cancel()
def add_callback(self, callback, *args, **kwargs):
self.reactor.callFromThread(self._run_callback,
wrap(callback), *args, **kwargs)
def add_callback_from_signal(self, callback, *args, **kwargs):
self.add_callback(callback, *args, **kwargs)
class TwistedResolver(Resolver):
"""Twisted-based asynchronous resolver.
This is a non-blocking and non-threaded resolver. It is
recommended only when threads cannot be used, since it has
limitations compared to the standard ``getaddrinfo``-based
`~tornado.netutil.Resolver` and
`~tornado.netutil.ThreadedResolver`. Specifically, it returns at
most one result, and arguments other than ``host`` and ``family``
are ignored. It may fail to resolve when ``family`` is not
``socket.AF_UNSPEC``.
Requires Twisted 12.1 or newer.
"""
def initialize(self, io_loop=None):
self.io_loop = io_loop or IOLoop.current()
# partial copy of twisted.names.client.createResolver, which doesn't
# allow for a reactor to be passed in.
self.reactor = tornado.platform.twisted.TornadoReactor(io_loop)
host_resolver = twisted.names.hosts.Resolver('/etc/hosts')
cache_resolver = twisted.names.cache.CacheResolver(reactor=self.reactor)
real_resolver = twisted.names.client.Resolver('/etc/resolv.conf',
reactor=self.reactor)
self.resolver = twisted.names.resolve.ResolverChain(
[host_resolver, cache_resolver, real_resolver])
@gen.coroutine
def resolve(self, host, port, family=0):
# getHostByName doesn't accept IP addresses, so if the input
# looks like an IP address just return it immediately.
if twisted.internet.abstract.isIPAddress(host):
resolved = host
resolved_family = socket.AF_INET
elif twisted.internet.abstract.isIPv6Address(host):
resolved = host
resolved_family = socket.AF_INET6
else:
deferred = self.resolver.getHostByName(utf8(host))
resolved = yield gen.Task(deferred.addCallback)
if twisted.internet.abstract.isIPAddress(resolved):
resolved_family = socket.AF_INET
elif twisted.internet.abstract.isIPv6Address(resolved):
resolved_family = socket.AF_INET6
else:
resolved_family = socket.AF_UNSPEC
if family != socket.AF_UNSPEC and family != resolved_family:
raise Exception('Requested socket family %d but got %d' %
(family, resolved_family))
result = [
(resolved_family, (resolved, port)),
]
raise gen.Return(result)

View file

@ -0,0 +1,20 @@
# NOTE: win32 support is currently experimental, and not recommended
# for production use.
from __future__ import absolute_import, division, print_function, with_statement
import ctypes
import ctypes.wintypes
# See: http://msdn.microsoft.com/en-us/library/ms724935(VS.85).aspx
SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation
SetHandleInformation.argtypes = (ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD, ctypes.wintypes.DWORD)
SetHandleInformation.restype = ctypes.wintypes.BOOL
HANDLE_FLAG_INHERIT = 0x00000001
def set_close_exec(fd):
success = SetHandleInformation(fd, HANDLE_FLAG_INHERIT, 0)
if not success:
raise ctypes.GetLastError()

View file

@ -0,0 +1,292 @@
#!/usr/bin/env python
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Utilities for working with multiple processes, including both forking
the server into multiple processes and managing subprocesses.
"""
from __future__ import absolute_import, division, print_function, with_statement
import errno
import multiprocessing
import os
import signal
import subprocess
import sys
import time
from binascii import hexlify
from tornado import ioloop
from tornado.iostream import PipeIOStream
from tornado.log import gen_log
from tornado.platform.auto import set_close_exec
from tornado import stack_context
try:
long # py2
except NameError:
long = int # py3
def cpu_count():
"""Returns the number of processors on this machine."""
try:
return multiprocessing.cpu_count()
except NotImplementedError:
pass
try:
return os.sysconf("SC_NPROCESSORS_CONF")
except ValueError:
pass
gen_log.error("Could not detect number of processors; assuming 1")
return 1
def _reseed_random():
if 'random' not in sys.modules:
return
import random
# If os.urandom is available, this method does the same thing as
# random.seed (at least as of python 2.6). If os.urandom is not
# available, we mix in the pid in addition to a timestamp.
try:
seed = long(hexlify(os.urandom(16)), 16)
except NotImplementedError:
seed = int(time.time() * 1000) ^ os.getpid()
random.seed(seed)
def _pipe_cloexec():
r, w = os.pipe()
set_close_exec(r)
set_close_exec(w)
return r, w
_task_id = None
def fork_processes(num_processes, max_restarts=100):
"""Starts multiple worker processes.
If ``num_processes`` is None or <= 0, we detect the number of cores
available on this machine and fork that number of child
processes. If ``num_processes`` is given and > 0, we fork that
specific number of sub-processes.
Since we use processes and not threads, there is no shared memory
between any server code.
Note that multiple processes are not compatible with the autoreload
module (or the debug=True option to `tornado.web.Application`).
When using multiple processes, no IOLoops can be created or
referenced until after the call to ``fork_processes``.
In each child process, ``fork_processes`` returns its *task id*, a
number between 0 and ``num_processes``. Processes that exit
abnormally (due to a signal or non-zero exit status) are restarted
with the same id (up to ``max_restarts`` times). In the parent
process, ``fork_processes`` returns None if all child processes
have exited normally, but will otherwise only exit by throwing an
exception.
"""
global _task_id
assert _task_id is None
if num_processes is None or num_processes <= 0:
num_processes = cpu_count()
if ioloop.IOLoop.initialized():
raise RuntimeError("Cannot run in multiple processes: IOLoop instance "
"has already been initialized. You cannot call "
"IOLoop.instance() before calling start_processes()")
gen_log.info("Starting %d processes", num_processes)
children = {}
def start_child(i):
pid = os.fork()
if pid == 0:
# child process
_reseed_random()
global _task_id
_task_id = i
return i
else:
children[pid] = i
return None
for i in range(num_processes):
id = start_child(i)
if id is not None:
return id
num_restarts = 0
while children:
try:
pid, status = os.wait()
except OSError as e:
if e.errno == errno.EINTR:
continue
raise
if pid not in children:
continue
id = children.pop(pid)
if os.WIFSIGNALED(status):
gen_log.warning("child %d (pid %d) killed by signal %d, restarting",
id, pid, os.WTERMSIG(status))
elif os.WEXITSTATUS(status) != 0:
gen_log.warning("child %d (pid %d) exited with status %d, restarting",
id, pid, os.WEXITSTATUS(status))
else:
gen_log.info("child %d (pid %d) exited normally", id, pid)
continue
num_restarts += 1
if num_restarts > max_restarts:
raise RuntimeError("Too many child restarts, giving up")
new_id = start_child(id)
if new_id is not None:
return new_id
# All child processes exited cleanly, so exit the master process
# instead of just returning to right after the call to
# fork_processes (which will probably just start up another IOLoop
# unless the caller checks the return value).
sys.exit(0)
def task_id():
"""Returns the current task id, if any.
Returns None if this process was not created by `fork_processes`.
"""
global _task_id
return _task_id
class Subprocess(object):
"""Wraps ``subprocess.Popen`` with IOStream support.
The constructor is the same as ``subprocess.Popen`` with the following
additions:
* ``stdin``, ``stdout``, and ``stderr`` may have the value
``tornado.process.Subprocess.STREAM``, which will make the corresponding
attribute of the resulting Subprocess a `.PipeIOStream`.
* A new keyword argument ``io_loop`` may be used to pass in an IOLoop.
"""
STREAM = object()
_initialized = False
_waiting = {}
def __init__(self, *args, **kwargs):
self.io_loop = kwargs.pop('io_loop', None) or ioloop.IOLoop.current()
to_close = []
if kwargs.get('stdin') is Subprocess.STREAM:
in_r, in_w = _pipe_cloexec()
kwargs['stdin'] = in_r
to_close.append(in_r)
self.stdin = PipeIOStream(in_w, io_loop=self.io_loop)
if kwargs.get('stdout') is Subprocess.STREAM:
out_r, out_w = _pipe_cloexec()
kwargs['stdout'] = out_w
to_close.append(out_w)
self.stdout = PipeIOStream(out_r, io_loop=self.io_loop)
if kwargs.get('stderr') is Subprocess.STREAM:
err_r, err_w = _pipe_cloexec()
kwargs['stderr'] = err_w
to_close.append(err_w)
self.stderr = PipeIOStream(err_r, io_loop=self.io_loop)
self.proc = subprocess.Popen(*args, **kwargs)
for fd in to_close:
os.close(fd)
for attr in ['stdin', 'stdout', 'stderr', 'pid']:
if not hasattr(self, attr): # don't clobber streams set above
setattr(self, attr, getattr(self.proc, attr))
self._exit_callback = None
self.returncode = None
def set_exit_callback(self, callback):
"""Runs ``callback`` when this process exits.
The callback takes one argument, the return code of the process.
This method uses a ``SIGCHILD`` handler, which is a global setting
and may conflict if you have other libraries trying to handle the
same signal. If you are using more than one ``IOLoop`` it may
be necessary to call `Subprocess.initialize` first to designate
one ``IOLoop`` to run the signal handlers.
In many cases a close callback on the stdout or stderr streams
can be used as an alternative to an exit callback if the
signal handler is causing a problem.
"""
self._exit_callback = stack_context.wrap(callback)
Subprocess.initialize(self.io_loop)
Subprocess._waiting[self.pid] = self
Subprocess._try_cleanup_process(self.pid)
@classmethod
def initialize(cls, io_loop=None):
"""Initializes the ``SIGCHILD`` handler.
The signal handler is run on an `.IOLoop` to avoid locking issues.
Note that the `.IOLoop` used for signal handling need not be the
same one used by individual Subprocess objects (as long as the
``IOLoops`` are each running in separate threads).
"""
if cls._initialized:
return
if io_loop is None:
io_loop = ioloop.IOLoop.current()
cls._old_sigchld = signal.signal(
signal.SIGCHLD,
lambda sig, frame: io_loop.add_callback_from_signal(cls._cleanup))
cls._initialized = True
@classmethod
def uninitialize(cls):
"""Removes the ``SIGCHILD`` handler."""
if not cls._initialized:
return
signal.signal(signal.SIGCHLD, cls._old_sigchld)
cls._initialized = False
@classmethod
def _cleanup(cls):
for pid in list(cls._waiting.keys()): # make a copy
cls._try_cleanup_process(pid)
@classmethod
def _try_cleanup_process(cls, pid):
try:
ret_pid, status = os.waitpid(pid, os.WNOHANG)
except OSError as e:
if e.args[0] == errno.ECHILD:
return
if ret_pid == 0:
return
assert ret_pid == pid
subproc = cls._waiting.pop(pid)
subproc.io_loop.add_callback_from_signal(
subproc._set_returncode, status)
def _set_returncode(self, status):
if os.WIFSIGNALED(status):
self.returncode = -os.WTERMSIG(status)
else:
assert os.WIFEXITED(status)
self.returncode = os.WEXITSTATUS(status)
if self._exit_callback:
callback = self._exit_callback
self._exit_callback = None
callback(self.returncode)

View file

@ -0,0 +1,497 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
from tornado.escape import utf8, _unicode, native_str
from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main, _RequestProxy
from tornado.httputil import HTTPHeaders
from tornado.iostream import IOStream, SSLIOStream
from tornado.netutil import Resolver, OverrideResolver
from tornado.log import gen_log
from tornado import stack_context
from tornado.util import GzipDecompressor
import base64
import collections
import copy
import functools
import os.path
import re
import socket
import ssl
import sys
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO # python 2
try:
import urlparse # py2
except ImportError:
import urllib.parse as urlparse # py3
_DEFAULT_CA_CERTS = os.path.dirname(__file__) + '/ca-certificates.crt'
class SimpleAsyncHTTPClient(AsyncHTTPClient):
"""Non-blocking HTTP client with no external dependencies.
This class implements an HTTP 1.1 client on top of Tornado's IOStreams.
It does not currently implement all applicable parts of the HTTP
specification, but it does enough to work with major web service APIs.
Some features found in the curl-based AsyncHTTPClient are not yet
supported. In particular, proxies are not supported, connections
are not reused, and callers cannot select the network interface to be
used.
"""
def initialize(self, io_loop, max_clients=10,
hostname_mapping=None, max_buffer_size=104857600,
resolver=None, defaults=None):
"""Creates a AsyncHTTPClient.
Only a single AsyncHTTPClient instance exists per IOLoop
in order to provide limitations on the number of pending connections.
force_instance=True may be used to suppress this behavior.
max_clients is the number of concurrent requests that can be
in progress. Note that this arguments are only used when the
client is first created, and will be ignored when an existing
client is reused.
hostname_mapping is a dictionary mapping hostnames to IP addresses.
It can be used to make local DNS changes when modifying system-wide
settings like /etc/hosts is not possible or desirable (e.g. in
unittests).
max_buffer_size is the number of bytes that can be read by IOStream. It
defaults to 100mb.
"""
super(SimpleAsyncHTTPClient, self).initialize(io_loop,
defaults=defaults)
self.max_clients = max_clients
self.queue = collections.deque()
self.active = {}
self.max_buffer_size = max_buffer_size
if resolver:
self.resolver = resolver
self.own_resolver = False
else:
self.resolver = Resolver(io_loop=io_loop)
self.own_resolver = True
if hostname_mapping is not None:
self.resolver = OverrideResolver(resolver=self.resolver,
mapping=hostname_mapping)
def close(self):
super(SimpleAsyncHTTPClient, self).close()
if self.own_resolver:
self.resolver.close()
def fetch_impl(self, request, callback):
self.queue.append((request, callback))
self._process_queue()
if self.queue:
gen_log.debug("max_clients limit reached, request queued. "
"%d active, %d queued requests." % (
len(self.active), len(self.queue)))
def _process_queue(self):
with stack_context.NullContext():
while self.queue and len(self.active) < self.max_clients:
request, callback = self.queue.popleft()
key = object()
self.active[key] = (request, callback)
release_callback = functools.partial(self._release_fetch, key)
self._handle_request(request, release_callback, callback)
def _handle_request(self, request, release_callback, final_callback):
_HTTPConnection(self.io_loop, self, request, release_callback,
final_callback, self.max_buffer_size, self.resolver)
def _release_fetch(self, key):
del self.active[key]
self._process_queue()
class _HTTPConnection(object):
_SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
def __init__(self, io_loop, client, request, release_callback,
final_callback, max_buffer_size, resolver):
self.start_time = io_loop.time()
self.io_loop = io_loop
self.client = client
self.request = request
self.release_callback = release_callback
self.final_callback = final_callback
self.max_buffer_size = max_buffer_size
self.resolver = resolver
self.code = None
self.headers = None
self.chunks = None
self._decompressor = None
# Timeout handle returned by IOLoop.add_timeout
self._timeout = None
with stack_context.ExceptionStackContext(self._handle_exception):
self.parsed = urlparse.urlsplit(_unicode(self.request.url))
if self.parsed.scheme not in ("http", "https"):
raise ValueError("Unsupported url scheme: %s" %
self.request.url)
# urlsplit results have hostname and port results, but they
# didn't support ipv6 literals until python 2.7.
netloc = self.parsed.netloc
if "@" in netloc:
userpass, _, netloc = netloc.rpartition("@")
match = re.match(r'^(.+):(\d+)$', netloc)
if match:
host = match.group(1)
port = int(match.group(2))
else:
host = netloc
port = 443 if self.parsed.scheme == "https" else 80
if re.match(r'^\[.*\]$', host):
# raw ipv6 addresses in urls are enclosed in brackets
host = host[1:-1]
self.parsed_hostname = host # save final host for _on_connect
if request.allow_ipv6:
af = socket.AF_UNSPEC
else:
# We only try the first IP we get from getaddrinfo,
# so restrict to ipv4 by default.
af = socket.AF_INET
self.resolver.resolve(host, port, af, callback=self._on_resolve)
def _on_resolve(self, addrinfo):
self.stream = self._create_stream(addrinfo)
timeout = min(self.request.connect_timeout, self.request.request_timeout)
if timeout:
self._timeout = self.io_loop.add_timeout(
self.start_time + timeout,
stack_context.wrap(self._on_timeout))
self.stream.set_close_callback(self._on_close)
# ipv6 addresses are broken (in self.parsed.hostname) until
# 2.7, here is correctly parsed value calculated in __init__
sockaddr = addrinfo[0][1]
self.stream.connect(sockaddr, self._on_connect,
server_hostname=self.parsed_hostname)
def _create_stream(self, addrinfo):
af = addrinfo[0][0]
if self.parsed.scheme == "https":
ssl_options = {}
if self.request.validate_cert:
ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
if self.request.ca_certs is not None:
ssl_options["ca_certs"] = self.request.ca_certs
else:
ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
if self.request.client_key is not None:
ssl_options["keyfile"] = self.request.client_key
if self.request.client_cert is not None:
ssl_options["certfile"] = self.request.client_cert
# SSL interoperability is tricky. We want to disable
# SSLv2 for security reasons; it wasn't disabled by default
# until openssl 1.0. The best way to do this is to use
# the SSL_OP_NO_SSLv2, but that wasn't exposed to python
# until 3.2. Python 2.7 adds the ciphers argument, which
# can also be used to disable SSLv2. As a last resort
# on python 2.6, we set ssl_version to SSLv3. This is
# more narrow than we'd like since it also breaks
# compatibility with servers configured for TLSv1 only,
# but nearly all servers support SSLv3:
# http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html
if sys.version_info >= (2, 7):
ssl_options["ciphers"] = "DEFAULT:!SSLv2"
else:
# This is really only necessary for pre-1.0 versions
# of openssl, but python 2.6 doesn't expose version
# information.
ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3
return SSLIOStream(socket.socket(af),
io_loop=self.io_loop,
ssl_options=ssl_options,
max_buffer_size=self.max_buffer_size)
else:
return IOStream(socket.socket(af),
io_loop=self.io_loop,
max_buffer_size=self.max_buffer_size)
def _on_timeout(self):
self._timeout = None
if self.final_callback is not None:
raise HTTPError(599, "Timeout")
def _remove_timeout(self):
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
self._timeout = None
def _on_connect(self):
self._remove_timeout()
if self.request.request_timeout:
self._timeout = self.io_loop.add_timeout(
self.start_time + self.request.request_timeout,
stack_context.wrap(self._on_timeout))
if (self.request.method not in self._SUPPORTED_METHODS and
not self.request.allow_nonstandard_methods):
raise KeyError("unknown method %s" % self.request.method)
for key in ('network_interface',
'proxy_host', 'proxy_port',
'proxy_username', 'proxy_password'):
if getattr(self.request, key, None):
raise NotImplementedError('%s not supported' % key)
if "Connection" not in self.request.headers:
self.request.headers["Connection"] = "close"
if "Host" not in self.request.headers:
if '@' in self.parsed.netloc:
self.request.headers["Host"] = self.parsed.netloc.rpartition('@')[-1]
else:
self.request.headers["Host"] = self.parsed.netloc
username, password = None, None
if self.parsed.username is not None:
username, password = self.parsed.username, self.parsed.password
elif self.request.auth_username is not None:
username = self.request.auth_username
password = self.request.auth_password or ''
if username is not None:
if self.request.auth_mode not in (None, "basic"):
raise ValueError("unsupported auth_mode %s",
self.request.auth_mode)
auth = utf8(username) + b":" + utf8(password)
self.request.headers["Authorization"] = (b"Basic " +
base64.b64encode(auth))
if self.request.user_agent:
self.request.headers["User-Agent"] = self.request.user_agent
if not self.request.allow_nonstandard_methods:
if self.request.method in ("POST", "PATCH", "PUT"):
assert self.request.body is not None
else:
assert self.request.body is None
if self.request.body is not None:
self.request.headers["Content-Length"] = str(len(
self.request.body))
if (self.request.method == "POST" and
"Content-Type" not in self.request.headers):
self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
if self.request.use_gzip:
self.request.headers["Accept-Encoding"] = "gzip"
req_path = ((self.parsed.path or '/') +
(('?' + self.parsed.query) if self.parsed.query else ''))
request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method,
req_path))]
for k, v in self.request.headers.get_all():
line = utf8(k) + b": " + utf8(v)
if b'\n' in line:
raise ValueError('Newline in header: ' + repr(line))
request_lines.append(line)
request_str = b"\r\n".join(request_lines) + b"\r\n\r\n"
if self.request.body is not None:
request_str += self.request.body
self.stream.set_nodelay(True)
self.stream.write(request_str)
self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers)
def _release(self):
if self.release_callback is not None:
release_callback = self.release_callback
self.release_callback = None
release_callback()
def _run_callback(self, response):
self._release()
if self.final_callback is not None:
final_callback = self.final_callback
self.final_callback = None
self.io_loop.add_callback(final_callback, response)
def _handle_exception(self, typ, value, tb):
if self.final_callback:
self._remove_timeout()
self._run_callback(HTTPResponse(self.request, 599, error=value,
request_time=self.io_loop.time() - self.start_time,
))
if hasattr(self, "stream"):
self.stream.close()
return True
else:
# If our callback has already been called, we are probably
# catching an exception that is not caused by us but rather
# some child of our callback. Rather than drop it on the floor,
# pass it along.
return False
def _on_close(self):
if self.final_callback is not None:
message = "Connection closed"
if self.stream.error:
message = str(self.stream.error)
raise HTTPError(599, message)
def _handle_1xx(self, code):
self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers)
def _on_headers(self, data):
data = native_str(data.decode("latin1"))
first_line, _, header_data = data.partition("\n")
match = re.match("HTTP/1.[01] ([0-9]+) ([^\r]*)", first_line)
assert match
code = int(match.group(1))
self.headers = HTTPHeaders.parse(header_data)
if 100 <= code < 200:
self._handle_1xx(code)
return
else:
self.code = code
self.reason = match.group(2)
if "Content-Length" in self.headers:
if "," in self.headers["Content-Length"]:
# Proxies sometimes cause Content-Length headers to get
# duplicated. If all the values are identical then we can
# use them but if they differ it's an error.
pieces = re.split(r',\s*', self.headers["Content-Length"])
if any(i != pieces[0] for i in pieces):
raise ValueError("Multiple unequal Content-Lengths: %r" %
self.headers["Content-Length"])
self.headers["Content-Length"] = pieces[0]
content_length = int(self.headers["Content-Length"])
else:
content_length = None
if self.request.header_callback is not None:
# re-attach the newline we split on earlier
self.request.header_callback(first_line + _)
for k, v in self.headers.get_all():
self.request.header_callback("%s: %s\r\n" % (k, v))
self.request.header_callback('\r\n')
if self.request.method == "HEAD" or self.code == 304:
# HEAD requests and 304 responses never have content, even
# though they may have content-length headers
self._on_body(b"")
return
if 100 <= self.code < 200 or self.code == 204:
# These response codes never have bodies
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
if ("Transfer-Encoding" in self.headers or
content_length not in (None, 0)):
raise ValueError("Response with code %d should not have body" %
self.code)
self._on_body(b"")
return
if (self.request.use_gzip and
self.headers.get("Content-Encoding") == "gzip"):
self._decompressor = GzipDecompressor()
if self.headers.get("Transfer-Encoding") == "chunked":
self.chunks = []
self.stream.read_until(b"\r\n", self._on_chunk_length)
elif content_length is not None:
self.stream.read_bytes(content_length, self._on_body)
else:
self.stream.read_until_close(self._on_body)
def _on_body(self, data):
self._remove_timeout()
original_request = getattr(self.request, "original_request",
self.request)
if (self.request.follow_redirects and
self.request.max_redirects > 0 and
self.code in (301, 302, 303, 307)):
assert isinstance(self.request, _RequestProxy)
new_request = copy.copy(self.request.request)
new_request.url = urlparse.urljoin(self.request.url,
self.headers["Location"])
new_request.max_redirects = self.request.max_redirects - 1
del new_request.headers["Host"]
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4
# Client SHOULD make a GET request after a 303.
# According to the spec, 302 should be followed by the same
# method as the original request, but in practice browsers
# treat 302 the same as 303, and many servers use 302 for
# compatibility with pre-HTTP/1.1 user agents which don't
# understand the 303 status.
if self.code in (302, 303):
new_request.method = "GET"
new_request.body = None
for h in ["Content-Length", "Content-Type",
"Content-Encoding", "Transfer-Encoding"]:
try:
del self.request.headers[h]
except KeyError:
pass
new_request.original_request = original_request
final_callback = self.final_callback
self.final_callback = None
self._release()
self.client.fetch(new_request, final_callback)
self._on_end_request()
return
if self._decompressor:
data = (self._decompressor.decompress(data) +
self._decompressor.flush())
if self.request.streaming_callback:
if self.chunks is None:
# if chunks is not None, we already called streaming_callback
# in _on_chunk_data
self.request.streaming_callback(data)
buffer = BytesIO()
else:
buffer = BytesIO(data) # TODO: don't require one big string?
response = HTTPResponse(original_request,
self.code, reason=self.reason,
headers=self.headers,
request_time=self.io_loop.time() - self.start_time,
buffer=buffer,
effective_url=self.request.url)
self._run_callback(response)
self._on_end_request()
def _on_end_request(self):
self.stream.close()
def _on_chunk_length(self, data):
# TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
length = int(data.strip(), 16)
if length == 0:
if self._decompressor is not None:
tail = self._decompressor.flush()
if tail:
# I believe the tail will always be empty (i.e.
# decompress will return all it can). The purpose
# of the flush call is to detect errors such
# as truncated input. But in case it ever returns
# anything, treat it as an extra chunk
if self.request.streaming_callback is not None:
self.request.streaming_callback(tail)
else:
self.chunks.append(tail)
# all the data has been decompressed, so we don't need to
# decompress again in _on_body
self._decompressor = None
self._on_body(b''.join(self.chunks))
else:
self.stream.read_bytes(length + 2, # chunk ends with \r\n
self._on_chunk_data)
def _on_chunk_data(self, data):
assert data[-2:] == b"\r\n"
chunk = data[:-2]
if self._decompressor:
chunk = self._decompressor.decompress(chunk)
if self.request.streaming_callback is not None:
self.request.streaming_callback(chunk)
else:
self.chunks.append(chunk)
self.stream.read_until(b"\r\n", self._on_chunk_length)
if __name__ == "__main__":
AsyncHTTPClient.configure(SimpleAsyncHTTPClient)
main()

View file

@ -0,0 +1,376 @@
#!/usr/bin/env python
#
# Copyright 2010 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""`StackContext` allows applications to maintain threadlocal-like state
that follows execution as it moves to other execution contexts.
The motivating examples are to eliminate the need for explicit
``async_callback`` wrappers (as in `tornado.web.RequestHandler`), and to
allow some additional context to be kept for logging.
This is slightly magic, but it's an extension of the idea that an
exception handler is a kind of stack-local state and when that stack
is suspended and resumed in a new context that state needs to be
preserved. `StackContext` shifts the burden of restoring that state
from each call site (e.g. wrapping each `.AsyncHTTPClient` callback
in ``async_callback``) to the mechanisms that transfer control from
one context to another (e.g. `.AsyncHTTPClient` itself, `.IOLoop`,
thread pools, etc).
Example usage::
@contextlib.contextmanager
def die_on_error():
try:
yield
except Exception:
logging.error("exception in asynchronous operation",exc_info=True)
sys.exit(1)
with StackContext(die_on_error):
# Any exception thrown here *or in callback and its desendents*
# will cause the process to exit instead of spinning endlessly
# in the ioloop.
http_client.fetch(url, callback)
ioloop.start()
Most applications shouln't have to work with `StackContext` directly.
Here are a few rules of thumb for when it's necessary:
* If you're writing an asynchronous library that doesn't rely on a
stack_context-aware library like `tornado.ioloop` or `tornado.iostream`
(for example, if you're writing a thread pool), use
`.stack_context.wrap()` before any asynchronous operations to capture the
stack context from where the operation was started.
* If you're writing an asynchronous library that has some shared
resources (such as a connection pool), create those shared resources
within a ``with stack_context.NullContext():`` block. This will prevent
``StackContexts`` from leaking from one request to another.
* If you want to write something like an exception handler that will
persist across asynchronous calls, create a new `StackContext` (or
`ExceptionStackContext`), and make your asynchronous calls in a ``with``
block that references your `StackContext`.
"""
from __future__ import absolute_import, division, print_function, with_statement
import sys
import threading
from tornado.util import raise_exc_info
class StackContextInconsistentError(Exception):
pass
class _State(threading.local):
def __init__(self):
self.contexts = (tuple(), None)
_state = _State()
class StackContext(object):
"""Establishes the given context as a StackContext that will be transferred.
Note that the parameter is a callable that returns a context
manager, not the context itself. That is, where for a
non-transferable context manager you would say::
with my_context():
StackContext takes the function itself rather than its result::
with StackContext(my_context):
The result of ``with StackContext() as cb:`` is a deactivation
callback. Run this callback when the StackContext is no longer
needed to ensure that it is not propagated any further (note that
deactivating a context does not affect any instances of that
context that are currently pending). This is an advanced feature
and not necessary in most applications.
"""
def __init__(self, context_factory):
self.context_factory = context_factory
self.contexts = []
self.active = True
def _deactivate(self):
self.active = False
# StackContext protocol
def enter(self):
context = self.context_factory()
self.contexts.append(context)
context.__enter__()
def exit(self, type, value, traceback):
context = self.contexts.pop()
context.__exit__(type, value, traceback)
# Note that some of this code is duplicated in ExceptionStackContext
# below. ExceptionStackContext is more common and doesn't need
# the full generality of this class.
def __enter__(self):
self.old_contexts = _state.contexts
self.new_contexts = (self.old_contexts[0] + (self,), self)
_state.contexts = self.new_contexts
try:
self.enter()
except:
_state.contexts = self.old_contexts
raise
return self._deactivate
def __exit__(self, type, value, traceback):
try:
self.exit(type, value, traceback)
finally:
final_contexts = _state.contexts
_state.contexts = self.old_contexts
# Generator coroutines and with-statements with non-local
# effects interact badly. Check here for signs of
# the stack getting out of sync.
# Note that this check comes after restoring _state.context
# so that if it fails things are left in a (relatively)
# consistent state.
if final_contexts is not self.new_contexts:
raise StackContextInconsistentError(
'stack_context inconsistency (may be caused by yield '
'within a "with StackContext" block)')
# Break up a reference to itself to allow for faster GC on CPython.
self.new_contexts = None
class ExceptionStackContext(object):
"""Specialization of StackContext for exception handling.
The supplied ``exception_handler`` function will be called in the
event of an uncaught exception in this context. The semantics are
similar to a try/finally clause, and intended use cases are to log
an error, close a socket, or similar cleanup actions. The
``exc_info`` triple ``(type, value, traceback)`` will be passed to the
exception_handler function.
If the exception handler returns true, the exception will be
consumed and will not be propagated to other exception handlers.
"""
def __init__(self, exception_handler):
self.exception_handler = exception_handler
self.active = True
def _deactivate(self):
self.active = False
def exit(self, type, value, traceback):
if type is not None:
return self.exception_handler(type, value, traceback)
def __enter__(self):
self.old_contexts = _state.contexts
self.new_contexts = (self.old_contexts[0], self)
_state.contexts = self.new_contexts
return self._deactivate
def __exit__(self, type, value, traceback):
try:
if type is not None:
return self.exception_handler(type, value, traceback)
finally:
final_contexts = _state.contexts
_state.contexts = self.old_contexts
if final_contexts is not self.new_contexts:
raise StackContextInconsistentError(
'stack_context inconsistency (may be caused by yield '
'within a "with StackContext" block)')
# Break up a reference to itself to allow for faster GC on CPython.
self.new_contexts = None
class NullContext(object):
"""Resets the `StackContext`.
Useful when creating a shared resource on demand (e.g. an
`.AsyncHTTPClient`) where the stack that caused the creating is
not relevant to future operations.
"""
def __enter__(self):
self.old_contexts = _state.contexts
_state.contexts = (tuple(), None)
def __exit__(self, type, value, traceback):
_state.contexts = self.old_contexts
def _remove_deactivated(contexts):
"""Remove deactivated handlers from the chain"""
# Clean ctx handlers
stack_contexts = tuple([h for h in contexts[0] if h.active])
# Find new head
head = contexts[1]
while head is not None and not head.active:
head = head.old_contexts[1]
# Process chain
ctx = head
while ctx is not None:
parent = ctx.old_contexts[1]
while parent is not None:
if parent.active:
break
ctx.old_contexts = parent.old_contexts
parent = parent.old_contexts[1]
ctx = parent
return (stack_contexts, head)
def wrap(fn):
"""Returns a callable object that will restore the current `StackContext`
when executed.
Use this whenever saving a callback to be executed later in a
different execution context (either in a different thread or
asynchronously in the same thread).
"""
# Check if function is already wrapped
if fn is None or hasattr(fn, '_wrapped'):
return fn
# Capture current stack head
# TODO: Any other better way to store contexts and update them in wrapped function?
cap_contexts = [_state.contexts]
def wrapped(*args, **kwargs):
ret = None
try:
# Capture old state
current_state = _state.contexts
# Remove deactivated items
cap_contexts[0] = contexts = _remove_deactivated(cap_contexts[0])
# Force new state
_state.contexts = contexts
# Current exception
exc = (None, None, None)
top = None
# Apply stack contexts
last_ctx = 0
stack = contexts[0]
# Apply state
for n in stack:
try:
n.enter()
last_ctx += 1
except:
# Exception happened. Record exception info and store top-most handler
exc = sys.exc_info()
top = n.old_contexts[1]
# Execute callback if no exception happened while restoring state
if top is None:
try:
ret = fn(*args, **kwargs)
except:
exc = sys.exc_info()
top = contexts[1]
# If there was exception, try to handle it by going through the exception chain
if top is not None:
exc = _handle_exception(top, exc)
else:
# Otherwise take shorter path and run stack contexts in reverse order
while last_ctx > 0:
last_ctx -= 1
c = stack[last_ctx]
try:
c.exit(*exc)
except:
exc = sys.exc_info()
top = c.old_contexts[1]
break
else:
top = None
# If if exception happened while unrolling, take longer exception handler path
if top is not None:
exc = _handle_exception(top, exc)
# If exception was not handled, raise it
if exc != (None, None, None):
raise_exc_info(exc)
finally:
_state.contexts = current_state
return ret
wrapped._wrapped = True
return wrapped
def _handle_exception(tail, exc):
while tail is not None:
try:
if tail.exit(*exc):
exc = (None, None, None)
except:
exc = sys.exc_info()
tail = tail.old_contexts[1]
return exc
def run_with_stack_context(context, func):
"""Run a coroutine ``func`` in the given `StackContext`.
It is not safe to have a ``yield`` statement within a ``with StackContext``
block, so it is difficult to use stack context with `.gen.coroutine`.
This helper function runs the function in the correct context while
keeping the ``yield`` and ``with`` statements syntactically separate.
Example::
@gen.coroutine
def incorrect():
with StackContext(ctx):
# ERROR: this will raise StackContextInconsistentError
yield other_coroutine()
@gen.coroutine
def correct():
yield run_with_stack_context(StackContext(ctx), other_coroutine)
.. versionadded:: 3.1
"""
with context:
return func()

View file

@ -0,0 +1,244 @@
#!/usr/bin/env python
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""A non-blocking, single-threaded TCP server."""
from __future__ import absolute_import, division, print_function, with_statement
import errno
import os
import socket
import ssl
from tornado.log import app_log
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream, SSLIOStream
from tornado.netutil import bind_sockets, add_accept_handler, ssl_wrap_socket
from tornado import process
class TCPServer(object):
r"""A non-blocking, single-threaded TCP server.
To use `TCPServer`, define a subclass which overrides the `handle_stream`
method.
To make this server serve SSL traffic, send the ssl_options dictionary
argument with the arguments required for the `ssl.wrap_socket` method,
including "certfile" and "keyfile"::
TCPServer(ssl_options={
"certfile": os.path.join(data_dir, "mydomain.crt"),
"keyfile": os.path.join(data_dir, "mydomain.key"),
})
`TCPServer` initialization follows one of three patterns:
1. `listen`: simple single-process::
server = TCPServer()
server.listen(8888)
IOLoop.instance().start()
2. `bind`/`start`: simple multi-process::
server = TCPServer()
server.bind(8888)
server.start(0) # Forks multiple sub-processes
IOLoop.instance().start()
When using this interface, an `.IOLoop` must *not* be passed
to the `TCPServer` constructor. `start` will always start
the server on the default singleton `.IOLoop`.
3. `add_sockets`: advanced multi-process::
sockets = bind_sockets(8888)
tornado.process.fork_processes(0)
server = TCPServer()
server.add_sockets(sockets)
IOLoop.instance().start()
The `add_sockets` interface is more complicated, but it can be
used with `tornado.process.fork_processes` to give you more
flexibility in when the fork happens. `add_sockets` can
also be used in single-process servers if you want to create
your listening sockets in some way other than
`~tornado.netutil.bind_sockets`.
.. versionadded:: 3.1
The ``max_buffer_size`` argument.
"""
def __init__(self, io_loop=None, ssl_options=None, max_buffer_size=None):
self.io_loop = io_loop
self.ssl_options = ssl_options
self._sockets = {} # fd -> socket object
self._pending_sockets = []
self._started = False
self.max_buffer_size = max_buffer_size
# Verify the SSL options. Otherwise we don't get errors until clients
# connect. This doesn't verify that the keys are legitimate, but
# the SSL module doesn't do that until there is a connected socket
# which seems like too much work
if self.ssl_options is not None and isinstance(self.ssl_options, dict):
# Only certfile is required: it can contain both keys
if 'certfile' not in self.ssl_options:
raise KeyError('missing key "certfile" in ssl_options')
if not os.path.exists(self.ssl_options['certfile']):
raise ValueError('certfile "%s" does not exist' %
self.ssl_options['certfile'])
if ('keyfile' in self.ssl_options and
not os.path.exists(self.ssl_options['keyfile'])):
raise ValueError('keyfile "%s" does not exist' %
self.ssl_options['keyfile'])
def listen(self, port, address=""):
"""Starts accepting connections on the given port.
This method may be called more than once to listen on multiple ports.
`listen` takes effect immediately; it is not necessary to call
`TCPServer.start` afterwards. It is, however, necessary to start
the `.IOLoop`.
"""
sockets = bind_sockets(port, address=address)
self.add_sockets(sockets)
def add_sockets(self, sockets):
"""Makes this server start accepting connections on the given sockets.
The ``sockets`` parameter is a list of socket objects such as
those returned by `~tornado.netutil.bind_sockets`.
`add_sockets` is typically used in combination with that
method and `tornado.process.fork_processes` to provide greater
control over the initialization of a multi-process server.
"""
if self.io_loop is None:
self.io_loop = IOLoop.current()
for sock in sockets:
self._sockets[sock.fileno()] = sock
add_accept_handler(sock, self._handle_connection,
io_loop=self.io_loop)
def add_socket(self, socket):
"""Singular version of `add_sockets`. Takes a single socket object."""
self.add_sockets([socket])
def bind(self, port, address=None, family=socket.AF_UNSPEC, backlog=128):
"""Binds this server to the given port on the given address.
To start the server, call `start`. If you want to run this server
in a single process, you can call `listen` as a shortcut to the
sequence of `bind` and `start` calls.
Address may be either an IP address or hostname. If it's a hostname,
the server will listen on all IP addresses associated with the
name. Address may be an empty string or None to listen on all
available interfaces. Family may be set to either `socket.AF_INET`
or `socket.AF_INET6` to restrict to IPv4 or IPv6 addresses, otherwise
both will be used if available.
The ``backlog`` argument has the same meaning as for
`socket.listen <socket.socket.listen>`.
This method may be called multiple times prior to `start` to listen
on multiple ports or interfaces.
"""
sockets = bind_sockets(port, address=address, family=family,
backlog=backlog)
if self._started:
self.add_sockets(sockets)
else:
self._pending_sockets.extend(sockets)
def start(self, num_processes=1):
"""Starts this server in the `.IOLoop`.
By default, we run the server in this process and do not fork any
additional child process.
If num_processes is ``None`` or <= 0, we detect the number of cores
available on this machine and fork that number of child
processes. If num_processes is given and > 1, we fork that
specific number of sub-processes.
Since we use processes and not threads, there is no shared memory
between any server code.
Note that multiple processes are not compatible with the autoreload
module (or the ``debug=True`` option to `tornado.web.Application`).
When using multiple processes, no IOLoops can be created or
referenced until after the call to ``TCPServer.start(n)``.
"""
assert not self._started
self._started = True
if num_processes != 1:
process.fork_processes(num_processes)
sockets = self._pending_sockets
self._pending_sockets = []
self.add_sockets(sockets)
def stop(self):
"""Stops listening for new connections.
Requests currently in progress may still continue after the
server is stopped.
"""
for fd, sock in self._sockets.items():
self.io_loop.remove_handler(fd)
sock.close()
def handle_stream(self, stream, address):
"""Override to handle a new `.IOStream` from an incoming connection."""
raise NotImplementedError()
def _handle_connection(self, connection, address):
if self.ssl_options is not None:
assert ssl, "Python 2.6+ and OpenSSL required for SSL"
try:
connection = ssl_wrap_socket(connection,
self.ssl_options,
server_side=True,
do_handshake_on_connect=False)
except ssl.SSLError as err:
if err.args[0] == ssl.SSL_ERROR_EOF:
return connection.close()
else:
raise
except socket.error as err:
# If the connection is closed immediately after it is created
# (as in a port scan), we can get one of several errors.
# wrap_socket makes an internal call to getpeername,
# which may return either EINVAL (Mac OS X) or ENOTCONN
# (Linux). If it returns ENOTCONN, this error is
# silently swallowed by the ssl module, so we need to
# catch another error later on (AttributeError in
# SSLIOStream._do_ssl_handshake).
# To test this behavior, try nmap with the -sT flag.
# https://github.com/facebook/tornado/pull/750
if err.args[0] in (errno.ECONNABORTED, errno.EINVAL):
return connection.close()
else:
raise
try:
if self.ssl_options is not None:
stream = SSLIOStream(connection, io_loop=self.io_loop, max_buffer_size=self.max_buffer_size)
else:
stream = IOStream(connection, io_loop=self.io_loop, max_buffer_size=self.max_buffer_size)
self.handle_stream(stream, address)
except Exception:
app_log.error("Error in connection callback", exc_info=True)

View file

@ -0,0 +1,861 @@
#!/usr/bin/env python
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""A simple template system that compiles templates to Python code.
Basic usage looks like::
t = template.Template("<html>{{ myvalue }}</html>")
print t.generate(myvalue="XXX")
Loader is a class that loads templates from a root directory and caches
the compiled templates::
loader = template.Loader("/home/btaylor")
print loader.load("test.html").generate(myvalue="XXX")
We compile all templates to raw Python. Error-reporting is currently... uh,
interesting. Syntax for the templates::
### base.html
<html>
<head>
<title>{% block title %}Default title{% end %}</title>
</head>
<body>
<ul>
{% for student in students %}
{% block student %}
<li>{{ escape(student.name) }}</li>
{% end %}
{% end %}
</ul>
</body>
</html>
### bold.html
{% extends "base.html" %}
{% block title %}A bolder title{% end %}
{% block student %}
<li><span style="bold">{{ escape(student.name) }}</span></li>
{% end %}
Unlike most other template systems, we do not put any restrictions on the
expressions you can include in your statements. if and for blocks get
translated exactly into Python, you can do complex expressions like::
{% for student in [p for p in people if p.student and p.age > 23] %}
<li>{{ escape(student.name) }}</li>
{% end %}
Translating directly to Python means you can apply functions to expressions
easily, like the escape() function in the examples above. You can pass
functions in to your template just like any other variable::
### Python code
def add(x, y):
return x + y
template.execute(add=add)
### The template
{{ add(1, 2) }}
We provide the functions escape(), url_escape(), json_encode(), and squeeze()
to all templates by default.
Typical applications do not create `Template` or `Loader` instances by
hand, but instead use the `~.RequestHandler.render` and
`~.RequestHandler.render_string` methods of
`tornado.web.RequestHandler`, which load templates automatically based
on the ``template_path`` `.Application` setting.
Variable names beginning with ``_tt_`` are reserved by the template
system and should not be used by application code.
Syntax Reference
----------------
Template expressions are surrounded by double curly braces: ``{{ ... }}``.
The contents may be any python expression, which will be escaped according
to the current autoescape setting and inserted into the output. Other
template directives use ``{% %}``. These tags may be escaped as ``{{!``
and ``{%!`` if you need to include a literal ``{{`` or ``{%`` in the output.
To comment out a section so that it is omitted from the output, surround it
with ``{# ... #}``.
``{% apply *function* %}...{% end %}``
Applies a function to the output of all template code between ``apply``
and ``end``::
{% apply linkify %}{{name}} said: {{message}}{% end %}
Note that as an implementation detail apply blocks are implemented
as nested functions and thus may interact strangely with variables
set via ``{% set %}``, or the use of ``{% break %}`` or ``{% continue %}``
within loops.
``{% autoescape *function* %}``
Sets the autoescape mode for the current file. This does not affect
other files, even those referenced by ``{% include %}``. Note that
autoescaping can also be configured globally, at the `.Application`
or `Loader`.::
{% autoescape xhtml_escape %}
{% autoescape None %}
``{% block *name* %}...{% end %}``
Indicates a named, replaceable block for use with ``{% extends %}``.
Blocks in the parent template will be replaced with the contents of
the same-named block in a child template.::
<!-- base.html -->
<title>{% block title %}Default title{% end %}</title>
<!-- mypage.html -->
{% extends "base.html" %}
{% block title %}My page title{% end %}
``{% comment ... %}``
A comment which will be removed from the template output. Note that
there is no ``{% end %}`` tag; the comment goes from the word ``comment``
to the closing ``%}`` tag.
``{% extends *filename* %}``
Inherit from another template. Templates that use ``extends`` should
contain one or more ``block`` tags to replace content from the parent
template. Anything in the child template not contained in a ``block``
tag will be ignored. For an example, see the ``{% block %}`` tag.
``{% for *var* in *expr* %}...{% end %}``
Same as the python ``for`` statement. ``{% break %}`` and
``{% continue %}`` may be used inside the loop.
``{% from *x* import *y* %}``
Same as the python ``import`` statement.
``{% if *condition* %}...{% elif *condition* %}...{% else %}...{% end %}``
Conditional statement - outputs the first section whose condition is
true. (The ``elif`` and ``else`` sections are optional)
``{% import *module* %}``
Same as the python ``import`` statement.
``{% include *filename* %}``
Includes another template file. The included file can see all the local
variables as if it were copied directly to the point of the ``include``
directive (the ``{% autoescape %}`` directive is an exception).
Alternately, ``{% module Template(filename, **kwargs) %}`` may be used
to include another template with an isolated namespace.
``{% module *expr* %}``
Renders a `~tornado.web.UIModule`. The output of the ``UIModule`` is
not escaped::
{% module Template("foo.html", arg=42) %}
``{% raw *expr* %}``
Outputs the result of the given expression without autoescaping.
``{% set *x* = *y* %}``
Sets a local variable.
``{% try %}...{% except %}...{% finally %}...{% else %}...{% end %}``
Same as the python ``try`` statement.
``{% while *condition* %}... {% end %}``
Same as the python ``while`` statement. ``{% break %}`` and
``{% continue %}`` may be used inside the loop.
"""
from __future__ import absolute_import, division, print_function, with_statement
import datetime
import linecache
import os.path
import posixpath
import re
import threading
from tornado import escape
from tornado.log import app_log
from tornado.util import bytes_type, ObjectDict, exec_in, unicode_type
try:
from cStringIO import StringIO # py2
except ImportError:
from io import StringIO # py3
_DEFAULT_AUTOESCAPE = "xhtml_escape"
_UNSET = object()
class Template(object):
"""A compiled template.
We compile into Python from the given template_string. You can generate
the template from variables with generate().
"""
# note that the constructor's signature is not extracted with
# autodoc because _UNSET looks like garbage. When changing
# this signature update website/sphinx/template.rst too.
def __init__(self, template_string, name="<string>", loader=None,
compress_whitespace=None, autoescape=_UNSET):
self.name = name
if compress_whitespace is None:
compress_whitespace = name.endswith(".html") or \
name.endswith(".js")
if autoescape is not _UNSET:
self.autoescape = autoescape
elif loader:
self.autoescape = loader.autoescape
else:
self.autoescape = _DEFAULT_AUTOESCAPE
self.namespace = loader.namespace if loader else {}
reader = _TemplateReader(name, escape.native_str(template_string))
self.file = _File(self, _parse(reader, self))
self.code = self._generate_python(loader, compress_whitespace)
self.loader = loader
try:
# Under python2.5, the fake filename used here must match
# the module name used in __name__ below.
# The dont_inherit flag prevents template.py's future imports
# from being applied to the generated code.
self.compiled = compile(
escape.to_unicode(self.code),
"%s.generated.py" % self.name.replace('.', '_'),
"exec", dont_inherit=True)
except Exception:
formatted_code = _format_code(self.code).rstrip()
app_log.error("%s code:\n%s", self.name, formatted_code)
raise
def generate(self, **kwargs):
"""Generate this template with the given arguments."""
namespace = {
"escape": escape.xhtml_escape,
"xhtml_escape": escape.xhtml_escape,
"url_escape": escape.url_escape,
"json_encode": escape.json_encode,
"squeeze": escape.squeeze,
"linkify": escape.linkify,
"datetime": datetime,
"_tt_utf8": escape.utf8, # for internal use
"_tt_string_types": (unicode_type, bytes_type),
# __name__ and __loader__ allow the traceback mechanism to find
# the generated source code.
"__name__": self.name.replace('.', '_'),
"__loader__": ObjectDict(get_source=lambda name: self.code),
}
namespace.update(self.namespace)
namespace.update(kwargs)
exec_in(self.compiled, namespace)
execute = namespace["_tt_execute"]
# Clear the traceback module's cache of source data now that
# we've generated a new template (mainly for this module's
# unittests, where different tests reuse the same name).
linecache.clearcache()
return execute()
def _generate_python(self, loader, compress_whitespace):
buffer = StringIO()
try:
# named_blocks maps from names to _NamedBlock objects
named_blocks = {}
ancestors = self._get_ancestors(loader)
ancestors.reverse()
for ancestor in ancestors:
ancestor.find_named_blocks(loader, named_blocks)
writer = _CodeWriter(buffer, named_blocks, loader, ancestors[0].template,
compress_whitespace)
ancestors[0].generate(writer)
return buffer.getvalue()
finally:
buffer.close()
def _get_ancestors(self, loader):
ancestors = [self.file]
for chunk in self.file.body.chunks:
if isinstance(chunk, _ExtendsBlock):
if not loader:
raise ParseError("{% extends %} block found, but no "
"template loader")
template = loader.load(chunk.name, self.name)
ancestors.extend(template._get_ancestors(loader))
return ancestors
class BaseLoader(object):
"""Base class for template loaders.
You must use a template loader to use template constructs like
``{% extends %}`` and ``{% include %}``. The loader caches all
templates after they are loaded the first time.
"""
def __init__(self, autoescape=_DEFAULT_AUTOESCAPE, namespace=None):
"""``autoescape`` must be either None or a string naming a function
in the template namespace, such as "xhtml_escape".
"""
self.autoescape = autoescape
self.namespace = namespace or {}
self.templates = {}
# self.lock protects self.templates. It's a reentrant lock
# because templates may load other templates via `include` or
# `extends`. Note that thanks to the GIL this code would be safe
# even without the lock, but could lead to wasted work as multiple
# threads tried to compile the same template simultaneously.
self.lock = threading.RLock()
def reset(self):
"""Resets the cache of compiled templates."""
with self.lock:
self.templates = {}
def resolve_path(self, name, parent_path=None):
"""Converts a possibly-relative path to absolute (used internally)."""
raise NotImplementedError()
def load(self, name, parent_path=None):
"""Loads a template."""
name = self.resolve_path(name, parent_path=parent_path)
with self.lock:
if name not in self.templates:
self.templates[name] = self._create_template(name)
return self.templates[name]
def _create_template(self, name):
raise NotImplementedError()
class Loader(BaseLoader):
"""A template loader that loads from a single root directory.
"""
def __init__(self, root_directory, **kwargs):
super(Loader, self).__init__(**kwargs)
self.root = os.path.abspath(root_directory)
def resolve_path(self, name, parent_path=None):
if parent_path and not parent_path.startswith("<") and \
not parent_path.startswith("/") and \
not name.startswith("/"):
current_path = os.path.join(self.root, parent_path)
file_dir = os.path.dirname(os.path.abspath(current_path))
relative_path = os.path.abspath(os.path.join(file_dir, name))
if relative_path.startswith(self.root):
name = relative_path[len(self.root) + 1:]
return name
def _create_template(self, name):
path = os.path.join(self.root, name)
f = open(path, "rb")
template = Template(f.read(), name=name, loader=self)
f.close()
return template
class DictLoader(BaseLoader):
"""A template loader that loads from a dictionary."""
def __init__(self, dict, **kwargs):
super(DictLoader, self).__init__(**kwargs)
self.dict = dict
def resolve_path(self, name, parent_path=None):
if parent_path and not parent_path.startswith("<") and \
not parent_path.startswith("/") and \
not name.startswith("/"):
file_dir = posixpath.dirname(parent_path)
name = posixpath.normpath(posixpath.join(file_dir, name))
return name
def _create_template(self, name):
return Template(self.dict[name], name=name, loader=self)
class _Node(object):
def each_child(self):
return ()
def generate(self, writer):
raise NotImplementedError()
def find_named_blocks(self, loader, named_blocks):
for child in self.each_child():
child.find_named_blocks(loader, named_blocks)
class _File(_Node):
def __init__(self, template, body):
self.template = template
self.body = body
self.line = 0
def generate(self, writer):
writer.write_line("def _tt_execute():", self.line)
with writer.indent():
writer.write_line("_tt_buffer = []", self.line)
writer.write_line("_tt_append = _tt_buffer.append", self.line)
self.body.generate(writer)
writer.write_line("return _tt_utf8('').join(_tt_buffer)", self.line)
def each_child(self):
return (self.body,)
class _ChunkList(_Node):
def __init__(self, chunks):
self.chunks = chunks
def generate(self, writer):
for chunk in self.chunks:
chunk.generate(writer)
def each_child(self):
return self.chunks
class _NamedBlock(_Node):
def __init__(self, name, body, template, line):
self.name = name
self.body = body
self.template = template
self.line = line
def each_child(self):
return (self.body,)
def generate(self, writer):
block = writer.named_blocks[self.name]
with writer.include(block.template, self.line):
block.body.generate(writer)
def find_named_blocks(self, loader, named_blocks):
named_blocks[self.name] = self
_Node.find_named_blocks(self, loader, named_blocks)
class _ExtendsBlock(_Node):
def __init__(self, name):
self.name = name
class _IncludeBlock(_Node):
def __init__(self, name, reader, line):
self.name = name
self.template_name = reader.name
self.line = line
def find_named_blocks(self, loader, named_blocks):
included = loader.load(self.name, self.template_name)
included.file.find_named_blocks(loader, named_blocks)
def generate(self, writer):
included = writer.loader.load(self.name, self.template_name)
with writer.include(included, self.line):
included.file.body.generate(writer)
class _ApplyBlock(_Node):
def __init__(self, method, line, body=None):
self.method = method
self.line = line
self.body = body
def each_child(self):
return (self.body,)
def generate(self, writer):
method_name = "_tt_apply%d" % writer.apply_counter
writer.apply_counter += 1
writer.write_line("def %s():" % method_name, self.line)
with writer.indent():
writer.write_line("_tt_buffer = []", self.line)
writer.write_line("_tt_append = _tt_buffer.append", self.line)
self.body.generate(writer)
writer.write_line("return _tt_utf8('').join(_tt_buffer)", self.line)
writer.write_line("_tt_append(_tt_utf8(%s(%s())))" % (
self.method, method_name), self.line)
class _ControlBlock(_Node):
def __init__(self, statement, line, body=None):
self.statement = statement
self.line = line
self.body = body
def each_child(self):
return (self.body,)
def generate(self, writer):
writer.write_line("%s:" % self.statement, self.line)
with writer.indent():
self.body.generate(writer)
# Just in case the body was empty
writer.write_line("pass", self.line)
class _IntermediateControlBlock(_Node):
def __init__(self, statement, line):
self.statement = statement
self.line = line
def generate(self, writer):
# In case the previous block was empty
writer.write_line("pass", self.line)
writer.write_line("%s:" % self.statement, self.line, writer.indent_size() - 1)
class _Statement(_Node):
def __init__(self, statement, line):
self.statement = statement
self.line = line
def generate(self, writer):
writer.write_line(self.statement, self.line)
class _Expression(_Node):
def __init__(self, expression, line, raw=False):
self.expression = expression
self.line = line
self.raw = raw
def generate(self, writer):
writer.write_line("_tt_tmp = %s" % self.expression, self.line)
writer.write_line("if isinstance(_tt_tmp, _tt_string_types):"
" _tt_tmp = _tt_utf8(_tt_tmp)", self.line)
writer.write_line("else: _tt_tmp = _tt_utf8(str(_tt_tmp))", self.line)
if not self.raw and writer.current_template.autoescape is not None:
# In python3 functions like xhtml_escape return unicode,
# so we have to convert to utf8 again.
writer.write_line("_tt_tmp = _tt_utf8(%s(_tt_tmp))" %
writer.current_template.autoescape, self.line)
writer.write_line("_tt_append(_tt_tmp)", self.line)
class _Module(_Expression):
def __init__(self, expression, line):
super(_Module, self).__init__("_tt_modules." + expression, line,
raw=True)
class _Text(_Node):
def __init__(self, value, line):
self.value = value
self.line = line
def generate(self, writer):
value = self.value
# Compress lots of white space to a single character. If the whitespace
# breaks a line, have it continue to break a line, but just with a
# single \n character
if writer.compress_whitespace and "<pre>" not in value:
value = re.sub(r"([\t ]+)", " ", value)
value = re.sub(r"(\s*\n\s*)", "\n", value)
if value:
writer.write_line('_tt_append(%r)' % escape.utf8(value), self.line)
class ParseError(Exception):
"""Raised for template syntax errors."""
pass
class _CodeWriter(object):
def __init__(self, file, named_blocks, loader, current_template,
compress_whitespace):
self.file = file
self.named_blocks = named_blocks
self.loader = loader
self.current_template = current_template
self.compress_whitespace = compress_whitespace
self.apply_counter = 0
self.include_stack = []
self._indent = 0
def indent_size(self):
return self._indent
def indent(self):
class Indenter(object):
def __enter__(_):
self._indent += 1
return self
def __exit__(_, *args):
assert self._indent > 0
self._indent -= 1
return Indenter()
def include(self, template, line):
self.include_stack.append((self.current_template, line))
self.current_template = template
class IncludeTemplate(object):
def __enter__(_):
return self
def __exit__(_, *args):
self.current_template = self.include_stack.pop()[0]
return IncludeTemplate()
def write_line(self, line, line_number, indent=None):
if indent is None:
indent = self._indent
line_comment = ' # %s:%d' % (self.current_template.name, line_number)
if self.include_stack:
ancestors = ["%s:%d" % (tmpl.name, lineno)
for (tmpl, lineno) in self.include_stack]
line_comment += ' (via %s)' % ', '.join(reversed(ancestors))
print(" " * indent + line + line_comment, file=self.file)
class _TemplateReader(object):
def __init__(self, name, text):
self.name = name
self.text = text
self.line = 1
self.pos = 0
def find(self, needle, start=0, end=None):
assert start >= 0, start
pos = self.pos
start += pos
if end is None:
index = self.text.find(needle, start)
else:
end += pos
assert end >= start
index = self.text.find(needle, start, end)
if index != -1:
index -= pos
return index
def consume(self, count=None):
if count is None:
count = len(self.text) - self.pos
newpos = self.pos + count
self.line += self.text.count("\n", self.pos, newpos)
s = self.text[self.pos:newpos]
self.pos = newpos
return s
def remaining(self):
return len(self.text) - self.pos
def __len__(self):
return self.remaining()
def __getitem__(self, key):
if type(key) is slice:
size = len(self)
start, stop, step = key.indices(size)
if start is None:
start = self.pos
else:
start += self.pos
if stop is not None:
stop += self.pos
return self.text[slice(start, stop, step)]
elif key < 0:
return self.text[key]
else:
return self.text[self.pos + key]
def __str__(self):
return self.text[self.pos:]
def _format_code(code):
lines = code.splitlines()
format = "%%%dd %%s\n" % len(repr(len(lines) + 1))
return "".join([format % (i + 1, line) for (i, line) in enumerate(lines)])
def _parse(reader, template, in_block=None, in_loop=None):
body = _ChunkList([])
while True:
# Find next template directive
curly = 0
while True:
curly = reader.find("{", curly)
if curly == -1 or curly + 1 == reader.remaining():
# EOF
if in_block:
raise ParseError("Missing {%% end %%} block for %s" %
in_block)
body.chunks.append(_Text(reader.consume(), reader.line))
return body
# If the first curly brace is not the start of a special token,
# start searching from the character after it
if reader[curly + 1] not in ("{", "%", "#"):
curly += 1
continue
# When there are more than 2 curlies in a row, use the
# innermost ones. This is useful when generating languages
# like latex where curlies are also meaningful
if (curly + 2 < reader.remaining() and
reader[curly + 1] == '{' and reader[curly + 2] == '{'):
curly += 1
continue
break
# Append any text before the special token
if curly > 0:
cons = reader.consume(curly)
body.chunks.append(_Text(cons, reader.line))
start_brace = reader.consume(2)
line = reader.line
# Template directives may be escaped as "{{!" or "{%!".
# In this case output the braces and consume the "!".
# This is especially useful in conjunction with jquery templates,
# which also use double braces.
if reader.remaining() and reader[0] == "!":
reader.consume(1)
body.chunks.append(_Text(start_brace, line))
continue
# Comment
if start_brace == "{#":
end = reader.find("#}")
if end == -1:
raise ParseError("Missing end expression #} on line %d" % line)
contents = reader.consume(end).strip()
reader.consume(2)
continue
# Expression
if start_brace == "{{":
end = reader.find("}}")
if end == -1:
raise ParseError("Missing end expression }} on line %d" % line)
contents = reader.consume(end).strip()
reader.consume(2)
if not contents:
raise ParseError("Empty expression on line %d" % line)
body.chunks.append(_Expression(contents, line))
continue
# Block
assert start_brace == "{%", start_brace
end = reader.find("%}")
if end == -1:
raise ParseError("Missing end block %%} on line %d" % line)
contents = reader.consume(end).strip()
reader.consume(2)
if not contents:
raise ParseError("Empty block tag ({%% %%}) on line %d" % line)
operator, space, suffix = contents.partition(" ")
suffix = suffix.strip()
# Intermediate ("else", "elif", etc) blocks
intermediate_blocks = {
"else": set(["if", "for", "while", "try"]),
"elif": set(["if"]),
"except": set(["try"]),
"finally": set(["try"]),
}
allowed_parents = intermediate_blocks.get(operator)
if allowed_parents is not None:
if not in_block:
raise ParseError("%s outside %s block" %
(operator, allowed_parents))
if in_block not in allowed_parents:
raise ParseError("%s block cannot be attached to %s block" % (operator, in_block))
body.chunks.append(_IntermediateControlBlock(contents, line))
continue
# End tag
elif operator == "end":
if not in_block:
raise ParseError("Extra {%% end %%} block on line %d" % line)
return body
elif operator in ("extends", "include", "set", "import", "from",
"comment", "autoescape", "raw", "module"):
if operator == "comment":
continue
if operator == "extends":
suffix = suffix.strip('"').strip("'")
if not suffix:
raise ParseError("extends missing file path on line %d" % line)
block = _ExtendsBlock(suffix)
elif operator in ("import", "from"):
if not suffix:
raise ParseError("import missing statement on line %d" % line)
block = _Statement(contents, line)
elif operator == "include":
suffix = suffix.strip('"').strip("'")
if not suffix:
raise ParseError("include missing file path on line %d" % line)
block = _IncludeBlock(suffix, reader, line)
elif operator == "set":
if not suffix:
raise ParseError("set missing statement on line %d" % line)
block = _Statement(suffix, line)
elif operator == "autoescape":
fn = suffix.strip()
if fn == "None":
fn = None
template.autoescape = fn
continue
elif operator == "raw":
block = _Expression(suffix, line, raw=True)
elif operator == "module":
block = _Module(suffix, line)
body.chunks.append(block)
continue
elif operator in ("apply", "block", "try", "if", "for", "while"):
# parse inner body recursively
if operator in ("for", "while"):
block_body = _parse(reader, template, operator, operator)
elif operator == "apply":
# apply creates a nested function so syntactically it's not
# in the loop.
block_body = _parse(reader, template, operator, None)
else:
block_body = _parse(reader, template, operator, in_loop)
if operator == "apply":
if not suffix:
raise ParseError("apply missing method name on line %d" % line)
block = _ApplyBlock(suffix, line, block_body)
elif operator == "block":
if not suffix:
raise ParseError("block missing name on line %d" % line)
block = _NamedBlock(suffix, block_body, template, line)
else:
block = _ControlBlock(contents, line, block_body)
body.chunks.append(block)
continue
elif operator in ("break", "continue"):
if not in_loop:
raise ParseError("%s outside %s block" % (operator, set(["for", "while"])))
body.chunks.append(_Statement(contents, line))
continue
else:
raise ParseError("unknown operator: %r" % operator)

View file

@ -0,0 +1,4 @@
Test coverage is almost non-existent, but it's a start. Be sure to
set PYTHONPATH apprioriately (generally to the root directory of your
tornado checkout) when running tests to make sure you're getting the
version of the tornado package that you expect.

View file

@ -0,0 +1,424 @@
# These tests do not currently do much to verify the correct implementation
# of the openid/oauth protocols, they just exercise the major code paths
# and ensure that it doesn't blow up (e.g. with unicode/bytes issues in
# python 3)
from __future__ import absolute_import, division, print_function, with_statement
from tornado.auth import OpenIdMixin, OAuthMixin, OAuth2Mixin, TwitterMixin, GoogleMixin, AuthError
from tornado.concurrent import Future
from tornado.escape import json_decode
from tornado import gen
from tornado.log import gen_log
from tornado.testing import AsyncHTTPTestCase, ExpectLog
from tornado.util import u
from tornado.web import RequestHandler, Application, asynchronous, HTTPError
class OpenIdClientLoginHandler(RequestHandler, OpenIdMixin):
def initialize(self, test):
self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate')
@asynchronous
def get(self):
if self.get_argument('openid.mode', None):
self.get_authenticated_user(
self.on_user, http_client=self.settings['http_client'])
return
res = self.authenticate_redirect()
assert isinstance(res, Future)
assert res.done()
def on_user(self, user):
if user is None:
raise Exception("user is None")
self.finish(user)
class OpenIdServerAuthenticateHandler(RequestHandler):
def post(self):
if self.get_argument('openid.mode') != 'check_authentication':
raise Exception("incorrect openid.mode %r")
self.write('is_valid:true')
class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
def initialize(self, test, version):
self._OAUTH_VERSION = version
self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize')
self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/oauth1/server/access_token')
def _oauth_consumer_token(self):
return dict(key='asdf', secret='qwer')
@asynchronous
def get(self):
if self.get_argument('oauth_token', None):
self.get_authenticated_user(
self.on_user, http_client=self.settings['http_client'])
return
res = self.authorize_redirect(http_client=self.settings['http_client'])
assert isinstance(res, Future)
def on_user(self, user):
if user is None:
raise Exception("user is None")
self.finish(user)
def _oauth_get_user(self, access_token, callback):
if access_token != dict(key='uiop', secret='5678'):
raise Exception("incorrect access token %r" % access_token)
callback(dict(email='foo@example.com'))
class OAuth1ClientRequestParametersHandler(RequestHandler, OAuthMixin):
def initialize(self, version):
self._OAUTH_VERSION = version
def _oauth_consumer_token(self):
return dict(key='asdf', secret='qwer')
def get(self):
params = self._oauth_request_parameters(
'http://www.example.com/api/asdf',
dict(key='uiop', secret='5678'),
parameters=dict(foo='bar'))
self.write(params)
class OAuth1ServerRequestTokenHandler(RequestHandler):
def get(self):
self.write('oauth_token=zxcv&oauth_token_secret=1234')
class OAuth1ServerAccessTokenHandler(RequestHandler):
def get(self):
self.write('oauth_token=uiop&oauth_token_secret=5678')
class OAuth2ClientLoginHandler(RequestHandler, OAuth2Mixin):
def initialize(self, test):
self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth2/server/authorize')
def get(self):
res = self.authorize_redirect()
assert isinstance(res, Future)
assert res.done()
class TwitterClientHandler(RequestHandler, TwitterMixin):
def initialize(self, test):
self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/twitter/server/access_token')
self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize')
self._TWITTER_BASE_URL = test.get_url('/twitter/api')
def get_auth_http_client(self):
return self.settings['http_client']
class TwitterClientLoginHandler(TwitterClientHandler):
@asynchronous
def get(self):
if self.get_argument("oauth_token", None):
self.get_authenticated_user(self.on_user)
return
self.authorize_redirect()
def on_user(self, user):
if user is None:
raise Exception("user is None")
self.finish(user)
class TwitterClientLoginGenEngineHandler(TwitterClientHandler):
@asynchronous
@gen.engine
def get(self):
if self.get_argument("oauth_token", None):
user = yield self.get_authenticated_user()
self.finish(user)
else:
# Old style: with @gen.engine we can ignore the Future from
# authorize_redirect.
self.authorize_redirect()
class TwitterClientLoginGenCoroutineHandler(TwitterClientHandler):
@gen.coroutine
def get(self):
if self.get_argument("oauth_token", None):
user = yield self.get_authenticated_user()
self.finish(user)
else:
# New style: with @gen.coroutine the result must be yielded
# or else the request will be auto-finished too soon.
yield self.authorize_redirect()
class TwitterClientShowUserHandler(TwitterClientHandler):
@asynchronous
@gen.engine
def get(self):
# TODO: would be nice to go through the login flow instead of
# cheating with a hard-coded access token.
response = yield gen.Task(self.twitter_request,
'/users/show/%s' % self.get_argument('name'),
access_token=dict(key='hjkl', secret='vbnm'))
if response is None:
self.set_status(500)
self.finish('error from twitter request')
else:
self.finish(response)
class TwitterClientShowUserFutureHandler(TwitterClientHandler):
@asynchronous
@gen.engine
def get(self):
try:
response = yield self.twitter_request(
'/users/show/%s' % self.get_argument('name'),
access_token=dict(key='hjkl', secret='vbnm'))
except AuthError as e:
self.set_status(500)
self.finish(str(e))
return
assert response is not None
self.finish(response)
class TwitterServerAccessTokenHandler(RequestHandler):
def get(self):
self.write('oauth_token=hjkl&oauth_token_secret=vbnm&screen_name=foo')
class TwitterServerShowUserHandler(RequestHandler):
def get(self, screen_name):
if screen_name == 'error':
raise HTTPError(500)
assert 'oauth_nonce' in self.request.arguments
assert 'oauth_timestamp' in self.request.arguments
assert 'oauth_signature' in self.request.arguments
assert self.get_argument('oauth_consumer_key') == 'test_twitter_consumer_key'
assert self.get_argument('oauth_signature_method') == 'HMAC-SHA1'
assert self.get_argument('oauth_version') == '1.0'
assert self.get_argument('oauth_token') == 'hjkl'
self.write(dict(screen_name=screen_name, name=screen_name.capitalize()))
class TwitterServerVerifyCredentialsHandler(RequestHandler):
def get(self):
assert 'oauth_nonce' in self.request.arguments
assert 'oauth_timestamp' in self.request.arguments
assert 'oauth_signature' in self.request.arguments
assert self.get_argument('oauth_consumer_key') == 'test_twitter_consumer_key'
assert self.get_argument('oauth_signature_method') == 'HMAC-SHA1'
assert self.get_argument('oauth_version') == '1.0'
assert self.get_argument('oauth_token') == 'hjkl'
self.write(dict(screen_name='foo', name='Foo'))
class GoogleOpenIdClientLoginHandler(RequestHandler, GoogleMixin):
def initialize(self, test):
self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate')
@asynchronous
def get(self):
if self.get_argument("openid.mode", None):
self.get_authenticated_user(self.on_user)
return
res = self.authenticate_redirect()
assert isinstance(res, Future)
assert res.done()
def on_user(self, user):
if user is None:
raise Exception("user is None")
self.finish(user)
def get_auth_http_client(self):
return self.settings['http_client']
class AuthTest(AsyncHTTPTestCase):
def get_app(self):
return Application(
[
# test endpoints
('/openid/client/login', OpenIdClientLoginHandler, dict(test=self)),
('/oauth10/client/login', OAuth1ClientLoginHandler,
dict(test=self, version='1.0')),
('/oauth10/client/request_params',
OAuth1ClientRequestParametersHandler,
dict(version='1.0')),
('/oauth10a/client/login', OAuth1ClientLoginHandler,
dict(test=self, version='1.0a')),
('/oauth10a/client/request_params',
OAuth1ClientRequestParametersHandler,
dict(version='1.0a')),
('/oauth2/client/login', OAuth2ClientLoginHandler, dict(test=self)),
('/twitter/client/login', TwitterClientLoginHandler, dict(test=self)),
('/twitter/client/login_gen_engine', TwitterClientLoginGenEngineHandler, dict(test=self)),
('/twitter/client/login_gen_coroutine', TwitterClientLoginGenCoroutineHandler, dict(test=self)),
('/twitter/client/show_user', TwitterClientShowUserHandler, dict(test=self)),
('/twitter/client/show_user_future', TwitterClientShowUserFutureHandler, dict(test=self)),
('/google/client/openid_login', GoogleOpenIdClientLoginHandler, dict(test=self)),
# simulated servers
('/openid/server/authenticate', OpenIdServerAuthenticateHandler),
('/oauth1/server/request_token', OAuth1ServerRequestTokenHandler),
('/oauth1/server/access_token', OAuth1ServerAccessTokenHandler),
('/twitter/server/access_token', TwitterServerAccessTokenHandler),
(r'/twitter/api/users/show/(.*)\.json', TwitterServerShowUserHandler),
(r'/twitter/api/account/verify_credentials\.json', TwitterServerVerifyCredentialsHandler),
],
http_client=self.http_client,
twitter_consumer_key='test_twitter_consumer_key',
twitter_consumer_secret='test_twitter_consumer_secret')
def test_openid_redirect(self):
response = self.fetch('/openid/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(
'/openid/server/authenticate?' in response.headers['Location'])
def test_openid_get_user(self):
response = self.fetch('/openid/client/login?openid.mode=blah&openid.ns.ax=http://openid.net/srv/ax/1.0&openid.ax.type.email=http://axschema.org/contact/email&openid.ax.value.email=foo@example.com')
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["email"], "foo@example.com")
def test_oauth10_redirect(self):
response = self.fetch('/oauth10/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(response.headers['Location'].endswith(
'/oauth1/server/authorize?oauth_token=zxcv'))
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
response.headers['Set-Cookie'])
def test_oauth10_get_user(self):
response = self.fetch(
'/oauth10/client/login?oauth_token=zxcv',
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed['email'], 'foo@example.com')
self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
def test_oauth10_request_parameters(self):
response = self.fetch('/oauth10/client/request_params')
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed['oauth_consumer_key'], 'asdf')
self.assertEqual(parsed['oauth_token'], 'uiop')
self.assertTrue('oauth_nonce' in parsed)
self.assertTrue('oauth_signature' in parsed)
def test_oauth10a_redirect(self):
response = self.fetch('/oauth10a/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(response.headers['Location'].endswith(
'/oauth1/server/authorize?oauth_token=zxcv'))
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
response.headers['Set-Cookie'])
def test_oauth10a_get_user(self):
response = self.fetch(
'/oauth10a/client/login?oauth_token=zxcv',
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed['email'], 'foo@example.com')
self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
def test_oauth10a_request_parameters(self):
response = self.fetch('/oauth10a/client/request_params')
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed['oauth_consumer_key'], 'asdf')
self.assertEqual(parsed['oauth_token'], 'uiop')
self.assertTrue('oauth_nonce' in parsed)
self.assertTrue('oauth_signature' in parsed)
def test_oauth2_redirect(self):
response = self.fetch('/oauth2/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue('/oauth2/server/authorize?' in response.headers['Location'])
def base_twitter_redirect(self, url):
# Same as test_oauth10a_redirect
response = self.fetch(url, follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(response.headers['Location'].endswith(
'/oauth1/server/authorize?oauth_token=zxcv'))
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
response.headers['Set-Cookie'])
def test_twitter_redirect(self):
self.base_twitter_redirect('/twitter/client/login')
def test_twitter_redirect_gen_engine(self):
self.base_twitter_redirect('/twitter/client/login_gen_engine')
def test_twitter_redirect_gen_coroutine(self):
self.base_twitter_redirect('/twitter/client/login_gen_coroutine')
def test_twitter_get_user(self):
response = self.fetch(
'/twitter/client/login?oauth_token=zxcv',
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed,
{u('access_token'): {u('key'): u('hjkl'),
u('screen_name'): u('foo'),
u('secret'): u('vbnm')},
u('name'): u('Foo'),
u('screen_name'): u('foo'),
u('username'): u('foo')})
def test_twitter_show_user(self):
response = self.fetch('/twitter/client/show_user?name=somebody')
response.rethrow()
self.assertEqual(json_decode(response.body),
{'name': 'Somebody', 'screen_name': 'somebody'})
def test_twitter_show_user_error(self):
with ExpectLog(gen_log, 'Error response HTTP 500'):
response = self.fetch('/twitter/client/show_user?name=error')
self.assertEqual(response.code, 500)
self.assertEqual(response.body, b'error from twitter request')
def test_twitter_show_user_future(self):
response = self.fetch('/twitter/client/show_user_future?name=somebody')
response.rethrow()
self.assertEqual(json_decode(response.body),
{'name': 'Somebody', 'screen_name': 'somebody'})
def test_twitter_show_user_future_error(self):
response = self.fetch('/twitter/client/show_user_future?name=error')
self.assertEqual(response.code, 500)
self.assertIn(b'Error response HTTP 500', response.body)
def test_google_redirect(self):
# same as test_openid_redirect
response = self.fetch('/google/client/openid_login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(
'/openid/server/authenticate?' in response.headers['Location'])
def test_google_get_user(self):
response = self.fetch('/google/client/openid_login?openid.mode=blah&openid.ns.ax=http://openid.net/srv/ax/1.0&openid.ax.type.email=http://axschema.org/contact/email&openid.ax.value.email=foo@example.com', follow_redirects=False)
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["email"], "foo@example.com")

View file

@ -0,0 +1,330 @@
#!/usr/bin/env python
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, with_statement
import logging
import re
import socket
import sys
import traceback
from tornado.concurrent import Future, return_future, ReturnValueIgnoredError
from tornado.escape import utf8, to_unicode
from tornado import gen
from tornado.iostream import IOStream
from tornado import stack_context
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, LogTrapTestCase, bind_unused_port, gen_test
class ReturnFutureTest(AsyncTestCase):
@return_future
def sync_future(self, callback):
callback(42)
@return_future
def async_future(self, callback):
self.io_loop.add_callback(callback, 42)
@return_future
def immediate_failure(self, callback):
1 / 0
@return_future
def delayed_failure(self, callback):
self.io_loop.add_callback(lambda: 1 / 0)
@return_future
def return_value(self, callback):
# Note that the result of both running the callback and returning
# a value (or raising an exception) is unspecified; with current
# implementations the last event prior to callback resolution wins.
return 42
@return_future
def no_result_future(self, callback):
callback()
def test_immediate_failure(self):
with self.assertRaises(ZeroDivisionError):
# The caller sees the error just like a normal function.
self.immediate_failure(callback=self.stop)
# The callback is not run because the function failed synchronously.
self.io_loop.add_timeout(self.io_loop.time() + 0.05, self.stop)
result = self.wait()
self.assertIs(result, None)
def test_return_value(self):
with self.assertRaises(ReturnValueIgnoredError):
self.return_value(callback=self.stop)
def test_callback_kw(self):
future = self.sync_future(callback=self.stop)
result = self.wait()
self.assertEqual(result, 42)
self.assertEqual(future.result(), 42)
def test_callback_positional(self):
# When the callback is passed in positionally, future_wrap shouldn't
# add another callback in the kwargs.
future = self.sync_future(self.stop)
result = self.wait()
self.assertEqual(result, 42)
self.assertEqual(future.result(), 42)
def test_no_callback(self):
future = self.sync_future()
self.assertEqual(future.result(), 42)
def test_none_callback_kw(self):
# explicitly pass None as callback
future = self.sync_future(callback=None)
self.assertEqual(future.result(), 42)
def test_none_callback_pos(self):
future = self.sync_future(None)
self.assertEqual(future.result(), 42)
def test_async_future(self):
future = self.async_future()
self.assertFalse(future.done())
self.io_loop.add_future(future, self.stop)
future2 = self.wait()
self.assertIs(future, future2)
self.assertEqual(future.result(), 42)
@gen_test
def test_async_future_gen(self):
result = yield self.async_future()
self.assertEqual(result, 42)
def test_delayed_failure(self):
future = self.delayed_failure()
self.io_loop.add_future(future, self.stop)
future2 = self.wait()
self.assertIs(future, future2)
with self.assertRaises(ZeroDivisionError):
future.result()
def test_kw_only_callback(self):
@return_future
def f(**kwargs):
kwargs['callback'](42)
future = f()
self.assertEqual(future.result(), 42)
def test_error_in_callback(self):
self.sync_future(callback=lambda future: 1 / 0)
# The exception gets caught by our StackContext and will be re-raised
# when we wait.
self.assertRaises(ZeroDivisionError, self.wait)
def test_no_result_future(self):
future = self.no_result_future(self.stop)
result = self.wait()
self.assertIs(result, None)
# result of this future is undefined, but not an error
future.result()
def test_no_result_future_callback(self):
future = self.no_result_future(callback=lambda: self.stop())
result = self.wait()
self.assertIs(result, None)
future.result()
@gen_test
def test_future_traceback(self):
@return_future
@gen.engine
def f(callback):
yield gen.Task(self.io_loop.add_callback)
try:
1 / 0
except ZeroDivisionError:
self.expected_frame = traceback.extract_tb(
sys.exc_info()[2], limit=1)[0]
raise
try:
yield f()
self.fail("didn't get expected exception")
except ZeroDivisionError:
tb = traceback.extract_tb(sys.exc_info()[2])
self.assertIn(self.expected_frame, tb)
# The following series of classes demonstrate and test various styles
# of use, with and without generators and futures.
class CapServer(TCPServer):
def handle_stream(self, stream, address):
logging.info("handle_stream")
self.stream = stream
self.stream.read_until(b"\n", self.handle_read)
def handle_read(self, data):
logging.info("handle_read")
data = to_unicode(data)
if data == data.upper():
self.stream.write(b"error\talready capitalized\n")
else:
# data already has \n
self.stream.write(utf8("ok\t%s" % data.upper()))
self.stream.close()
class CapError(Exception):
pass
class BaseCapClient(object):
def __init__(self, port, io_loop):
self.port = port
self.io_loop = io_loop
def process_response(self, data):
status, message = re.match('(.*)\t(.*)\n', to_unicode(data)).groups()
if status == 'ok':
return message
else:
raise CapError(message)
class ManualCapClient(BaseCapClient):
def capitalize(self, request_data, callback=None):
logging.info("capitalize")
self.request_data = request_data
self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
self.stream.connect(('127.0.0.1', self.port),
callback=self.handle_connect)
self.future = Future()
if callback is not None:
self.future.add_done_callback(
stack_context.wrap(lambda future: callback(future.result())))
return self.future
def handle_connect(self):
logging.info("handle_connect")
self.stream.write(utf8(self.request_data + "\n"))
self.stream.read_until(b'\n', callback=self.handle_read)
def handle_read(self, data):
logging.info("handle_read")
self.stream.close()
try:
self.future.set_result(self.process_response(data))
except CapError as e:
self.future.set_exception(e)
class DecoratorCapClient(BaseCapClient):
@return_future
def capitalize(self, request_data, callback):
logging.info("capitalize")
self.request_data = request_data
self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
self.stream.connect(('127.0.0.1', self.port),
callback=self.handle_connect)
self.callback = callback
def handle_connect(self):
logging.info("handle_connect")
self.stream.write(utf8(self.request_data + "\n"))
self.stream.read_until(b'\n', callback=self.handle_read)
def handle_read(self, data):
logging.info("handle_read")
self.stream.close()
self.callback(self.process_response(data))
class GeneratorCapClient(BaseCapClient):
@return_future
@gen.engine
def capitalize(self, request_data, callback):
logging.info('capitalize')
stream = IOStream(socket.socket(), io_loop=self.io_loop)
logging.info('connecting')
yield gen.Task(stream.connect, ('127.0.0.1', self.port))
stream.write(utf8(request_data + '\n'))
logging.info('reading')
data = yield gen.Task(stream.read_until, b'\n')
logging.info('returning')
stream.close()
callback(self.process_response(data))
class ClientTestMixin(object):
def setUp(self):
super(ClientTestMixin, self).setUp()
self.server = CapServer(io_loop=self.io_loop)
sock, port = bind_unused_port()
self.server.add_sockets([sock])
self.client = self.client_class(io_loop=self.io_loop, port=port)
def tearDown(self):
self.server.stop()
super(ClientTestMixin, self).tearDown()
def test_callback(self):
self.client.capitalize("hello", callback=self.stop)
result = self.wait()
self.assertEqual(result, "HELLO")
def test_callback_error(self):
self.client.capitalize("HELLO", callback=self.stop)
self.assertRaisesRegexp(CapError, "already capitalized", self.wait)
def test_future(self):
future = self.client.capitalize("hello")
self.io_loop.add_future(future, self.stop)
self.wait()
self.assertEqual(future.result(), "HELLO")
def test_future_error(self):
future = self.client.capitalize("HELLO")
self.io_loop.add_future(future, self.stop)
self.wait()
self.assertRaisesRegexp(CapError, "already capitalized", future.result)
def test_generator(self):
@gen.engine
def f():
result = yield self.client.capitalize("hello")
self.assertEqual(result, "HELLO")
self.stop()
f()
self.wait()
def test_generator_error(self):
@gen.engine
def f():
with self.assertRaisesRegexp(CapError, "already capitalized"):
yield self.client.capitalize("HELLO")
self.stop()
f()
self.wait()
class ManualClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
client_class = ManualCapClient
class DecoratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
client_class = DecoratorCapClient
class GeneratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
client_class = GeneratorCapClient

View file

@ -0,0 +1 @@
"school","école"
1 school école

View file

@ -0,0 +1,99 @@
from __future__ import absolute_import, division, print_function, with_statement
from hashlib import md5
from tornado.httpclient import HTTPRequest
from tornado.stack_context import ExceptionStackContext
from tornado.testing import AsyncHTTPTestCase
from tornado.test import httpclient_test
from tornado.test.util import unittest
from tornado.web import Application, RequestHandler
try:
import pycurl
except ImportError:
pycurl = None
if pycurl is not None:
from tornado.curl_httpclient import CurlAsyncHTTPClient
@unittest.skipIf(pycurl is None, "pycurl module not present")
class CurlHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
def get_http_client(self):
client = CurlAsyncHTTPClient(io_loop=self.io_loop)
# make sure AsyncHTTPClient magic doesn't give us the wrong class
self.assertTrue(isinstance(client, CurlAsyncHTTPClient))
return client
class DigestAuthHandler(RequestHandler):
def get(self):
realm = 'test'
opaque = 'asdf'
# Real implementations would use a random nonce.
nonce = "1234"
username = 'foo'
password = 'bar'
auth_header = self.request.headers.get('Authorization', None)
if auth_header is not None:
auth_mode, params = auth_header.split(' ', 1)
assert auth_mode == 'Digest'
param_dict = {}
for pair in params.split(','):
k, v = pair.strip().split('=', 1)
if v[0] == '"' and v[-1] == '"':
v = v[1:-1]
param_dict[k] = v
assert param_dict['realm'] == realm
assert param_dict['opaque'] == opaque
assert param_dict['nonce'] == nonce
assert param_dict['username'] == username
assert param_dict['uri'] == self.request.path
h1 = md5('%s:%s:%s' % (username, realm, password)).hexdigest()
h2 = md5('%s:%s' % (self.request.method,
self.request.path)).hexdigest()
digest = md5('%s:%s:%s' % (h1, nonce, h2)).hexdigest()
if digest == param_dict['response']:
self.write('ok')
else:
self.write('fail')
else:
self.set_status(401)
self.set_header('WWW-Authenticate',
'Digest realm="%s", nonce="%s", opaque="%s"' %
(realm, nonce, opaque))
@unittest.skipIf(pycurl is None, "pycurl module not present")
class CurlHTTPClientTestCase(AsyncHTTPTestCase):
def setUp(self):
super(CurlHTTPClientTestCase, self).setUp()
self.http_client = CurlAsyncHTTPClient(self.io_loop)
def get_app(self):
return Application([
('/digest', DigestAuthHandler),
])
def test_prepare_curl_callback_stack_context(self):
exc_info = []
def error_handler(typ, value, tb):
exc_info.append((typ, value, tb))
self.stop()
return True
with ExceptionStackContext(error_handler):
request = HTTPRequest(self.get_url('/'),
prepare_curl_callback=lambda curl: 1 / 0)
self.http_client.fetch(request, callback=self.stop)
self.wait()
self.assertEqual(1, len(exc_info))
self.assertIs(exc_info[0][0], ZeroDivisionError)
def test_digest_auth(self):
response = self.fetch('/digest', auth_mode='digest',
auth_username='foo', auth_password='bar')
self.assertEqual(response.body, b'ok')

View file

@ -0,0 +1,217 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
import tornado.escape
from tornado.escape import utf8, xhtml_escape, xhtml_unescape, url_escape, url_unescape, to_unicode, json_decode, json_encode
from tornado.util import u, unicode_type, bytes_type
from tornado.test.util import unittest
linkify_tests = [
# (input, linkify_kwargs, expected_output)
("hello http://world.com/!", {},
u('hello <a href="http://world.com/">http://world.com/</a>!')),
("hello http://world.com/with?param=true&stuff=yes", {},
u('hello <a href="http://world.com/with?param=true&amp;stuff=yes">http://world.com/with?param=true&amp;stuff=yes</a>')),
# an opened paren followed by many chars killed Gruber's regex
("http://url.com/w(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", {},
u('<a href="http://url.com/w">http://url.com/w</a>(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa')),
# as did too many dots at the end
("http://url.com/withmany.......................................", {},
u('<a href="http://url.com/withmany">http://url.com/withmany</a>.......................................')),
("http://url.com/withmany((((((((((((((((((((((((((((((((((a)", {},
u('<a href="http://url.com/withmany">http://url.com/withmany</a>((((((((((((((((((((((((((((((((((a)')),
# some examples from http://daringfireball.net/2009/11/liberal_regex_for_matching_urls
# plus a fex extras (such as multiple parentheses).
("http://foo.com/blah_blah", {},
u('<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>')),
("http://foo.com/blah_blah/", {},
u('<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>')),
("(Something like http://foo.com/blah_blah)", {},
u('(Something like <a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>)')),
("http://foo.com/blah_blah_(wikipedia)", {},
u('<a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>')),
("http://foo.com/blah_(blah)_(wikipedia)_blah", {},
u('<a href="http://foo.com/blah_(blah)_(wikipedia)_blah">http://foo.com/blah_(blah)_(wikipedia)_blah</a>')),
("(Something like http://foo.com/blah_blah_(wikipedia))", {},
u('(Something like <a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>)')),
("http://foo.com/blah_blah.", {},
u('<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>.')),
("http://foo.com/blah_blah/.", {},
u('<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>.')),
("<http://foo.com/blah_blah>", {},
u('&lt;<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>&gt;')),
("<http://foo.com/blah_blah/>", {},
u('&lt;<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>&gt;')),
("http://foo.com/blah_blah,", {},
u('<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>,')),
("http://www.example.com/wpstyle/?p=364.", {},
u('<a href="http://www.example.com/wpstyle/?p=364">http://www.example.com/wpstyle/?p=364</a>.')),
("rdar://1234",
{"permitted_protocols": ["http", "rdar"]},
u('<a href="rdar://1234">rdar://1234</a>')),
("rdar:/1234",
{"permitted_protocols": ["rdar"]},
u('<a href="rdar:/1234">rdar:/1234</a>')),
("http://userid:password@example.com:8080", {},
u('<a href="http://userid:password@example.com:8080">http://userid:password@example.com:8080</a>')),
("http://userid@example.com", {},
u('<a href="http://userid@example.com">http://userid@example.com</a>')),
("http://userid@example.com:8080", {},
u('<a href="http://userid@example.com:8080">http://userid@example.com:8080</a>')),
("http://userid:password@example.com", {},
u('<a href="http://userid:password@example.com">http://userid:password@example.com</a>')),
("message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e",
{"permitted_protocols": ["http", "message"]},
u('<a href="message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e">message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e</a>')),
(u("http://\u27a1.ws/\u4a39"), {},
u('<a href="http://\u27a1.ws/\u4a39">http://\u27a1.ws/\u4a39</a>')),
("<tag>http://example.com</tag>", {},
u('&lt;tag&gt;<a href="http://example.com">http://example.com</a>&lt;/tag&gt;')),
("Just a www.example.com link.", {},
u('Just a <a href="http://www.example.com">www.example.com</a> link.')),
("Just a www.example.com link.",
{"require_protocol": True},
u('Just a www.example.com link.')),
("A http://reallylong.com/link/that/exceedsthelenglimit.html",
{"require_protocol": True, "shorten": True},
u('A <a href="http://reallylong.com/link/that/exceedsthelenglimit.html" title="http://reallylong.com/link/that/exceedsthelenglimit.html">http://reallylong.com/link...</a>')),
("A http://reallylongdomainnamethatwillbetoolong.com/hi!",
{"shorten": True},
u('A <a href="http://reallylongdomainnamethatwillbetoolong.com/hi" title="http://reallylongdomainnamethatwillbetoolong.com/hi">http://reallylongdomainnametha...</a>!')),
("A file:///passwords.txt and http://web.com link", {},
u('A file:///passwords.txt and <a href="http://web.com">http://web.com</a> link')),
("A file:///passwords.txt and http://web.com link",
{"permitted_protocols": ["file"]},
u('A <a href="file:///passwords.txt">file:///passwords.txt</a> and http://web.com link')),
("www.external-link.com",
{"extra_params": 'rel="nofollow" class="external"'},
u('<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>')),
("www.external-link.com and www.internal-link.com/blogs extra",
{"extra_params": lambda href: 'class="internal"' if href.startswith("http://www.internal-link.com") else 'rel="nofollow" class="external"'},
u('<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a> and <a href="http://www.internal-link.com/blogs" class="internal">www.internal-link.com/blogs</a> extra')),
("www.external-link.com",
{"extra_params": lambda href: ' rel="nofollow" class="external" '},
u('<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>')),
]
class EscapeTestCase(unittest.TestCase):
def test_linkify(self):
for text, kwargs, html in linkify_tests:
linked = tornado.escape.linkify(text, **kwargs)
self.assertEqual(linked, html)
def test_xhtml_escape(self):
tests = [
("<foo>", "&lt;foo&gt;"),
(u("<foo>"), u("&lt;foo&gt;")),
(b"<foo>", b"&lt;foo&gt;"),
("<>&\"", "&lt;&gt;&amp;&quot;"),
("&amp;", "&amp;amp;"),
(u("<\u00e9>"), u("&lt;\u00e9&gt;")),
(b"<\xc3\xa9>", b"&lt;\xc3\xa9&gt;"),
]
for unescaped, escaped in tests:
self.assertEqual(utf8(xhtml_escape(unescaped)), utf8(escaped))
self.assertEqual(utf8(unescaped), utf8(xhtml_unescape(escaped)))
def test_url_escape_unicode(self):
tests = [
# byte strings are passed through as-is
(u('\u00e9').encode('utf8'), '%C3%A9'),
(u('\u00e9').encode('latin1'), '%E9'),
# unicode strings become utf8
(u('\u00e9'), '%C3%A9'),
]
for unescaped, escaped in tests:
self.assertEqual(url_escape(unescaped), escaped)
def test_url_unescape_unicode(self):
tests = [
('%C3%A9', u('\u00e9'), 'utf8'),
('%C3%A9', u('\u00c3\u00a9'), 'latin1'),
('%C3%A9', utf8(u('\u00e9')), None),
]
for escaped, unescaped, encoding in tests:
# input strings to url_unescape should only contain ascii
# characters, but make sure the function accepts both byte
# and unicode strings.
self.assertEqual(url_unescape(to_unicode(escaped), encoding), unescaped)
self.assertEqual(url_unescape(utf8(escaped), encoding), unescaped)
def test_url_escape_quote_plus(self):
unescaped = '+ #%'
plus_escaped = '%2B+%23%25'
escaped = '%2B%20%23%25'
self.assertEqual(url_escape(unescaped), plus_escaped)
self.assertEqual(url_escape(unescaped, plus=False), escaped)
self.assertEqual(url_unescape(plus_escaped), unescaped)
self.assertEqual(url_unescape(escaped, plus=False), unescaped)
self.assertEqual(url_unescape(plus_escaped, encoding=None),
utf8(unescaped))
self.assertEqual(url_unescape(escaped, encoding=None, plus=False),
utf8(unescaped))
def test_escape_return_types(self):
# On python2 the escape methods should generally return the same
# type as their argument
self.assertEqual(type(xhtml_escape("foo")), str)
self.assertEqual(type(xhtml_escape(u("foo"))), unicode_type)
def test_json_decode(self):
# json_decode accepts both bytes and unicode, but strings it returns
# are always unicode.
self.assertEqual(json_decode(b'"foo"'), u("foo"))
self.assertEqual(json_decode(u('"foo"')), u("foo"))
# Non-ascii bytes are interpreted as utf8
self.assertEqual(json_decode(utf8(u('"\u00e9"'))), u("\u00e9"))
def test_json_encode(self):
# json deals with strings, not bytes. On python 2 byte strings will
# convert automatically if they are utf8; on python 3 byte strings
# are not allowed.
self.assertEqual(json_decode(json_encode(u("\u00e9"))), u("\u00e9"))
if bytes_type is str:
self.assertEqual(json_decode(json_encode(utf8(u("\u00e9")))), u("\u00e9"))
self.assertRaises(UnicodeDecodeError, json_encode, b"\xe9")

View file

@ -0,0 +1,913 @@
from __future__ import absolute_import, division, print_function, with_statement
import contextlib
import functools
import sys
import textwrap
import time
import platform
import weakref
from tornado.concurrent import return_future
from tornado.escape import url_escape
from tornado.httpclient import AsyncHTTPClient
from tornado.ioloop import IOLoop
from tornado.log import app_log
from tornado import stack_context
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
from tornado.test.util import unittest, skipOnTravis
from tornado.web import Application, RequestHandler, asynchronous, HTTPError
from tornado import gen
skipBefore33 = unittest.skipIf(sys.version_info < (3, 3), 'PEP 380 not available')
skipNotCPython = unittest.skipIf(platform.python_implementation() != 'CPython',
'Not CPython implementation')
class GenEngineTest(AsyncTestCase):
def setUp(self):
super(GenEngineTest, self).setUp()
self.named_contexts = []
def named_context(self, name):
@contextlib.contextmanager
def context():
self.named_contexts.append(name)
try:
yield
finally:
self.assertEqual(self.named_contexts.pop(), name)
return context
def run_gen(self, f):
f()
return self.wait()
def delay_callback(self, iterations, callback, arg):
"""Runs callback(arg) after a number of IOLoop iterations."""
if iterations == 0:
callback(arg)
else:
self.io_loop.add_callback(functools.partial(
self.delay_callback, iterations - 1, callback, arg))
@return_future
def async_future(self, result, callback):
self.io_loop.add_callback(callback, result)
def test_no_yield(self):
@gen.engine
def f():
self.stop()
self.run_gen(f)
def test_inline_cb(self):
@gen.engine
def f():
(yield gen.Callback("k1"))()
res = yield gen.Wait("k1")
self.assertTrue(res is None)
self.stop()
self.run_gen(f)
def test_ioloop_cb(self):
@gen.engine
def f():
self.io_loop.add_callback((yield gen.Callback("k1")))
yield gen.Wait("k1")
self.stop()
self.run_gen(f)
def test_exception_phase1(self):
@gen.engine
def f():
1 / 0
self.assertRaises(ZeroDivisionError, self.run_gen, f)
def test_exception_phase2(self):
@gen.engine
def f():
self.io_loop.add_callback((yield gen.Callback("k1")))
yield gen.Wait("k1")
1 / 0
self.assertRaises(ZeroDivisionError, self.run_gen, f)
def test_exception_in_task_phase1(self):
def fail_task(callback):
1 / 0
@gen.engine
def f():
try:
yield gen.Task(fail_task)
raise Exception("did not get expected exception")
except ZeroDivisionError:
self.stop()
self.run_gen(f)
def test_exception_in_task_phase2(self):
# This is the case that requires the use of stack_context in gen.engine
def fail_task(callback):
self.io_loop.add_callback(lambda: 1 / 0)
@gen.engine
def f():
try:
yield gen.Task(fail_task)
raise Exception("did not get expected exception")
except ZeroDivisionError:
self.stop()
self.run_gen(f)
def test_with_arg(self):
@gen.engine
def f():
(yield gen.Callback("k1"))(42)
res = yield gen.Wait("k1")
self.assertEqual(42, res)
self.stop()
self.run_gen(f)
def test_with_arg_tuple(self):
@gen.engine
def f():
(yield gen.Callback((1, 2)))((3, 4))
res = yield gen.Wait((1, 2))
self.assertEqual((3, 4), res)
self.stop()
self.run_gen(f)
def test_key_reuse(self):
@gen.engine
def f():
yield gen.Callback("k1")
yield gen.Callback("k1")
self.stop()
self.assertRaises(gen.KeyReuseError, self.run_gen, f)
def test_key_reuse_tuple(self):
@gen.engine
def f():
yield gen.Callback((1, 2))
yield gen.Callback((1, 2))
self.stop()
self.assertRaises(gen.KeyReuseError, self.run_gen, f)
def test_key_mismatch(self):
@gen.engine
def f():
yield gen.Callback("k1")
yield gen.Wait("k2")
self.stop()
self.assertRaises(gen.UnknownKeyError, self.run_gen, f)
def test_key_mismatch_tuple(self):
@gen.engine
def f():
yield gen.Callback((1, 2))
yield gen.Wait((2, 3))
self.stop()
self.assertRaises(gen.UnknownKeyError, self.run_gen, f)
def test_leaked_callback(self):
@gen.engine
def f():
yield gen.Callback("k1")
self.stop()
self.assertRaises(gen.LeakedCallbackError, self.run_gen, f)
def test_leaked_callback_tuple(self):
@gen.engine
def f():
yield gen.Callback((1, 2))
self.stop()
self.assertRaises(gen.LeakedCallbackError, self.run_gen, f)
def test_parallel_callback(self):
@gen.engine
def f():
for k in range(3):
self.io_loop.add_callback((yield gen.Callback(k)))
yield gen.Wait(1)
self.io_loop.add_callback((yield gen.Callback(3)))
yield gen.Wait(0)
yield gen.Wait(3)
yield gen.Wait(2)
self.stop()
self.run_gen(f)
def test_bogus_yield(self):
@gen.engine
def f():
yield 42
self.assertRaises(gen.BadYieldError, self.run_gen, f)
def test_bogus_yield_tuple(self):
@gen.engine
def f():
yield (1, 2)
self.assertRaises(gen.BadYieldError, self.run_gen, f)
def test_reuse(self):
@gen.engine
def f():
self.io_loop.add_callback((yield gen.Callback(0)))
yield gen.Wait(0)
self.stop()
self.run_gen(f)
self.run_gen(f)
def test_task(self):
@gen.engine
def f():
yield gen.Task(self.io_loop.add_callback)
self.stop()
self.run_gen(f)
def test_wait_all(self):
@gen.engine
def f():
(yield gen.Callback("k1"))("v1")
(yield gen.Callback("k2"))("v2")
results = yield gen.WaitAll(["k1", "k2"])
self.assertEqual(results, ["v1", "v2"])
self.stop()
self.run_gen(f)
def test_exception_in_yield(self):
@gen.engine
def f():
try:
yield gen.Wait("k1")
raise Exception("did not get expected exception")
except gen.UnknownKeyError:
pass
self.stop()
self.run_gen(f)
def test_resume_after_exception_in_yield(self):
@gen.engine
def f():
try:
yield gen.Wait("k1")
raise Exception("did not get expected exception")
except gen.UnknownKeyError:
pass
(yield gen.Callback("k2"))("v2")
self.assertEqual((yield gen.Wait("k2")), "v2")
self.stop()
self.run_gen(f)
def test_orphaned_callback(self):
@gen.engine
def f():
self.orphaned_callback = yield gen.Callback(1)
try:
self.run_gen(f)
raise Exception("did not get expected exception")
except gen.LeakedCallbackError:
pass
self.orphaned_callback()
def test_multi(self):
@gen.engine
def f():
(yield gen.Callback("k1"))("v1")
(yield gen.Callback("k2"))("v2")
results = yield [gen.Wait("k1"), gen.Wait("k2")]
self.assertEqual(results, ["v1", "v2"])
self.stop()
self.run_gen(f)
def test_multi_delayed(self):
@gen.engine
def f():
# callbacks run at different times
responses = yield [
gen.Task(self.delay_callback, 3, arg="v1"),
gen.Task(self.delay_callback, 1, arg="v2"),
]
self.assertEqual(responses, ["v1", "v2"])
self.stop()
self.run_gen(f)
@skipOnTravis
@gen_test
def test_multi_performance(self):
# Yielding a list used to have quadratic performance; make
# sure a large list stays reasonable. On my laptop a list of
# 2000 used to take 1.8s, now it takes 0.12.
start = time.time()
yield [gen.Task(self.io_loop.add_callback) for i in range(2000)]
end = time.time()
self.assertLess(end - start, 1.0)
@gen_test
def test_future(self):
result = yield self.async_future(1)
self.assertEqual(result, 1)
@gen_test
def test_multi_future(self):
results = yield [self.async_future(1), self.async_future(2)]
self.assertEqual(results, [1, 2])
def test_arguments(self):
@gen.engine
def f():
(yield gen.Callback("noargs"))()
self.assertEqual((yield gen.Wait("noargs")), None)
(yield gen.Callback("1arg"))(42)
self.assertEqual((yield gen.Wait("1arg")), 42)
(yield gen.Callback("kwargs"))(value=42)
result = yield gen.Wait("kwargs")
self.assertTrue(isinstance(result, gen.Arguments))
self.assertEqual(((), dict(value=42)), result)
self.assertEqual(dict(value=42), result.kwargs)
(yield gen.Callback("2args"))(42, 43)
result = yield gen.Wait("2args")
self.assertTrue(isinstance(result, gen.Arguments))
self.assertEqual(((42, 43), {}), result)
self.assertEqual((42, 43), result.args)
def task_func(callback):
callback(None, error="foo")
result = yield gen.Task(task_func)
self.assertTrue(isinstance(result, gen.Arguments))
self.assertEqual(((None,), dict(error="foo")), result)
self.stop()
self.run_gen(f)
def test_stack_context_leak(self):
# regression test: repeated invocations of a gen-based
# function should not result in accumulated stack_contexts
def _stack_depth():
head = stack_context._state.contexts[1]
length = 0
while head is not None:
length += 1
head = head.old_contexts[1]
return length
@gen.engine
def inner(callback):
yield gen.Task(self.io_loop.add_callback)
callback()
@gen.engine
def outer():
for i in range(10):
yield gen.Task(inner)
stack_increase = _stack_depth() - initial_stack_depth
self.assertTrue(stack_increase <= 2)
self.stop()
initial_stack_depth = _stack_depth()
self.run_gen(outer)
def test_stack_context_leak_exception(self):
# same as previous, but with a function that exits with an exception
@gen.engine
def inner(callback):
yield gen.Task(self.io_loop.add_callback)
1 / 0
@gen.engine
def outer():
for i in range(10):
try:
yield gen.Task(inner)
except ZeroDivisionError:
pass
stack_increase = len(stack_context._state.contexts) - initial_stack_depth
self.assertTrue(stack_increase <= 2)
self.stop()
initial_stack_depth = len(stack_context._state.contexts)
self.run_gen(outer)
def function_with_stack_context(self, callback):
# Technically this function should stack_context.wrap its callback
# upon entry. However, it is very common for this step to be
# omitted.
def step2():
self.assertEqual(self.named_contexts, ['a'])
self.io_loop.add_callback(callback)
with stack_context.StackContext(self.named_context('a')):
self.io_loop.add_callback(step2)
@gen_test
def test_wait_transfer_stack_context(self):
# Wait should not pick up contexts from where callback was invoked,
# even if that function improperly fails to wrap its callback.
cb = yield gen.Callback('k1')
self.function_with_stack_context(cb)
self.assertEqual(self.named_contexts, [])
yield gen.Wait('k1')
self.assertEqual(self.named_contexts, [])
@gen_test
def test_task_transfer_stack_context(self):
yield gen.Task(self.function_with_stack_context)
self.assertEqual(self.named_contexts, [])
def test_raise_after_stop(self):
# This pattern will be used in the following tests so make sure
# the exception propagates as expected.
@gen.engine
def f():
self.stop()
1 / 0
with self.assertRaises(ZeroDivisionError):
self.run_gen(f)
def test_sync_raise_return(self):
# gen.Return is allowed in @gen.engine, but it may not be used
# to return a value.
@gen.engine
def f():
self.stop(42)
raise gen.Return()
result = self.run_gen(f)
self.assertEqual(result, 42)
def test_async_raise_return(self):
@gen.engine
def f():
yield gen.Task(self.io_loop.add_callback)
self.stop(42)
raise gen.Return()
result = self.run_gen(f)
self.assertEqual(result, 42)
def test_sync_raise_return_value(self):
@gen.engine
def f():
raise gen.Return(42)
with self.assertRaises(gen.ReturnValueIgnoredError):
self.run_gen(f)
def test_sync_raise_return_value_tuple(self):
@gen.engine
def f():
raise gen.Return((1, 2))
with self.assertRaises(gen.ReturnValueIgnoredError):
self.run_gen(f)
def test_async_raise_return_value(self):
@gen.engine
def f():
yield gen.Task(self.io_loop.add_callback)
raise gen.Return(42)
with self.assertRaises(gen.ReturnValueIgnoredError):
self.run_gen(f)
def test_async_raise_return_value_tuple(self):
@gen.engine
def f():
yield gen.Task(self.io_loop.add_callback)
raise gen.Return((1, 2))
with self.assertRaises(gen.ReturnValueIgnoredError):
self.run_gen(f)
def test_return_value(self):
# It is an error to apply @gen.engine to a function that returns
# a value.
@gen.engine
def f():
return 42
with self.assertRaises(gen.ReturnValueIgnoredError):
self.run_gen(f)
def test_return_value_tuple(self):
# It is an error to apply @gen.engine to a function that returns
# a value.
@gen.engine
def f():
return (1, 2)
with self.assertRaises(gen.ReturnValueIgnoredError):
self.run_gen(f)
@skipNotCPython
def test_task_refcounting(self):
# On CPython, tasks and their arguments should be released immediately
# without waiting for garbage collection.
@gen.engine
def f():
class Foo(object):
pass
arg = Foo()
self.arg_ref = weakref.ref(arg)
task = gen.Task(self.io_loop.add_callback, arg=arg)
self.task_ref = weakref.ref(task)
yield task
self.stop()
self.run_gen(f)
self.assertIs(self.arg_ref(), None)
self.assertIs(self.task_ref(), None)
class GenCoroutineTest(AsyncTestCase):
def setUp(self):
# Stray StopIteration exceptions can lead to tests exiting prematurely,
# so we need explicit checks here to make sure the tests run all
# the way through.
self.finished = False
super(GenCoroutineTest, self).setUp()
def tearDown(self):
super(GenCoroutineTest, self).tearDown()
assert self.finished
@gen_test
def test_sync_gen_return(self):
@gen.coroutine
def f():
raise gen.Return(42)
result = yield f()
self.assertEqual(result, 42)
self.finished = True
@gen_test
def test_async_gen_return(self):
@gen.coroutine
def f():
yield gen.Task(self.io_loop.add_callback)
raise gen.Return(42)
result = yield f()
self.assertEqual(result, 42)
self.finished = True
@gen_test
def test_sync_return(self):
@gen.coroutine
def f():
return 42
result = yield f()
self.assertEqual(result, 42)
self.finished = True
@skipBefore33
@gen_test
def test_async_return(self):
# It is a compile-time error to return a value in a generator
# before Python 3.3, so we must test this with exec.
# Flatten the real global and local namespace into our fake globals:
# it's all global from the perspective of f().
global_namespace = dict(globals(), **locals())
local_namespace = {}
exec(textwrap.dedent("""
@gen.coroutine
def f():
yield gen.Task(self.io_loop.add_callback)
return 42
"""), global_namespace, local_namespace)
result = yield local_namespace['f']()
self.assertEqual(result, 42)
self.finished = True
@skipBefore33
@gen_test
def test_async_early_return(self):
# A yield statement exists but is not executed, which means
# this function "returns" via an exception. This exception
# doesn't happen before the exception handling is set up.
global_namespace = dict(globals(), **locals())
local_namespace = {}
exec(textwrap.dedent("""
@gen.coroutine
def f():
if True:
return 42
yield gen.Task(self.io_loop.add_callback)
"""), global_namespace, local_namespace)
result = yield local_namespace['f']()
self.assertEqual(result, 42)
self.finished = True
@gen_test
def test_sync_return_no_value(self):
@gen.coroutine
def f():
return
result = yield f()
self.assertEqual(result, None)
self.finished = True
@gen_test
def test_async_return_no_value(self):
# Without a return value we don't need python 3.3.
@gen.coroutine
def f():
yield gen.Task(self.io_loop.add_callback)
return
result = yield f()
self.assertEqual(result, None)
self.finished = True
@gen_test
def test_sync_raise(self):
@gen.coroutine
def f():
1 / 0
# The exception is raised when the future is yielded
# (or equivalently when its result method is called),
# not when the function itself is called).
future = f()
with self.assertRaises(ZeroDivisionError):
yield future
self.finished = True
@gen_test
def test_async_raise(self):
@gen.coroutine
def f():
yield gen.Task(self.io_loop.add_callback)
1 / 0
future = f()
with self.assertRaises(ZeroDivisionError):
yield future
self.finished = True
@gen_test
def test_pass_callback(self):
@gen.coroutine
def f():
raise gen.Return(42)
result = yield gen.Task(f)
self.assertEqual(result, 42)
self.finished = True
@gen_test
def test_replace_yieldpoint_exception(self):
# Test exception handling: a coroutine can catch one exception
# raised by a yield point and raise a different one.
@gen.coroutine
def f1():
1 / 0
@gen.coroutine
def f2():
try:
yield f1()
except ZeroDivisionError:
raise KeyError()
future = f2()
with self.assertRaises(KeyError):
yield future
self.finished = True
@gen_test
def test_swallow_yieldpoint_exception(self):
# Test exception handling: a coroutine can catch an exception
# raised by a yield point and not raise a different one.
@gen.coroutine
def f1():
1 / 0
@gen.coroutine
def f2():
try:
yield f1()
except ZeroDivisionError:
raise gen.Return(42)
result = yield f2()
self.assertEqual(result, 42)
self.finished = True
@gen_test
def test_replace_context_exception(self):
# Test exception handling: exceptions thrown into the stack context
# can be caught and replaced.
@gen.coroutine
def f2():
self.io_loop.add_callback(lambda: 1 / 0)
try:
yield gen.Task(self.io_loop.add_timeout,
self.io_loop.time() + 10)
except ZeroDivisionError:
raise KeyError()
future = f2()
with self.assertRaises(KeyError):
yield future
self.finished = True
@gen_test
def test_swallow_context_exception(self):
# Test exception handling: exceptions thrown into the stack context
# can be caught and ignored.
@gen.coroutine
def f2():
self.io_loop.add_callback(lambda: 1 / 0)
try:
yield gen.Task(self.io_loop.add_timeout,
self.io_loop.time() + 10)
except ZeroDivisionError:
raise gen.Return(42)
result = yield f2()
self.assertEqual(result, 42)
self.finished = True
class GenSequenceHandler(RequestHandler):
@asynchronous
@gen.engine
def get(self):
self.io_loop = self.request.connection.stream.io_loop
self.io_loop.add_callback((yield gen.Callback("k1")))
yield gen.Wait("k1")
self.write("1")
self.io_loop.add_callback((yield gen.Callback("k2")))
yield gen.Wait("k2")
self.write("2")
# reuse an old key
self.io_loop.add_callback((yield gen.Callback("k1")))
yield gen.Wait("k1")
self.finish("3")
class GenCoroutineSequenceHandler(RequestHandler):
@gen.coroutine
def get(self):
self.io_loop = self.request.connection.stream.io_loop
self.io_loop.add_callback((yield gen.Callback("k1")))
yield gen.Wait("k1")
self.write("1")
self.io_loop.add_callback((yield gen.Callback("k2")))
yield gen.Wait("k2")
self.write("2")
# reuse an old key
self.io_loop.add_callback((yield gen.Callback("k1")))
yield gen.Wait("k1")
self.finish("3")
class GenCoroutineUnfinishedSequenceHandler(RequestHandler):
@asynchronous
@gen.coroutine
def get(self):
self.io_loop = self.request.connection.stream.io_loop
self.io_loop.add_callback((yield gen.Callback("k1")))
yield gen.Wait("k1")
self.write("1")
self.io_loop.add_callback((yield gen.Callback("k2")))
yield gen.Wait("k2")
self.write("2")
# reuse an old key
self.io_loop.add_callback((yield gen.Callback("k1")))
yield gen.Wait("k1")
# just write, don't finish
self.write("3")
class GenTaskHandler(RequestHandler):
@asynchronous
@gen.engine
def get(self):
io_loop = self.request.connection.stream.io_loop
client = AsyncHTTPClient(io_loop=io_loop)
response = yield gen.Task(client.fetch, self.get_argument('url'))
response.rethrow()
self.finish(b"got response: " + response.body)
class GenExceptionHandler(RequestHandler):
@asynchronous
@gen.engine
def get(self):
# This test depends on the order of the two decorators.
io_loop = self.request.connection.stream.io_loop
yield gen.Task(io_loop.add_callback)
raise Exception("oops")
class GenCoroutineExceptionHandler(RequestHandler):
@asynchronous
@gen.coroutine
def get(self):
# This test depends on the order of the two decorators.
io_loop = self.request.connection.stream.io_loop
yield gen.Task(io_loop.add_callback)
raise Exception("oops")
class GenYieldExceptionHandler(RequestHandler):
@asynchronous
@gen.engine
def get(self):
io_loop = self.request.connection.stream.io_loop
# Test the interaction of the two stack_contexts.
def fail_task(callback):
io_loop.add_callback(lambda: 1 / 0)
try:
yield gen.Task(fail_task)
raise Exception("did not get expected exception")
except ZeroDivisionError:
self.finish('ok')
class UndecoratedCoroutinesHandler(RequestHandler):
@gen.coroutine
def prepare(self):
self.chunks = []
yield gen.Task(IOLoop.current().add_callback)
self.chunks.append('1')
@gen.coroutine
def get(self):
self.chunks.append('2')
yield gen.Task(IOLoop.current().add_callback)
self.chunks.append('3')
yield gen.Task(IOLoop.current().add_callback)
self.write(''.join(self.chunks))
class AsyncPrepareErrorHandler(RequestHandler):
@gen.coroutine
def prepare(self):
yield gen.Task(IOLoop.current().add_callback)
raise HTTPError(403)
def get(self):
self.finish('ok')
class GenWebTest(AsyncHTTPTestCase):
def get_app(self):
return Application([
('/sequence', GenSequenceHandler),
('/coroutine_sequence', GenCoroutineSequenceHandler),
('/coroutine_unfinished_sequence',
GenCoroutineUnfinishedSequenceHandler),
('/task', GenTaskHandler),
('/exception', GenExceptionHandler),
('/coroutine_exception', GenCoroutineExceptionHandler),
('/yield_exception', GenYieldExceptionHandler),
('/undecorated_coroutine', UndecoratedCoroutinesHandler),
('/async_prepare_error', AsyncPrepareErrorHandler),
])
def test_sequence_handler(self):
response = self.fetch('/sequence')
self.assertEqual(response.body, b"123")
def test_coroutine_sequence_handler(self):
response = self.fetch('/coroutine_sequence')
self.assertEqual(response.body, b"123")
def test_coroutine_unfinished_sequence_handler(self):
response = self.fetch('/coroutine_unfinished_sequence')
self.assertEqual(response.body, b"123")
def test_task_handler(self):
response = self.fetch('/task?url=%s' % url_escape(self.get_url('/sequence')))
self.assertEqual(response.body, b"got response: 123")
def test_exception_handler(self):
# Make sure we get an error and not a timeout
with ExpectLog(app_log, "Uncaught exception GET /exception"):
response = self.fetch('/exception')
self.assertEqual(500, response.code)
def test_coroutine_exception_handler(self):
# Make sure we get an error and not a timeout
with ExpectLog(app_log, "Uncaught exception GET /coroutine_exception"):
response = self.fetch('/coroutine_exception')
self.assertEqual(500, response.code)
def test_yield_exception_handler(self):
response = self.fetch('/yield_exception')
self.assertEqual(response.body, b'ok')
def test_undecorated_coroutines(self):
response = self.fetch('/undecorated_coroutine')
self.assertEqual(response.body, b'123')
def test_async_prepare_error_handler(self):
response = self.fetch('/async_prepare_error')
self.assertEqual(response.code, 403)
if __name__ == '__main__':
unittest.main()

View file

@ -0,0 +1,22 @@
# SOME DESCRIPTIVE TITLE.
# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER
# This file is distributed under the same license as the PACKAGE package.
# FIRST AUTHOR <EMAIL@ADDRESS>, YEAR.
#
#, fuzzy
msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2012-06-14 01:10-0700\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n"
"Language: \n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
#: extract_me.py:1
msgid "school"
msgstr "école"

View file

@ -0,0 +1,471 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
import base64
import binascii
from contextlib import closing
import functools
import sys
import threading
from tornado.escape import utf8
from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream
from tornado.log import gen_log
from tornado import netutil
from tornado.stack_context import ExceptionStackContext, NullContext
from tornado.testing import AsyncHTTPTestCase, bind_unused_port, gen_test, ExpectLog
from tornado.test.util import unittest
from tornado.util import u, bytes_type
from tornado.web import Application, RequestHandler, url
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO
class HelloWorldHandler(RequestHandler):
def get(self):
name = self.get_argument("name", "world")
self.set_header("Content-Type", "text/plain")
self.finish("Hello %s!" % name)
class PostHandler(RequestHandler):
def post(self):
self.finish("Post arg1: %s, arg2: %s" % (
self.get_argument("arg1"), self.get_argument("arg2")))
class ChunkHandler(RequestHandler):
def get(self):
self.write("asdf")
self.flush()
self.write("qwer")
class AuthHandler(RequestHandler):
def get(self):
self.finish(self.request.headers["Authorization"])
class CountdownHandler(RequestHandler):
def get(self, count):
count = int(count)
if count > 0:
self.redirect(self.reverse_url("countdown", count - 1))
else:
self.write("Zero")
class EchoPostHandler(RequestHandler):
def post(self):
self.write(self.request.body)
class UserAgentHandler(RequestHandler):
def get(self):
self.write(self.request.headers.get('User-Agent', 'User agent not set'))
class ContentLength304Handler(RequestHandler):
def get(self):
self.set_status(304)
self.set_header('Content-Length', 42)
def _clear_headers_for_304(self):
# Tornado strips content-length from 304 responses, but here we
# want to simulate servers that include the headers anyway.
pass
class AllMethodsHandler(RequestHandler):
SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ('OTHER',)
def method(self):
self.write(self.request.method)
get = post = put = delete = options = patch = other = method
# These tests end up getting run redundantly: once here with the default
# HTTPClient implementation, and then again in each implementation's own
# test suite.
class HTTPClientCommonTestCase(AsyncHTTPTestCase):
def get_app(self):
return Application([
url("/hello", HelloWorldHandler),
url("/post", PostHandler),
url("/chunk", ChunkHandler),
url("/auth", AuthHandler),
url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
url("/echopost", EchoPostHandler),
url("/user_agent", UserAgentHandler),
url("/304_with_content_length", ContentLength304Handler),
url("/all_methods", AllMethodsHandler),
], gzip=True)
def test_hello_world(self):
response = self.fetch("/hello")
self.assertEqual(response.code, 200)
self.assertEqual(response.headers["Content-Type"], "text/plain")
self.assertEqual(response.body, b"Hello world!")
self.assertEqual(int(response.request_time), 0)
response = self.fetch("/hello?name=Ben")
self.assertEqual(response.body, b"Hello Ben!")
def test_streaming_callback(self):
# streaming_callback is also tested in test_chunked
chunks = []
response = self.fetch("/hello",
streaming_callback=chunks.append)
# with streaming_callback, data goes to the callback and not response.body
self.assertEqual(chunks, [b"Hello world!"])
self.assertFalse(response.body)
def test_post(self):
response = self.fetch("/post", method="POST",
body="arg1=foo&arg2=bar")
self.assertEqual(response.code, 200)
self.assertEqual(response.body, b"Post arg1: foo, arg2: bar")
def test_chunked(self):
response = self.fetch("/chunk")
self.assertEqual(response.body, b"asdfqwer")
chunks = []
response = self.fetch("/chunk",
streaming_callback=chunks.append)
self.assertEqual(chunks, [b"asdf", b"qwer"])
self.assertFalse(response.body)
def test_chunked_close(self):
# test case in which chunks spread read-callback processing
# over several ioloop iterations, but the connection is already closed.
sock, port = bind_unused_port()
with closing(sock):
def write_response(stream, request_data):
stream.write(b"""\
HTTP/1.1 200 OK
Transfer-Encoding: chunked
1
1
1
2
0
""".replace(b"\n", b"\r\n"), callback=stream.close)
def accept_callback(conn, address):
# fake an HTTP server using chunked encoding where the final chunks
# and connection close all happen at once
stream = IOStream(conn, io_loop=self.io_loop)
stream.read_until(b"\r\n\r\n",
functools.partial(write_response, stream))
netutil.add_accept_handler(sock, accept_callback, self.io_loop)
self.http_client.fetch("http://127.0.0.1:%d/" % port, self.stop)
resp = self.wait()
resp.rethrow()
self.assertEqual(resp.body, b"12")
self.io_loop.remove_handler(sock.fileno())
def test_streaming_stack_context(self):
chunks = []
exc_info = []
def error_handler(typ, value, tb):
exc_info.append((typ, value, tb))
return True
def streaming_cb(chunk):
chunks.append(chunk)
if chunk == b'qwer':
1 / 0
with ExceptionStackContext(error_handler):
self.fetch('/chunk', streaming_callback=streaming_cb)
self.assertEqual(chunks, [b'asdf', b'qwer'])
self.assertEqual(1, len(exc_info))
self.assertIs(exc_info[0][0], ZeroDivisionError)
def test_basic_auth(self):
self.assertEqual(self.fetch("/auth", auth_username="Aladdin",
auth_password="open sesame").body,
b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
def test_basic_auth_explicit_mode(self):
self.assertEqual(self.fetch("/auth", auth_username="Aladdin",
auth_password="open sesame",
auth_mode="basic").body,
b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
def test_unsupported_auth_mode(self):
# curl and simple clients handle errors a bit differently; the
# important thing is that they don't fall back to basic auth
# on an unknown mode.
with ExpectLog(gen_log, "uncaught exception", required=False):
with self.assertRaises((ValueError, HTTPError)):
response = self.fetch("/auth", auth_username="Aladdin",
auth_password="open sesame",
auth_mode="asdf")
response.rethrow()
def test_follow_redirect(self):
response = self.fetch("/countdown/2", follow_redirects=False)
self.assertEqual(302, response.code)
self.assertTrue(response.headers["Location"].endswith("/countdown/1"))
response = self.fetch("/countdown/2")
self.assertEqual(200, response.code)
self.assertTrue(response.effective_url.endswith("/countdown/0"))
self.assertEqual(b"Zero", response.body)
def test_credentials_in_url(self):
url = self.get_url("/auth").replace("http://", "http://me:secret@")
self.http_client.fetch(url, self.stop)
response = self.wait()
self.assertEqual(b"Basic " + base64.b64encode(b"me:secret"),
response.body)
def test_body_encoding(self):
unicode_body = u("\xe9")
byte_body = binascii.a2b_hex(b"e9")
# unicode string in body gets converted to utf8
response = self.fetch("/echopost", method="POST", body=unicode_body,
headers={"Content-Type": "application/blah"})
self.assertEqual(response.headers["Content-Length"], "2")
self.assertEqual(response.body, utf8(unicode_body))
# byte strings pass through directly
response = self.fetch("/echopost", method="POST",
body=byte_body,
headers={"Content-Type": "application/blah"})
self.assertEqual(response.headers["Content-Length"], "1")
self.assertEqual(response.body, byte_body)
# Mixing unicode in headers and byte string bodies shouldn't
# break anything
response = self.fetch("/echopost", method="POST", body=byte_body,
headers={"Content-Type": "application/blah"},
user_agent=u("foo"))
self.assertEqual(response.headers["Content-Length"], "1")
self.assertEqual(response.body, byte_body)
def test_types(self):
response = self.fetch("/hello")
self.assertEqual(type(response.body), bytes_type)
self.assertEqual(type(response.headers["Content-Type"]), str)
self.assertEqual(type(response.code), int)
self.assertEqual(type(response.effective_url), str)
def test_header_callback(self):
first_line = []
headers = {}
chunks = []
def header_callback(header_line):
if header_line.startswith('HTTP/'):
first_line.append(header_line)
elif header_line != '\r\n':
k, v = header_line.split(':', 1)
headers[k] = v.strip()
def streaming_callback(chunk):
# All header callbacks are run before any streaming callbacks,
# so the header data is available to process the data as it
# comes in.
self.assertEqual(headers['Content-Type'], 'text/html; charset=UTF-8')
chunks.append(chunk)
self.fetch('/chunk', header_callback=header_callback,
streaming_callback=streaming_callback)
self.assertEqual(len(first_line), 1)
self.assertRegexpMatches(first_line[0], 'HTTP/1.[01] 200 OK\r\n')
self.assertEqual(chunks, [b'asdf', b'qwer'])
def test_header_callback_stack_context(self):
exc_info = []
def error_handler(typ, value, tb):
exc_info.append((typ, value, tb))
return True
def header_callback(header_line):
if header_line.startswith('Content-Type:'):
1 / 0
with ExceptionStackContext(error_handler):
self.fetch('/chunk', header_callback=header_callback)
self.assertEqual(len(exc_info), 1)
self.assertIs(exc_info[0][0], ZeroDivisionError)
def test_configure_defaults(self):
defaults = dict(user_agent='TestDefaultUserAgent')
# Construct a new instance of the configured client class
client = self.http_client.__class__(self.io_loop, force_instance=True,
defaults=defaults)
client.fetch(self.get_url('/user_agent'), callback=self.stop)
response = self.wait()
self.assertEqual(response.body, b'TestDefaultUserAgent')
client.close()
def test_304_with_content_length(self):
# According to the spec 304 responses SHOULD NOT include
# Content-Length or other entity headers, but some servers do it
# anyway.
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.5
response = self.fetch('/304_with_content_length')
self.assertEqual(response.code, 304)
self.assertEqual(response.headers['Content-Length'], '42')
def test_final_callback_stack_context(self):
# The final callback should be run outside of the httpclient's
# stack_context. We want to ensure that there is not stack_context
# between the user's callback and the IOLoop, so monkey-patch
# IOLoop.handle_callback_exception and disable the test harness's
# context with a NullContext.
# Note that this does not apply to secondary callbacks (header
# and streaming_callback), as errors there must be seen as errors
# by the http client so it can clean up the connection.
exc_info = []
def handle_callback_exception(callback):
exc_info.append(sys.exc_info())
self.stop()
self.io_loop.handle_callback_exception = handle_callback_exception
with NullContext():
self.http_client.fetch(self.get_url('/hello'),
lambda response: 1 / 0)
self.wait()
self.assertEqual(exc_info[0][0], ZeroDivisionError)
@gen_test
def test_future_interface(self):
response = yield self.http_client.fetch(self.get_url('/hello'))
self.assertEqual(response.body, b'Hello world!')
@gen_test
def test_future_http_error(self):
try:
yield self.http_client.fetch(self.get_url('/notfound'))
except HTTPError as e:
self.assertEqual(e.code, 404)
self.assertEqual(e.response.code, 404)
@gen_test
def test_reuse_request_from_response(self):
# The response.request attribute should be an HTTPRequest, not
# a _RequestProxy.
# This test uses self.http_client.fetch because self.fetch calls
# self.get_url on the input unconditionally.
url = self.get_url('/hello')
response = yield self.http_client.fetch(url)
self.assertEqual(response.request.url, url)
self.assertTrue(isinstance(response.request, HTTPRequest))
response2 = yield self.http_client.fetch(response.request)
self.assertEqual(response2.body, b'Hello world!')
def test_all_methods(self):
for method in ['GET', 'DELETE', 'OPTIONS']:
response = self.fetch('/all_methods', method=method)
self.assertEqual(response.body, utf8(method))
for method in ['POST', 'PUT', 'PATCH']:
response = self.fetch('/all_methods', method=method, body=b'')
self.assertEqual(response.body, utf8(method))
response = self.fetch('/all_methods', method='HEAD')
self.assertEqual(response.body, b'')
response = self.fetch('/all_methods', method='OTHER',
allow_nonstandard_methods=True)
self.assertEqual(response.body, b'OTHER')
class RequestProxyTest(unittest.TestCase):
def test_request_set(self):
proxy = _RequestProxy(HTTPRequest('http://example.com/',
user_agent='foo'),
dict())
self.assertEqual(proxy.user_agent, 'foo')
def test_default_set(self):
proxy = _RequestProxy(HTTPRequest('http://example.com/'),
dict(network_interface='foo'))
self.assertEqual(proxy.network_interface, 'foo')
def test_both_set(self):
proxy = _RequestProxy(HTTPRequest('http://example.com/',
proxy_host='foo'),
dict(proxy_host='bar'))
self.assertEqual(proxy.proxy_host, 'foo')
def test_neither_set(self):
proxy = _RequestProxy(HTTPRequest('http://example.com/'),
dict())
self.assertIs(proxy.auth_username, None)
def test_bad_attribute(self):
proxy = _RequestProxy(HTTPRequest('http://example.com/'),
dict())
with self.assertRaises(AttributeError):
proxy.foo
def test_defaults_none(self):
proxy = _RequestProxy(HTTPRequest('http://example.com/'), None)
self.assertIs(proxy.auth_username, None)
class HTTPResponseTestCase(unittest.TestCase):
def test_str(self):
response = HTTPResponse(HTTPRequest('http://example.com'),
200, headers={}, buffer=BytesIO())
s = str(response)
self.assertTrue(s.startswith('HTTPResponse('))
self.assertIn('code=200', s)
class SyncHTTPClientTest(unittest.TestCase):
def setUp(self):
if IOLoop.configured_class().__name__ == 'TwistedIOLoop':
# TwistedIOLoop only supports the global reactor, so we can't have
# separate IOLoops for client and server threads.
raise unittest.SkipTest(
'Sync HTTPClient not compatible with TwistedIOLoop')
self.server_ioloop = IOLoop()
sock, self.port = bind_unused_port()
app = Application([('/', HelloWorldHandler)])
server = HTTPServer(app, io_loop=self.server_ioloop)
server.add_socket(sock)
self.server_thread = threading.Thread(target=self.server_ioloop.start)
self.server_thread.start()
self.http_client = HTTPClient()
def tearDown(self):
self.server_ioloop.add_callback(self.server_ioloop.stop)
self.server_thread.join()
self.http_client.close()
self.server_ioloop.close(all_fds=True)
def get_url(self, path):
return 'http://localhost:%d%s' % (self.port, path)
def test_sync_client(self):
response = self.http_client.fetch(self.get_url('/'))
self.assertEqual(b'Hello world!', response.body)
def test_sync_client_error(self):
# Synchronous HTTPClient raises errors directly; no need for
# response.rethrow()
with self.assertRaises(HTTPError) as assertion:
self.http_client.fetch(self.get_url('/notfound'))
self.assertEqual(assertion.exception.code, 404)

View file

@ -0,0 +1,661 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
from tornado import httpclient, simple_httpclient, netutil
from tornado.escape import json_decode, utf8, _unicode, recursive_unicode, native_str
from tornado.httpserver import HTTPServer
from tornado.httputil import HTTPHeaders
from tornado.iostream import IOStream
from tornado.log import gen_log
from tornado.netutil import ssl_options_to_context, Resolver
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog
from tornado.test.util import unittest
from tornado.util import u, bytes_type
from tornado.web import Application, RequestHandler, asynchronous
from contextlib import closing
import datetime
import os
import shutil
import socket
import ssl
import sys
import tempfile
class HandlerBaseTestCase(AsyncHTTPTestCase):
def get_app(self):
return Application([('/', self.__class__.Handler)])
def fetch_json(self, *args, **kwargs):
response = self.fetch(*args, **kwargs)
response.rethrow()
return json_decode(response.body)
class HelloWorldRequestHandler(RequestHandler):
def initialize(self, protocol="http"):
self.expected_protocol = protocol
def get(self):
if self.request.protocol != self.expected_protocol:
raise Exception("unexpected protocol")
self.finish("Hello world")
def post(self):
self.finish("Got %d bytes in POST" % len(self.request.body))
# In pre-1.0 versions of openssl, SSLv23 clients always send SSLv2
# ClientHello messages, which are rejected by SSLv3 and TLSv1
# servers. Note that while the OPENSSL_VERSION_INFO was formally
# introduced in python3.2, it was present but undocumented in
# python 2.7
skipIfOldSSL = unittest.skipIf(
getattr(ssl, 'OPENSSL_VERSION_INFO', (0, 0)) < (1, 0),
"old version of ssl module and/or openssl")
class BaseSSLTest(AsyncHTTPSTestCase):
def get_app(self):
return Application([('/', HelloWorldRequestHandler,
dict(protocol="https"))])
class SSLTestMixin(object):
def get_ssl_options(self):
return dict(ssl_version=self.get_ssl_version(),
**AsyncHTTPSTestCase.get_ssl_options())
def get_ssl_version(self):
raise NotImplementedError()
def test_ssl(self):
response = self.fetch('/')
self.assertEqual(response.body, b"Hello world")
def test_large_post(self):
response = self.fetch('/',
method='POST',
body='A' * 5000)
self.assertEqual(response.body, b"Got 5000 bytes in POST")
def test_non_ssl_request(self):
# Make sure the server closes the connection when it gets a non-ssl
# connection, rather than waiting for a timeout or otherwise
# misbehaving.
with ExpectLog(gen_log, '(SSL Error|uncaught exception)'):
self.http_client.fetch(self.get_url("/").replace('https:', 'http:'),
self.stop,
request_timeout=3600,
connect_timeout=3600)
response = self.wait()
self.assertEqual(response.code, 599)
# Python's SSL implementation differs significantly between versions.
# For example, SSLv3 and TLSv1 throw an exception if you try to read
# from the socket before the handshake is complete, but the default
# of SSLv23 allows it.
class SSLv23Test(BaseSSLTest, SSLTestMixin):
def get_ssl_version(self):
return ssl.PROTOCOL_SSLv23
@skipIfOldSSL
class SSLv3Test(BaseSSLTest, SSLTestMixin):
def get_ssl_version(self):
return ssl.PROTOCOL_SSLv3
@skipIfOldSSL
class TLSv1Test(BaseSSLTest, SSLTestMixin):
def get_ssl_version(self):
return ssl.PROTOCOL_TLSv1
@unittest.skipIf(not hasattr(ssl, 'SSLContext'), 'ssl.SSLContext not present')
class SSLContextTest(BaseSSLTest, SSLTestMixin):
def get_ssl_options(self):
context = ssl_options_to_context(
AsyncHTTPSTestCase.get_ssl_options(self))
assert isinstance(context, ssl.SSLContext)
return context
class BadSSLOptionsTest(unittest.TestCase):
def test_missing_arguments(self):
application = Application()
self.assertRaises(KeyError, HTTPServer, application, ssl_options={
"keyfile": "/__missing__.crt",
})
def test_missing_key(self):
"""A missing SSL key should cause an immediate exception."""
application = Application()
module_dir = os.path.dirname(__file__)
existing_certificate = os.path.join(module_dir, 'test.crt')
self.assertRaises(ValueError, HTTPServer, application, ssl_options={
"certfile": "/__mising__.crt",
})
self.assertRaises(ValueError, HTTPServer, application, ssl_options={
"certfile": existing_certificate,
"keyfile": "/__missing__.key"
})
# This actually works because both files exist
HTTPServer(application, ssl_options={
"certfile": existing_certificate,
"keyfile": existing_certificate
})
class MultipartTestHandler(RequestHandler):
def post(self):
self.finish({"header": self.request.headers["X-Header-Encoding-Test"],
"argument": self.get_argument("argument"),
"filename": self.request.files["files"][0].filename,
"filebody": _unicode(self.request.files["files"][0]["body"]),
})
class RawRequestHTTPConnection(simple_httpclient._HTTPConnection):
def set_request(self, request):
self.__next_request = request
def _on_connect(self):
self.stream.write(self.__next_request)
self.__next_request = None
self.stream.read_until(b"\r\n\r\n", self._on_headers)
# This test is also called from wsgi_test
class HTTPConnectionTest(AsyncHTTPTestCase):
def get_handlers(self):
return [("/multipart", MultipartTestHandler),
("/hello", HelloWorldRequestHandler)]
def get_app(self):
return Application(self.get_handlers())
def raw_fetch(self, headers, body):
with closing(Resolver(io_loop=self.io_loop)) as resolver:
with closing(SimpleAsyncHTTPClient(self.io_loop,
resolver=resolver)) as client:
conn = RawRequestHTTPConnection(
self.io_loop, client,
httpclient._RequestProxy(
httpclient.HTTPRequest(self.get_url("/")),
dict(httpclient.HTTPRequest._DEFAULTS)),
None, self.stop,
1024 * 1024, resolver)
conn.set_request(
b"\r\n".join(headers +
[utf8("Content-Length: %d\r\n" % len(body))]) +
b"\r\n" + body)
response = self.wait()
response.rethrow()
return response
def test_multipart_form(self):
# Encodings here are tricky: Headers are latin1, bodies can be
# anything (we use utf8 by default).
response = self.raw_fetch([
b"POST /multipart HTTP/1.0",
b"Content-Type: multipart/form-data; boundary=1234567890",
b"X-Header-encoding-test: \xe9",
],
b"\r\n".join([
b"Content-Disposition: form-data; name=argument",
b"",
u("\u00e1").encode("utf-8"),
b"--1234567890",
u('Content-Disposition: form-data; name="files"; filename="\u00f3"').encode("utf8"),
b"",
u("\u00fa").encode("utf-8"),
b"--1234567890--",
b"",
]))
data = json_decode(response.body)
self.assertEqual(u("\u00e9"), data["header"])
self.assertEqual(u("\u00e1"), data["argument"])
self.assertEqual(u("\u00f3"), data["filename"])
self.assertEqual(u("\u00fa"), data["filebody"])
def test_100_continue(self):
# Run through a 100-continue interaction by hand:
# When given Expect: 100-continue, we get a 100 response after the
# headers, and then the real response after the body.
stream = IOStream(socket.socket(), io_loop=self.io_loop)
stream.connect(("localhost", self.get_http_port()), callback=self.stop)
self.wait()
stream.write(b"\r\n".join([b"POST /hello HTTP/1.1",
b"Content-Length: 1024",
b"Expect: 100-continue",
b"Connection: close",
b"\r\n"]), callback=self.stop)
self.wait()
stream.read_until(b"\r\n\r\n", self.stop)
data = self.wait()
self.assertTrue(data.startswith(b"HTTP/1.1 100 "), data)
stream.write(b"a" * 1024)
stream.read_until(b"\r\n", self.stop)
first_line = self.wait()
self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
stream.read_until(b"\r\n\r\n", self.stop)
header_data = self.wait()
headers = HTTPHeaders.parse(native_str(header_data.decode('latin1')))
stream.read_bytes(int(headers["Content-Length"]), self.stop)
body = self.wait()
self.assertEqual(body, b"Got 1024 bytes in POST")
stream.close()
class EchoHandler(RequestHandler):
def get(self):
self.write(recursive_unicode(self.request.arguments))
def post(self):
self.write(recursive_unicode(self.request.arguments))
class TypeCheckHandler(RequestHandler):
def prepare(self):
self.errors = {}
fields = [
('method', str),
('uri', str),
('version', str),
('remote_ip', str),
('protocol', str),
('host', str),
('path', str),
('query', str),
]
for field, expected_type in fields:
self.check_type(field, getattr(self.request, field), expected_type)
self.check_type('header_key', list(self.request.headers.keys())[0], str)
self.check_type('header_value', list(self.request.headers.values())[0], str)
self.check_type('cookie_key', list(self.request.cookies.keys())[0], str)
self.check_type('cookie_value', list(self.request.cookies.values())[0].value, str)
# secure cookies
self.check_type('arg_key', list(self.request.arguments.keys())[0], str)
self.check_type('arg_value', list(self.request.arguments.values())[0][0], bytes_type)
def post(self):
self.check_type('body', self.request.body, bytes_type)
self.write(self.errors)
def get(self):
self.write(self.errors)
def check_type(self, name, obj, expected_type):
actual_type = type(obj)
if expected_type != actual_type:
self.errors[name] = "expected %s, got %s" % (expected_type,
actual_type)
class HTTPServerTest(AsyncHTTPTestCase):
def get_app(self):
return Application([("/echo", EchoHandler),
("/typecheck", TypeCheckHandler),
("//doubleslash", EchoHandler),
])
def test_query_string_encoding(self):
response = self.fetch("/echo?foo=%C3%A9")
data = json_decode(response.body)
self.assertEqual(data, {u("foo"): [u("\u00e9")]})
def test_empty_query_string(self):
response = self.fetch("/echo?foo=&foo=")
data = json_decode(response.body)
self.assertEqual(data, {u("foo"): [u(""), u("")]})
def test_empty_post_parameters(self):
response = self.fetch("/echo", method="POST", body="foo=&bar=")
data = json_decode(response.body)
self.assertEqual(data, {u("foo"): [u("")], u("bar"): [u("")]})
def test_types(self):
headers = {"Cookie": "foo=bar"}
response = self.fetch("/typecheck?foo=bar", headers=headers)
data = json_decode(response.body)
self.assertEqual(data, {})
response = self.fetch("/typecheck", method="POST", body="foo=bar", headers=headers)
data = json_decode(response.body)
self.assertEqual(data, {})
def test_double_slash(self):
# urlparse.urlsplit (which tornado.httpserver used to use
# incorrectly) would parse paths beginning with "//" as
# protocol-relative urls.
response = self.fetch("//doubleslash")
self.assertEqual(200, response.code)
self.assertEqual(json_decode(response.body), {})
class HTTPServerRawTest(AsyncHTTPTestCase):
def get_app(self):
return Application([
('/echo', EchoHandler),
])
def setUp(self):
super(HTTPServerRawTest, self).setUp()
self.stream = IOStream(socket.socket())
self.stream.connect(('localhost', self.get_http_port()), self.stop)
self.wait()
def tearDown(self):
self.stream.close()
super(HTTPServerRawTest, self).tearDown()
def test_empty_request(self):
self.stream.close()
self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
self.wait()
def test_malformed_first_line(self):
with ExpectLog(gen_log, '.*Malformed HTTP request line'):
self.stream.write(b'asdf\r\n\r\n')
# TODO: need an async version of ExpectLog so we don't need
# hard-coded timeouts here.
self.io_loop.add_timeout(datetime.timedelta(seconds=0.01),
self.stop)
self.wait()
def test_malformed_headers(self):
with ExpectLog(gen_log, '.*Malformed HTTP headers'):
self.stream.write(b'GET / HTTP/1.0\r\nasdf\r\n\r\n')
self.io_loop.add_timeout(datetime.timedelta(seconds=0.01),
self.stop)
self.wait()
class XHeaderTest(HandlerBaseTestCase):
class Handler(RequestHandler):
def get(self):
self.write(dict(remote_ip=self.request.remote_ip,
remote_protocol=self.request.protocol))
def get_httpserver_options(self):
return dict(xheaders=True)
def test_ip_headers(self):
self.assertEqual(self.fetch_json("/")["remote_ip"], "127.0.0.1")
valid_ipv4 = {"X-Real-IP": "4.4.4.4"}
self.assertEqual(
self.fetch_json("/", headers=valid_ipv4)["remote_ip"],
"4.4.4.4")
valid_ipv4_list = {"X-Forwarded-For": "127.0.0.1, 4.4.4.4"}
self.assertEqual(
self.fetch_json("/", headers=valid_ipv4_list)["remote_ip"],
"4.4.4.4")
valid_ipv6 = {"X-Real-IP": "2620:0:1cfe:face:b00c::3"}
self.assertEqual(
self.fetch_json("/", headers=valid_ipv6)["remote_ip"],
"2620:0:1cfe:face:b00c::3")
valid_ipv6_list = {"X-Forwarded-For": "::1, 2620:0:1cfe:face:b00c::3"}
self.assertEqual(
self.fetch_json("/", headers=valid_ipv6_list)["remote_ip"],
"2620:0:1cfe:face:b00c::3")
invalid_chars = {"X-Real-IP": "4.4.4.4<script>"}
self.assertEqual(
self.fetch_json("/", headers=invalid_chars)["remote_ip"],
"127.0.0.1")
invalid_chars_list = {"X-Forwarded-For": "4.4.4.4, 5.5.5.5<script>"}
self.assertEqual(
self.fetch_json("/", headers=invalid_chars_list)["remote_ip"],
"127.0.0.1")
invalid_host = {"X-Real-IP": "www.google.com"}
self.assertEqual(
self.fetch_json("/", headers=invalid_host)["remote_ip"],
"127.0.0.1")
def test_scheme_headers(self):
self.assertEqual(self.fetch_json("/")["remote_protocol"], "http")
https_scheme = {"X-Scheme": "https"}
self.assertEqual(
self.fetch_json("/", headers=https_scheme)["remote_protocol"],
"https")
https_forwarded = {"X-Forwarded-Proto": "https"}
self.assertEqual(
self.fetch_json("/", headers=https_forwarded)["remote_protocol"],
"https")
bad_forwarded = {"X-Forwarded-Proto": "unknown"}
self.assertEqual(
self.fetch_json("/", headers=bad_forwarded)["remote_protocol"],
"http")
class SSLXHeaderTest(AsyncHTTPSTestCase, HandlerBaseTestCase):
def get_app(self):
return Application([('/', XHeaderTest.Handler)])
def get_httpserver_options(self):
output = super(SSLXHeaderTest, self).get_httpserver_options()
output['xheaders'] = True
return output
def test_request_without_xprotocol(self):
self.assertEqual(self.fetch_json("/")["remote_protocol"], "https")
http_scheme = {"X-Scheme": "http"}
self.assertEqual(
self.fetch_json("/", headers=http_scheme)["remote_protocol"], "http")
bad_scheme = {"X-Scheme": "unknown"}
self.assertEqual(
self.fetch_json("/", headers=bad_scheme)["remote_protocol"], "https")
class ManualProtocolTest(HandlerBaseTestCase):
class Handler(RequestHandler):
def get(self):
self.write(dict(protocol=self.request.protocol))
def get_httpserver_options(self):
return dict(protocol='https')
def test_manual_protocol(self):
self.assertEqual(self.fetch_json('/')['protocol'], 'https')
@unittest.skipIf(not hasattr(socket, 'AF_UNIX') or sys.platform == 'cygwin',
"unix sockets not supported on this platform")
class UnixSocketTest(AsyncTestCase):
"""HTTPServers can listen on Unix sockets too.
Why would you want to do this? Nginx can proxy to backends listening
on unix sockets, for one thing (and managing a namespace for unix
sockets can be easier than managing a bunch of TCP port numbers).
Unfortunately, there's no way to specify a unix socket in a url for
an HTTP client, so we have to test this by hand.
"""
def setUp(self):
super(UnixSocketTest, self).setUp()
self.tmpdir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.tmpdir)
super(UnixSocketTest, self).tearDown()
def test_unix_socket(self):
sockfile = os.path.join(self.tmpdir, "test.sock")
sock = netutil.bind_unix_socket(sockfile)
app = Application([("/hello", HelloWorldRequestHandler)])
server = HTTPServer(app, io_loop=self.io_loop)
server.add_socket(sock)
stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop)
stream.connect(sockfile, self.stop)
self.wait()
stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
stream.read_until(b"\r\n", self.stop)
response = self.wait()
self.assertEqual(response, b"HTTP/1.0 200 OK\r\n")
stream.read_until(b"\r\n\r\n", self.stop)
headers = HTTPHeaders.parse(self.wait().decode('latin1'))
stream.read_bytes(int(headers["Content-Length"]), self.stop)
body = self.wait()
self.assertEqual(body, b"Hello world")
stream.close()
server.stop()
class KeepAliveTest(AsyncHTTPTestCase):
"""Tests various scenarios for HTTP 1.1 keep-alive support.
These tests don't use AsyncHTTPClient because we want to control
connection reuse and closing.
"""
def get_app(self):
class HelloHandler(RequestHandler):
def get(self):
self.finish('Hello world')
class LargeHandler(RequestHandler):
def get(self):
# 512KB should be bigger than the socket buffers so it will
# be written out in chunks.
self.write(''.join(chr(i % 256) * 1024 for i in range(512)))
class FinishOnCloseHandler(RequestHandler):
@asynchronous
def get(self):
self.flush()
def on_connection_close(self):
# This is not very realistic, but finishing the request
# from the close callback has the right timing to mimic
# some errors seen in the wild.
self.finish('closed')
return Application([('/', HelloHandler),
('/large', LargeHandler),
('/finish_on_close', FinishOnCloseHandler)])
def setUp(self):
super(KeepAliveTest, self).setUp()
self.http_version = b'HTTP/1.1'
def tearDown(self):
# We just closed the client side of the socket; let the IOLoop run
# once to make sure the server side got the message.
self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
self.wait()
if hasattr(self, 'stream'):
self.stream.close()
super(KeepAliveTest, self).tearDown()
# The next few methods are a crude manual http client
def connect(self):
self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
self.stream.connect(('localhost', self.get_http_port()), self.stop)
self.wait()
def read_headers(self):
self.stream.read_until(b'\r\n', self.stop)
first_line = self.wait()
self.assertTrue(first_line.startswith(self.http_version + b' 200'), first_line)
self.stream.read_until(b'\r\n\r\n', self.stop)
header_bytes = self.wait()
headers = HTTPHeaders.parse(header_bytes.decode('latin1'))
return headers
def read_response(self):
headers = self.read_headers()
self.stream.read_bytes(int(headers['Content-Length']), self.stop)
body = self.wait()
self.assertEqual(b'Hello world', body)
def close(self):
self.stream.close()
del self.stream
def test_two_requests(self):
self.connect()
self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
self.read_response()
self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
self.read_response()
self.close()
def test_request_close(self):
self.connect()
self.stream.write(b'GET / HTTP/1.1\r\nConnection: close\r\n\r\n')
self.read_response()
self.stream.read_until_close(callback=self.stop)
data = self.wait()
self.assertTrue(not data)
self.close()
# keepalive is supported for http 1.0 too, but it's opt-in
def test_http10(self):
self.http_version = b'HTTP/1.0'
self.connect()
self.stream.write(b'GET / HTTP/1.0\r\n\r\n')
self.read_response()
self.stream.read_until_close(callback=self.stop)
data = self.wait()
self.assertTrue(not data)
self.close()
def test_http10_keepalive(self):
self.http_version = b'HTTP/1.0'
self.connect()
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
self.read_response()
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
self.read_response()
self.close()
def test_pipelined_requests(self):
self.connect()
self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
self.read_response()
self.read_response()
self.close()
def test_pipelined_cancel(self):
self.connect()
self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
# only read once
self.read_response()
self.close()
def test_cancel_during_download(self):
self.connect()
self.stream.write(b'GET /large HTTP/1.1\r\n\r\n')
self.read_headers()
self.stream.read_bytes(1024, self.stop)
self.wait()
self.close()
def test_finish_while_closed(self):
self.connect()
self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n')
self.read_headers()
self.close()

View file

@ -0,0 +1,255 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
from tornado.httputil import url_concat, parse_multipart_form_data, HTTPHeaders, format_timestamp
from tornado.escape import utf8
from tornado.log import gen_log
from tornado.testing import ExpectLog
from tornado.test.util import unittest
import datetime
import logging
import time
class TestUrlConcat(unittest.TestCase):
def test_url_concat_no_query_params(self):
url = url_concat(
"https://localhost/path",
[('y', 'y'), ('z', 'z')],
)
self.assertEqual(url, "https://localhost/path?y=y&z=z")
def test_url_concat_encode_args(self):
url = url_concat(
"https://localhost/path",
[('y', '/y'), ('z', 'z')],
)
self.assertEqual(url, "https://localhost/path?y=%2Fy&z=z")
def test_url_concat_trailing_q(self):
url = url_concat(
"https://localhost/path?",
[('y', 'y'), ('z', 'z')],
)
self.assertEqual(url, "https://localhost/path?y=y&z=z")
def test_url_concat_q_with_no_trailing_amp(self):
url = url_concat(
"https://localhost/path?x",
[('y', 'y'), ('z', 'z')],
)
self.assertEqual(url, "https://localhost/path?x&y=y&z=z")
def test_url_concat_trailing_amp(self):
url = url_concat(
"https://localhost/path?x&",
[('y', 'y'), ('z', 'z')],
)
self.assertEqual(url, "https://localhost/path?x&y=y&z=z")
def test_url_concat_mult_params(self):
url = url_concat(
"https://localhost/path?a=1&b=2",
[('y', 'y'), ('z', 'z')],
)
self.assertEqual(url, "https://localhost/path?a=1&b=2&y=y&z=z")
def test_url_concat_no_params(self):
url = url_concat(
"https://localhost/path?r=1&t=2",
[],
)
self.assertEqual(url, "https://localhost/path?r=1&t=2")
class MultipartFormDataTest(unittest.TestCase):
def test_file_upload(self):
data = b"""\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"
Foo
--1234--""".replace(b"\n", b"\r\n")
args = {}
files = {}
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], "ab.txt")
self.assertEqual(file["body"], b"Foo")
def test_unquoted_names(self):
# quotes are optional unless special characters are present
data = b"""\
--1234
Content-Disposition: form-data; name=files; filename=ab.txt
Foo
--1234--""".replace(b"\n", b"\r\n")
args = {}
files = {}
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], "ab.txt")
self.assertEqual(file["body"], b"Foo")
def test_special_filenames(self):
filenames = ['a;b.txt',
'a"b.txt',
'a";b.txt',
'a;"b.txt',
'a";";.txt',
'a\\"b.txt',
'a\\b.txt',
]
for filename in filenames:
logging.debug("trying filename %r", filename)
data = """\
--1234
Content-Disposition: form-data; name="files"; filename="%s"
Foo
--1234--""" % filename.replace('\\', '\\\\').replace('"', '\\"')
data = utf8(data.replace("\n", "\r\n"))
args = {}
files = {}
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], filename)
self.assertEqual(file["body"], b"Foo")
def test_boundary_starts_and_ends_with_quotes(self):
data = b'''\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"
Foo
--1234--'''.replace(b"\n", b"\r\n")
args = {}
files = {}
parse_multipart_form_data(b'"1234"', data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], "ab.txt")
self.assertEqual(file["body"], b"Foo")
def test_missing_headers(self):
data = b'''\
--1234
Foo
--1234--'''.replace(b"\n", b"\r\n")
args = {}
files = {}
with ExpectLog(gen_log, "multipart/form-data missing headers"):
parse_multipart_form_data(b"1234", data, args, files)
self.assertEqual(files, {})
def test_invalid_content_disposition(self):
data = b'''\
--1234
Content-Disposition: invalid; name="files"; filename="ab.txt"
Foo
--1234--'''.replace(b"\n", b"\r\n")
args = {}
files = {}
with ExpectLog(gen_log, "Invalid multipart/form-data"):
parse_multipart_form_data(b"1234", data, args, files)
self.assertEqual(files, {})
def test_line_does_not_end_with_correct_line_break(self):
data = b'''\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"
Foo--1234--'''.replace(b"\n", b"\r\n")
args = {}
files = {}
with ExpectLog(gen_log, "Invalid multipart/form-data"):
parse_multipart_form_data(b"1234", data, args, files)
self.assertEqual(files, {})
def test_content_disposition_header_without_name_parameter(self):
data = b"""\
--1234
Content-Disposition: form-data; filename="ab.txt"
Foo
--1234--""".replace(b"\n", b"\r\n")
args = {}
files = {}
with ExpectLog(gen_log, "multipart/form-data value missing name"):
parse_multipart_form_data(b"1234", data, args, files)
self.assertEqual(files, {})
def test_data_after_final_boundary(self):
# The spec requires that data after the final boundary be ignored.
# http://www.w3.org/Protocols/rfc1341/7_2_Multipart.html
# In practice, some libraries include an extra CRLF after the boundary.
data = b"""\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"
Foo
--1234--
""".replace(b"\n", b"\r\n")
args = {}
files = {}
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], "ab.txt")
self.assertEqual(file["body"], b"Foo")
class HTTPHeadersTest(unittest.TestCase):
def test_multi_line(self):
# Lines beginning with whitespace are appended to the previous line
# with any leading whitespace replaced by a single space.
# Note that while multi-line headers are a part of the HTTP spec,
# their use is strongly discouraged.
data = """\
Foo: bar
baz
Asdf: qwer
\tzxcv
Foo: even
more
lines
""".replace("\n", "\r\n")
headers = HTTPHeaders.parse(data)
self.assertEqual(headers["asdf"], "qwer zxcv")
self.assertEqual(headers.get_list("asdf"), ["qwer zxcv"])
self.assertEqual(headers["Foo"], "bar baz,even more lines")
self.assertEqual(headers.get_list("foo"), ["bar baz", "even more lines"])
self.assertEqual(sorted(list(headers.get_all())),
[("Asdf", "qwer zxcv"),
("Foo", "bar baz"),
("Foo", "even more lines")])
class FormatTimestampTest(unittest.TestCase):
# Make sure that all the input types are supported.
TIMESTAMP = 1359312200.503611
EXPECTED = 'Sun, 27 Jan 2013 18:43:20 GMT'
def check(self, value):
self.assertEqual(format_timestamp(value), self.EXPECTED)
def test_unix_time_float(self):
self.check(self.TIMESTAMP)
def test_unix_time_int(self):
self.check(int(self.TIMESTAMP))
def test_struct_time(self):
self.check(time.gmtime(self.TIMESTAMP))
def test_time_tuple(self):
tup = tuple(time.gmtime(self.TIMESTAMP))
self.assertEqual(9, len(tup))
self.check(tup)
def test_datetime(self):
self.check(datetime.datetime.utcfromtimestamp(self.TIMESTAMP))

View file

@ -0,0 +1,45 @@
from __future__ import absolute_import, division, print_function, with_statement
from tornado.test.util import unittest
class ImportTest(unittest.TestCase):
def test_import_everything(self):
# Some of our modules are not otherwise tested. Import them
# all (unless they have external dependencies) here to at
# least ensure that there are no syntax errors.
import tornado.auth
import tornado.autoreload
import tornado.concurrent
# import tornado.curl_httpclient # depends on pycurl
import tornado.escape
import tornado.gen
import tornado.httpclient
import tornado.httpserver
import tornado.httputil
import tornado.ioloop
import tornado.iostream
import tornado.locale
import tornado.log
import tornado.netutil
import tornado.options
import tornado.process
import tornado.simple_httpclient
import tornado.stack_context
import tornado.tcpserver
import tornado.template
import tornado.testing
import tornado.util
import tornado.web
import tornado.websocket
import tornado.wsgi
# for modules with dependencies, if those dependencies can be loaded,
# load them too.
def test_import_pycurl(self):
try:
import pycurl
except ImportError:
pass
else:
import tornado.curl_httpclient

View file

@ -0,0 +1,333 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
import contextlib
import datetime
import functools
import socket
import sys
import threading
import time
from tornado import gen
from tornado.ioloop import IOLoop, TimeoutError
from tornado.stack_context import ExceptionStackContext, StackContext, wrap, NullContext
from tornado.testing import AsyncTestCase, bind_unused_port
from tornado.test.util import unittest, skipIfNonUnix, skipOnTravis
try:
from concurrent import futures
except ImportError:
futures = None
class TestIOLoop(AsyncTestCase):
@skipOnTravis
def test_add_callback_wakeup(self):
# Make sure that add_callback from inside a running IOLoop
# wakes up the IOLoop immediately instead of waiting for a timeout.
def callback():
self.called = True
self.stop()
def schedule_callback():
self.called = False
self.io_loop.add_callback(callback)
# Store away the time so we can check if we woke up immediately
self.start_time = time.time()
self.io_loop.add_timeout(self.io_loop.time(), schedule_callback)
self.wait()
self.assertAlmostEqual(time.time(), self.start_time, places=2)
self.assertTrue(self.called)
@skipOnTravis
def test_add_callback_wakeup_other_thread(self):
def target():
# sleep a bit to let the ioloop go into its poll loop
time.sleep(0.01)
self.stop_time = time.time()
self.io_loop.add_callback(self.stop)
thread = threading.Thread(target=target)
self.io_loop.add_callback(thread.start)
self.wait()
self.assertAlmostEqual(time.time(), self.stop_time, places=2)
thread.join()
def test_add_timeout_timedelta(self):
self.io_loop.add_timeout(datetime.timedelta(microseconds=1), self.stop)
self.wait()
def test_multiple_add(self):
sock, port = bind_unused_port()
try:
self.io_loop.add_handler(sock.fileno(), lambda fd, events: None,
IOLoop.READ)
# Attempting to add the same handler twice fails
# (with a platform-dependent exception)
self.assertRaises(Exception, self.io_loop.add_handler,
sock.fileno(), lambda fd, events: None,
IOLoop.READ)
finally:
self.io_loop.remove_handler(sock.fileno())
sock.close()
def test_remove_without_add(self):
# remove_handler should not throw an exception if called on an fd
# was never added.
sock, port = bind_unused_port()
try:
self.io_loop.remove_handler(sock.fileno())
finally:
sock.close()
def test_add_callback_from_signal(self):
# cheat a little bit and just run this normally, since we can't
# easily simulate the races that happen with real signal handlers
self.io_loop.add_callback_from_signal(self.stop)
self.wait()
def test_add_callback_from_signal_other_thread(self):
# Very crude test, just to make sure that we cover this case.
# This also happens to be the first test where we run an IOLoop in
# a non-main thread.
other_ioloop = IOLoop()
thread = threading.Thread(target=other_ioloop.start)
thread.start()
other_ioloop.add_callback_from_signal(other_ioloop.stop)
thread.join()
other_ioloop.close()
def test_add_callback_while_closing(self):
# Issue #635: add_callback() should raise a clean exception
# if called while another thread is closing the IOLoop.
closing = threading.Event()
def target():
other_ioloop.add_callback(other_ioloop.stop)
other_ioloop.start()
closing.set()
other_ioloop.close(all_fds=True)
other_ioloop = IOLoop()
thread = threading.Thread(target=target)
thread.start()
closing.wait()
for i in range(1000):
try:
other_ioloop.add_callback(lambda: None)
except RuntimeError as e:
self.assertEqual("IOLoop is closing", str(e))
break
def test_handle_callback_exception(self):
# IOLoop.handle_callback_exception can be overridden to catch
# exceptions in callbacks.
def handle_callback_exception(callback):
self.assertIs(sys.exc_info()[0], ZeroDivisionError)
self.stop()
self.io_loop.handle_callback_exception = handle_callback_exception
with NullContext():
# remove the test StackContext that would see this uncaught
# exception as a test failure.
self.io_loop.add_callback(lambda: 1 / 0)
self.wait()
@skipIfNonUnix # just because socketpair is so convenient
def test_read_while_writeable(self):
# Ensure that write events don't come in while we're waiting for
# a read and haven't asked for writeability. (the reverse is
# difficult to test for)
client, server = socket.socketpair()
try:
def handler(fd, events):
self.assertEqual(events, IOLoop.READ)
self.stop()
self.io_loop.add_handler(client.fileno(), handler, IOLoop.READ)
self.io_loop.add_timeout(self.io_loop.time() + 0.01,
functools.partial(server.send, b'asdf'))
self.wait()
self.io_loop.remove_handler(client.fileno())
finally:
client.close()
server.close()
def test_remove_timeout_after_fire(self):
# It is not an error to call remove_timeout after it has run.
handle = self.io_loop.add_timeout(self.io_loop.time(), self.stop())
self.wait()
self.io_loop.remove_timeout(handle)
def test_remove_timeout_cleanup(self):
# Add and remove enough callbacks to trigger cleanup.
# Not a very thorough test, but it ensures that the cleanup code
# gets executed and doesn't blow up. This test is only really useful
# on PollIOLoop subclasses, but it should run silently on any
# implementation.
for i in range(2000):
timeout = self.io_loop.add_timeout(self.io_loop.time() + 3600,
lambda: None)
self.io_loop.remove_timeout(timeout)
# HACK: wait two IOLoop iterations for the GC to happen.
self.io_loop.add_callback(lambda: self.io_loop.add_callback(self.stop))
self.wait()
# Deliberately not a subclass of AsyncTestCase so the IOLoop isn't
# automatically set as current.
class TestIOLoopCurrent(unittest.TestCase):
def setUp(self):
self.io_loop = IOLoop()
def tearDown(self):
self.io_loop.close()
def test_current(self):
def f():
self.current_io_loop = IOLoop.current()
self.io_loop.stop()
self.io_loop.add_callback(f)
self.io_loop.start()
self.assertIs(self.current_io_loop, self.io_loop)
class TestIOLoopAddCallback(AsyncTestCase):
def setUp(self):
super(TestIOLoopAddCallback, self).setUp()
self.active_contexts = []
def add_callback(self, callback, *args, **kwargs):
self.io_loop.add_callback(callback, *args, **kwargs)
@contextlib.contextmanager
def context(self, name):
self.active_contexts.append(name)
yield
self.assertEqual(self.active_contexts.pop(), name)
def test_pre_wrap(self):
# A pre-wrapped callback is run in the context in which it was
# wrapped, not when it was added to the IOLoop.
def f1():
self.assertIn('c1', self.active_contexts)
self.assertNotIn('c2', self.active_contexts)
self.stop()
with StackContext(functools.partial(self.context, 'c1')):
wrapped = wrap(f1)
with StackContext(functools.partial(self.context, 'c2')):
self.add_callback(wrapped)
self.wait()
def test_pre_wrap_with_args(self):
# Same as test_pre_wrap, but the function takes arguments.
# Implementation note: The function must not be wrapped in a
# functools.partial until after it has been passed through
# stack_context.wrap
def f1(foo, bar):
self.assertIn('c1', self.active_contexts)
self.assertNotIn('c2', self.active_contexts)
self.stop((foo, bar))
with StackContext(functools.partial(self.context, 'c1')):
wrapped = wrap(f1)
with StackContext(functools.partial(self.context, 'c2')):
self.add_callback(wrapped, 1, bar=2)
result = self.wait()
self.assertEqual(result, (1, 2))
class TestIOLoopAddCallbackFromSignal(TestIOLoopAddCallback):
# Repeat the add_callback tests using add_callback_from_signal
def add_callback(self, callback, *args, **kwargs):
self.io_loop.add_callback_from_signal(callback, *args, **kwargs)
@unittest.skipIf(futures is None, "futures module not present")
class TestIOLoopFutures(AsyncTestCase):
def test_add_future_threads(self):
with futures.ThreadPoolExecutor(1) as pool:
self.io_loop.add_future(pool.submit(lambda: None),
lambda future: self.stop(future))
future = self.wait()
self.assertTrue(future.done())
self.assertTrue(future.result() is None)
def test_add_future_stack_context(self):
ready = threading.Event()
def task():
# we must wait for the ioloop callback to be scheduled before
# the task completes to ensure that add_future adds the callback
# asynchronously (which is the scenario in which capturing
# the stack_context matters)
ready.wait(1)
assert ready.isSet(), "timed out"
raise Exception("worker")
def callback(future):
self.future = future
raise Exception("callback")
def handle_exception(typ, value, traceback):
self.exception = value
self.stop()
return True
# stack_context propagates to the ioloop callback, but the worker
# task just has its exceptions caught and saved in the Future.
with futures.ThreadPoolExecutor(1) as pool:
with ExceptionStackContext(handle_exception):
self.io_loop.add_future(pool.submit(task), callback)
ready.set()
self.wait()
self.assertEqual(self.exception.args[0], "callback")
self.assertEqual(self.future.exception().args[0], "worker")
class TestIOLoopRunSync(unittest.TestCase):
def setUp(self):
self.io_loop = IOLoop()
def tearDown(self):
self.io_loop.close()
def test_sync_result(self):
self.assertEqual(self.io_loop.run_sync(lambda: 42), 42)
def test_sync_exception(self):
with self.assertRaises(ZeroDivisionError):
self.io_loop.run_sync(lambda: 1 / 0)
def test_async_result(self):
@gen.coroutine
def f():
yield gen.Task(self.io_loop.add_callback)
raise gen.Return(42)
self.assertEqual(self.io_loop.run_sync(f), 42)
def test_async_exception(self):
@gen.coroutine
def f():
yield gen.Task(self.io_loop.add_callback)
1 / 0
with self.assertRaises(ZeroDivisionError):
self.io_loop.run_sync(f)
def test_current(self):
def f():
self.assertIs(IOLoop.current(), self.io_loop)
self.io_loop.run_sync(f)
def test_timeout(self):
@gen.coroutine
def f():
yield gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1)
self.assertRaises(TimeoutError, self.io_loop.run_sync, f, timeout=0.01)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,545 @@
from __future__ import absolute_import, division, print_function, with_statement
from tornado import netutil
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream, SSLIOStream, PipeIOStream
from tornado.log import gen_log, app_log
from tornado.netutil import ssl_wrap_socket
from tornado.stack_context import NullContext
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog
from tornado.test.util import unittest, skipIfNonUnix
from tornado.web import RequestHandler, Application
import errno
import logging
import os
import platform
import socket
import ssl
import sys
class HelloHandler(RequestHandler):
def get(self):
self.write("Hello")
class TestIOStreamWebMixin(object):
def _make_client_iostream(self):
raise NotImplementedError()
def get_app(self):
return Application([('/', HelloHandler)])
def test_connection_closed(self):
# When a server sends a response and then closes the connection,
# the client must be allowed to read the data before the IOStream
# closes itself. Epoll reports closed connections with a separate
# EPOLLRDHUP event delivered at the same time as the read event,
# while kqueue reports them as a second read/write event with an EOF
# flag.
response = self.fetch("/", headers={"Connection": "close"})
response.rethrow()
def test_read_until_close(self):
stream = self._make_client_iostream()
stream.connect(('localhost', self.get_http_port()), callback=self.stop)
self.wait()
stream.write(b"GET / HTTP/1.0\r\n\r\n")
stream.read_until_close(self.stop)
data = self.wait()
self.assertTrue(data.startswith(b"HTTP/1.0 200"))
self.assertTrue(data.endswith(b"Hello"))
def test_read_zero_bytes(self):
self.stream = self._make_client_iostream()
self.stream.connect(("localhost", self.get_http_port()),
callback=self.stop)
self.wait()
self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
# normal read
self.stream.read_bytes(9, self.stop)
data = self.wait()
self.assertEqual(data, b"HTTP/1.0 ")
# zero bytes
self.stream.read_bytes(0, self.stop)
data = self.wait()
self.assertEqual(data, b"")
# another normal read
self.stream.read_bytes(3, self.stop)
data = self.wait()
self.assertEqual(data, b"200")
self.stream.close()
def test_write_while_connecting(self):
stream = self._make_client_iostream()
connected = [False]
def connected_callback():
connected[0] = True
self.stop()
stream.connect(("localhost", self.get_http_port()),
callback=connected_callback)
# unlike the previous tests, try to write before the connection
# is complete.
written = [False]
def write_callback():
written[0] = True
self.stop()
stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n",
callback=write_callback)
self.assertTrue(not connected[0])
# by the time the write has flushed, the connection callback has
# also run
try:
self.wait(lambda: connected[0] and written[0])
finally:
logging.debug((connected, written))
stream.read_until_close(self.stop)
data = self.wait()
self.assertTrue(data.endswith(b"Hello"))
stream.close()
class TestIOStreamMixin(object):
def _make_server_iostream(self, connection, **kwargs):
raise NotImplementedError()
def _make_client_iostream(self, connection, **kwargs):
raise NotImplementedError()
def make_iostream_pair(self, **kwargs):
listener, port = bind_unused_port()
streams = [None, None]
def accept_callback(connection, address):
streams[0] = self._make_server_iostream(connection, **kwargs)
if isinstance(streams[0], SSLIOStream):
# HACK: The SSL handshake won't complete (and
# therefore the client connect callback won't be
# run)until the server side has tried to do something
# with the connection. For these tests we want both
# sides to connect before we do anything else with the
# connection, so we must cause some dummy activity on the
# server. If this turns out to be useful for real apps
# it should have a cleaner interface.
streams[0]._add_io_state(IOLoop.READ)
self.stop()
def connect_callback():
streams[1] = client_stream
self.stop()
netutil.add_accept_handler(listener, accept_callback,
io_loop=self.io_loop)
client_stream = self._make_client_iostream(socket.socket(), **kwargs)
client_stream.connect(('127.0.0.1', port),
callback=connect_callback)
self.wait(condition=lambda: all(streams))
self.io_loop.remove_handler(listener.fileno())
listener.close()
return streams
def test_streaming_callback_with_data_in_buffer(self):
server, client = self.make_iostream_pair()
client.write(b"abcd\r\nefgh")
server.read_until(b"\r\n", self.stop)
data = self.wait()
self.assertEqual(data, b"abcd\r\n")
def closed_callback(chunk):
self.fail()
server.read_until_close(callback=closed_callback,
streaming_callback=self.stop)
# self.io_loop.add_timeout(self.io_loop.time() + 0.01, self.stop)
data = self.wait()
self.assertEqual(data, b"efgh")
server.close()
client.close()
def test_write_zero_bytes(self):
# Attempting to write zero bytes should run the callback without
# going into an infinite loop.
server, client = self.make_iostream_pair()
server.write(b'', callback=self.stop)
self.wait()
# As a side effect, the stream is now listening for connection
# close (if it wasn't already), but is not listening for writes
self.assertEqual(server._state, IOLoop.READ | IOLoop.ERROR)
server.close()
client.close()
def test_connection_refused(self):
# When a connection is refused, the connect callback should not
# be run. (The kqueue IOLoop used to behave differently from the
# epoll IOLoop in this respect)
server_socket, port = bind_unused_port()
server_socket.close()
stream = IOStream(socket.socket(), self.io_loop)
self.connect_called = False
def connect_callback():
self.connect_called = True
stream.set_close_callback(self.stop)
# log messages vary by platform and ioloop implementation
with ExpectLog(gen_log, ".*", required=False):
stream.connect(("localhost", port), connect_callback)
self.wait()
self.assertFalse(self.connect_called)
self.assertTrue(isinstance(stream.error, socket.error), stream.error)
if sys.platform != 'cygwin':
# cygwin's errnos don't match those used on native windows python
self.assertEqual(stream.error.args[0], errno.ECONNREFUSED)
def test_gaierror(self):
# Test that IOStream sets its exc_info on getaddrinfo error
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
stream = IOStream(s, io_loop=self.io_loop)
stream.set_close_callback(self.stop)
# To reliably generate a gaierror we use a malformed domain name
# instead of a name that's simply unlikely to exist (since
# opendns and some ISPs return bogus addresses for nonexistent
# domains instead of the proper error codes).
with ExpectLog(gen_log, "Connect error"):
stream.connect(('an invalid domain', 54321))
self.assertTrue(isinstance(stream.error, socket.gaierror), stream.error)
def test_read_callback_error(self):
# Test that IOStream sets its exc_info when a read callback throws
server, client = self.make_iostream_pair()
try:
server.set_close_callback(self.stop)
with ExpectLog(
app_log, "(Uncaught exception|Exception in callback)"
):
# Clear ExceptionStackContext so IOStream catches error
with NullContext():
server.read_bytes(1, callback=lambda data: 1 / 0)
client.write(b"1")
self.wait()
self.assertTrue(isinstance(server.error, ZeroDivisionError))
finally:
server.close()
client.close()
def test_streaming_callback(self):
server, client = self.make_iostream_pair()
try:
chunks = []
final_called = []
def streaming_callback(data):
chunks.append(data)
self.stop()
def final_callback(data):
self.assertFalse(data)
final_called.append(True)
self.stop()
server.read_bytes(6, callback=final_callback,
streaming_callback=streaming_callback)
client.write(b"1234")
self.wait(condition=lambda: chunks)
client.write(b"5678")
self.wait(condition=lambda: final_called)
self.assertEqual(chunks, [b"1234", b"56"])
# the rest of the last chunk is still in the buffer
server.read_bytes(2, callback=self.stop)
data = self.wait()
self.assertEqual(data, b"78")
finally:
server.close()
client.close()
def test_streaming_until_close(self):
server, client = self.make_iostream_pair()
try:
chunks = []
closed = [False]
def streaming_callback(data):
chunks.append(data)
self.stop()
def close_callback(data):
assert not data, data
closed[0] = True
self.stop()
client.read_until_close(callback=close_callback,
streaming_callback=streaming_callback)
server.write(b"1234")
self.wait(condition=lambda: len(chunks) == 1)
server.write(b"5678", self.stop)
self.wait()
server.close()
self.wait(condition=lambda: closed[0])
self.assertEqual(chunks, [b"1234", b"5678"])
finally:
server.close()
client.close()
def test_delayed_close_callback(self):
# The scenario: Server closes the connection while there is a pending
# read that can be served out of buffered data. The client does not
# run the close_callback as soon as it detects the close, but rather
# defers it until after the buffered read has finished.
server, client = self.make_iostream_pair()
try:
client.set_close_callback(self.stop)
server.write(b"12")
chunks = []
def callback1(data):
chunks.append(data)
client.read_bytes(1, callback2)
server.close()
def callback2(data):
chunks.append(data)
client.read_bytes(1, callback1)
self.wait() # stopped by close_callback
self.assertEqual(chunks, [b"1", b"2"])
finally:
server.close()
client.close()
def test_close_buffered_data(self):
# Similar to the previous test, but with data stored in the OS's
# socket buffers instead of the IOStream's read buffer. Out-of-band
# close notifications must be delayed until all data has been
# drained into the IOStream buffer. (epoll used to use out-of-band
# close events with EPOLLRDHUP, but no longer)
#
# This depends on the read_chunk_size being smaller than the
# OS socket buffer, so make it small.
server, client = self.make_iostream_pair(read_chunk_size=256)
try:
server.write(b"A" * 512)
client.read_bytes(256, self.stop)
data = self.wait()
self.assertEqual(b"A" * 256, data)
server.close()
# Allow the close to propagate to the client side of the
# connection. Using add_callback instead of add_timeout
# doesn't seem to work, even with multiple iterations
self.io_loop.add_timeout(self.io_loop.time() + 0.01, self.stop)
self.wait()
client.read_bytes(256, self.stop)
data = self.wait()
self.assertEqual(b"A" * 256, data)
finally:
server.close()
client.close()
def test_read_until_close_after_close(self):
# Similar to test_delayed_close_callback, but read_until_close takes
# a separate code path so test it separately.
server, client = self.make_iostream_pair()
client.set_close_callback(self.stop)
try:
server.write(b"1234")
server.close()
self.wait()
client.read_until_close(self.stop)
data = self.wait()
self.assertEqual(data, b"1234")
finally:
server.close()
client.close()
def test_streaming_read_until_close_after_close(self):
# Same as the preceding test but with a streaming_callback.
# All data should go through the streaming callback,
# and the final read callback just gets an empty string.
server, client = self.make_iostream_pair()
client.set_close_callback(self.stop)
try:
server.write(b"1234")
server.close()
self.wait()
streaming_data = []
client.read_until_close(self.stop,
streaming_callback=streaming_data.append)
data = self.wait()
self.assertEqual(b'', data)
self.assertEqual(b''.join(streaming_data), b"1234")
finally:
server.close()
client.close()
def test_large_read_until(self):
# Performance test: read_until used to have a quadratic component
# so a read_until of 4MB would take 8 seconds; now it takes 0.25
# seconds.
server, client = self.make_iostream_pair()
try:
# This test fails on pypy with ssl. I think it's because
# pypy's gc defeats moves objects, breaking the
# "frozen write buffer" assumption.
if (isinstance(server, SSLIOStream) and
platform.python_implementation() == 'PyPy'):
raise unittest.SkipTest(
"pypy gc causes problems with openssl")
NUM_KB = 4096
for i in range(NUM_KB):
client.write(b"A" * 1024)
client.write(b"\r\n")
server.read_until(b"\r\n", self.stop)
data = self.wait()
self.assertEqual(len(data), NUM_KB * 1024 + 2)
finally:
server.close()
client.close()
def test_close_callback_with_pending_read(self):
# Regression test for a bug that was introduced in 2.3
# where the IOStream._close_callback would never be called
# if there were pending reads.
OK = b"OK\r\n"
server, client = self.make_iostream_pair()
client.set_close_callback(self.stop)
try:
server.write(OK)
client.read_until(b"\r\n", self.stop)
res = self.wait()
self.assertEqual(res, OK)
server.close()
client.read_until(b"\r\n", lambda x: x)
# If _close_callback (self.stop) is not called,
# an AssertionError: Async operation timed out after 5 seconds
# will be raised.
res = self.wait()
self.assertTrue(res is None)
finally:
server.close()
client.close()
@skipIfNonUnix
def test_inline_read_error(self):
# An error on an inline read is raised without logging (on the
# assumption that it will eventually be noticed or logged further
# up the stack).
#
# This test is posix-only because windows os.close() doesn't work
# on socket FDs, but we can't close the socket object normally
# because we won't get the error we want if the socket knows
# it's closed.
server, client = self.make_iostream_pair()
try:
os.close(server.socket.fileno())
with self.assertRaises(socket.error):
server.read_bytes(1, lambda data: None)
finally:
server.close()
client.close()
def test_async_read_error_logging(self):
# Socket errors on asynchronous reads should be logged (but only
# once).
server, client = self.make_iostream_pair()
server.set_close_callback(self.stop)
try:
# Start a read that will be fullfilled asynchronously.
server.read_bytes(1, lambda data: None)
client.write(b'a')
# Stub out read_from_fd to make it fail.
def fake_read_from_fd():
os.close(server.socket.fileno())
server.__class__.read_from_fd(server)
server.read_from_fd = fake_read_from_fd
# This log message is from _handle_read (not read_from_fd).
with ExpectLog(gen_log, "error on read"):
self.wait()
finally:
server.close()
client.close()
class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase):
def _make_client_iostream(self):
return IOStream(socket.socket(), io_loop=self.io_loop)
class TestIOStreamWebHTTPS(TestIOStreamWebMixin, AsyncHTTPSTestCase):
def _make_client_iostream(self):
return SSLIOStream(socket.socket(), io_loop=self.io_loop)
class TestIOStream(TestIOStreamMixin, AsyncTestCase):
def _make_server_iostream(self, connection, **kwargs):
return IOStream(connection, **kwargs)
def _make_client_iostream(self, connection, **kwargs):
return IOStream(connection, **kwargs)
class TestIOStreamSSL(TestIOStreamMixin, AsyncTestCase):
def _make_server_iostream(self, connection, **kwargs):
ssl_options = dict(
certfile=os.path.join(os.path.dirname(__file__), 'test.crt'),
keyfile=os.path.join(os.path.dirname(__file__), 'test.key'),
)
connection = ssl.wrap_socket(connection,
server_side=True,
do_handshake_on_connect=False,
**ssl_options)
return SSLIOStream(connection, io_loop=self.io_loop, **kwargs)
def _make_client_iostream(self, connection, **kwargs):
return SSLIOStream(connection, io_loop=self.io_loop, **kwargs)
# This will run some tests that are basically redundant but it's the
# simplest way to make sure that it works to pass an SSLContext
# instead of an ssl_options dict to the SSLIOStream constructor.
@unittest.skipIf(not hasattr(ssl, 'SSLContext'), 'ssl.SSLContext not present')
class TestIOStreamSSLContext(TestIOStreamMixin, AsyncTestCase):
def _make_server_iostream(self, connection, **kwargs):
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
context.load_cert_chain(
os.path.join(os.path.dirname(__file__), 'test.crt'),
os.path.join(os.path.dirname(__file__), 'test.key'))
connection = ssl_wrap_socket(connection, context,
server_side=True,
do_handshake_on_connect=False)
return SSLIOStream(connection, io_loop=self.io_loop, **kwargs)
def _make_client_iostream(self, connection, **kwargs):
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
return SSLIOStream(connection, io_loop=self.io_loop,
ssl_options=context, **kwargs)
@skipIfNonUnix
class TestPipeIOStream(AsyncTestCase):
def test_pipe_iostream(self):
r, w = os.pipe()
rs = PipeIOStream(r, io_loop=self.io_loop)
ws = PipeIOStream(w, io_loop=self.io_loop)
ws.write(b"hel")
ws.write(b"lo world")
rs.read_until(b' ', callback=self.stop)
data = self.wait()
self.assertEqual(data, b"hello ")
rs.read_bytes(3, self.stop)
data = self.wait()
self.assertEqual(data, b"wor")
ws.close()
rs.read_until_close(self.stop)
data = self.wait()
self.assertEqual(data, b"ld")
rs.close()

View file

@ -0,0 +1,59 @@
from __future__ import absolute_import, division, print_function, with_statement
import datetime
import os
import tornado.locale
from tornado.escape import utf8
from tornado.test.util import unittest
from tornado.util import u, unicode_type
class TranslationLoaderTest(unittest.TestCase):
# TODO: less hacky way to get isolated tests
SAVE_VARS = ['_translations', '_supported_locales', '_use_gettext']
def clear_locale_cache(self):
if hasattr(tornado.locale.Locale, '_cache'):
del tornado.locale.Locale._cache
def setUp(self):
self.saved = {}
for var in TranslationLoaderTest.SAVE_VARS:
self.saved[var] = getattr(tornado.locale, var)
self.clear_locale_cache()
def tearDown(self):
for k, v in self.saved.items():
setattr(tornado.locale, k, v)
self.clear_locale_cache()
def test_csv(self):
tornado.locale.load_translations(
os.path.join(os.path.dirname(__file__), 'csv_translations'))
locale = tornado.locale.get("fr_FR")
self.assertTrue(isinstance(locale, tornado.locale.CSVLocale))
self.assertEqual(locale.translate("school"), u("\u00e9cole"))
def test_gettext(self):
tornado.locale.load_gettext_translations(
os.path.join(os.path.dirname(__file__), 'gettext_translations'),
"tornado_test")
locale = tornado.locale.get("fr_FR")
self.assertTrue(isinstance(locale, tornado.locale.GettextLocale))
self.assertEqual(locale.translate("school"), u("\u00e9cole"))
class LocaleDataTest(unittest.TestCase):
def test_non_ascii_name(self):
name = tornado.locale.LOCALE_NAMES['es_LA']['name']
self.assertTrue(isinstance(name, unicode_type))
self.assertEqual(name, u('Espa\u00f1ol'))
self.assertEqual(utf8(name), b'Espa\xc3\xb1ol')
class EnglishTest(unittest.TestCase):
def test_format_date(self):
locale = tornado.locale.get('en_US')
date = datetime.datetime(2013, 4, 28, 18, 35)
self.assertEqual(locale.format_date(date, full_format=True),
'April 28, 2013 at 6:35 pm')

View file

@ -0,0 +1,159 @@
#!/usr/bin/env python
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, with_statement
import contextlib
import glob
import logging
import os
import re
import tempfile
import warnings
from tornado.escape import utf8
from tornado.log import LogFormatter, define_logging_options, enable_pretty_logging
from tornado.options import OptionParser
from tornado.test.util import unittest
from tornado.util import u, bytes_type, basestring_type
@contextlib.contextmanager
def ignore_bytes_warning():
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=BytesWarning)
yield
class LogFormatterTest(unittest.TestCase):
# Matches the output of a single logging call (which may be multiple lines
# if a traceback was included, so we use the DOTALL option)
LINE_RE = re.compile(b"(?s)\x01\\[E [0-9]{6} [0-9]{2}:[0-9]{2}:[0-9]{2} log_test:[0-9]+\\]\x02 (.*)")
def setUp(self):
self.formatter = LogFormatter(color=False)
# Fake color support. We can't guarantee anything about the $TERM
# variable when the tests are run, so just patch in some values
# for testing. (testing with color off fails to expose some potential
# encoding issues from the control characters)
self.formatter._colors = {
logging.ERROR: u("\u0001"),
}
self.formatter._normal = u("\u0002")
self.formatter._color = True
# construct a Logger directly to bypass getLogger's caching
self.logger = logging.Logger('LogFormatterTest')
self.logger.propagate = False
self.tempdir = tempfile.mkdtemp()
self.filename = os.path.join(self.tempdir, 'log.out')
self.handler = self.make_handler(self.filename)
self.handler.setFormatter(self.formatter)
self.logger.addHandler(self.handler)
def tearDown(self):
self.handler.close()
os.unlink(self.filename)
os.rmdir(self.tempdir)
def make_handler(self, filename):
# Base case: default setup without explicit encoding.
# In python 2, supports arbitrary byte strings and unicode objects
# that contain only ascii. In python 3, supports ascii-only unicode
# strings (but byte strings will be repr'd automatically).
return logging.FileHandler(filename)
def get_output(self):
with open(self.filename, "rb") as f:
line = f.read().strip()
m = LogFormatterTest.LINE_RE.match(line)
if m:
return m.group(1)
else:
raise Exception("output didn't match regex: %r" % line)
def test_basic_logging(self):
self.logger.error("foo")
self.assertEqual(self.get_output(), b"foo")
def test_bytes_logging(self):
with ignore_bytes_warning():
# This will be "\xe9" on python 2 or "b'\xe9'" on python 3
self.logger.error(b"\xe9")
self.assertEqual(self.get_output(), utf8(repr(b"\xe9")))
def test_utf8_logging(self):
self.logger.error(u("\u00e9").encode("utf8"))
if issubclass(bytes_type, basestring_type):
# on python 2, utf8 byte strings (and by extension ascii byte
# strings) are passed through as-is.
self.assertEqual(self.get_output(), utf8(u("\u00e9")))
else:
# on python 3, byte strings always get repr'd even if
# they're ascii-only, so this degenerates into another
# copy of test_bytes_logging.
self.assertEqual(self.get_output(), utf8(repr(utf8(u("\u00e9")))))
def test_bytes_exception_logging(self):
try:
raise Exception(b'\xe9')
except Exception:
self.logger.exception('caught exception')
# This will be "Exception: \xe9" on python 2 or
# "Exception: b'\xe9'" on python 3.
output = self.get_output()
self.assertRegexpMatches(output, br'Exception.*\\xe9')
# The traceback contains newlines, which should not have been escaped.
self.assertNotIn(br'\n', output)
class UnicodeLogFormatterTest(LogFormatterTest):
def make_handler(self, filename):
# Adding an explicit encoding configuration allows non-ascii unicode
# strings in both python 2 and 3, without changing the behavior
# for byte strings.
return logging.FileHandler(filename, encoding="utf8")
def test_unicode_logging(self):
self.logger.error(u("\u00e9"))
self.assertEqual(self.get_output(), utf8(u("\u00e9")))
class EnablePrettyLoggingTest(unittest.TestCase):
def setUp(self):
super(EnablePrettyLoggingTest, self).setUp()
self.options = OptionParser()
define_logging_options(self.options)
self.logger = logging.Logger('tornado.test.log_test.EnablePrettyLoggingTest')
self.logger.propagate = False
def test_log_file(self):
tmpdir = tempfile.mkdtemp()
try:
self.options.log_file_prefix = tmpdir + '/test_log'
enable_pretty_logging(options=self.options, logger=self.logger)
self.assertEqual(1, len(self.logger.handlers))
self.logger.error('hello')
self.logger.handlers[0].flush()
filenames = glob.glob(tmpdir + '/test_log*')
self.assertEqual(1, len(filenames))
with open(filenames[0]) as f:
self.assertRegexpMatches(f.read(), r'^\[E [^]]*\] hello$')
finally:
for handler in self.logger.handlers:
handler.flush()
handler.close()
for filename in glob.glob(tmpdir + '/test_log*'):
os.unlink(filename)
os.rmdir(tmpdir)

View file

@ -0,0 +1,84 @@
from __future__ import absolute_import, division, print_function, with_statement
import socket
from tornado.netutil import BlockingResolver, ThreadedResolver, is_valid_ip
from tornado.testing import AsyncTestCase, gen_test
from tornado.test.util import unittest
try:
from concurrent import futures
except ImportError:
futures = None
try:
import pycares
except ImportError:
pycares = None
else:
from tornado.platform.caresresolver import CaresResolver
try:
import twisted
except ImportError:
twisted = None
else:
from tornado.platform.twisted import TwistedResolver
class _ResolverTestMixin(object):
def test_localhost(self):
self.resolver.resolve('localhost', 80, callback=self.stop)
result = self.wait()
self.assertIn((socket.AF_INET, ('127.0.0.1', 80)), result)
@gen_test
def test_future_interface(self):
addrinfo = yield self.resolver.resolve('localhost', 80,
socket.AF_UNSPEC)
self.assertIn((socket.AF_INET, ('127.0.0.1', 80)),
addrinfo)
class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(BlockingResolverTest, self).setUp()
self.resolver = BlockingResolver(io_loop=self.io_loop)
@unittest.skipIf(futures is None, "futures module not present")
class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(ThreadedResolverTest, self).setUp()
self.resolver = ThreadedResolver(io_loop=self.io_loop)
def tearDown(self):
self.resolver.close()
super(ThreadedResolverTest, self).tearDown()
@unittest.skipIf(pycares is None, "pycares module not present")
class CaresResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(CaresResolverTest, self).setUp()
self.resolver = CaresResolver(io_loop=self.io_loop)
@unittest.skipIf(twisted is None, "twisted module not present")
@unittest.skipIf(getattr(twisted, '__version__', '0.0') < "12.1", "old version of twisted")
class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(TwistedResolverTest, self).setUp()
self.resolver = TwistedResolver(io_loop=self.io_loop)
class IsValidIPTest(unittest.TestCase):
def test_is_valid_ip(self):
self.assertTrue(is_valid_ip('127.0.0.1'))
self.assertTrue(is_valid_ip('4.4.4.4'))
self.assertTrue(is_valid_ip('::1'))
self.assertTrue(is_valid_ip('2620:0:1cfe:face:b00c::3'))
self.assertTrue(not is_valid_ip('www.google.com'))
self.assertTrue(not is_valid_ip('localhost'))
self.assertTrue(not is_valid_ip('4.4.4.4<'))
self.assertTrue(not is_valid_ip(' 127.0.0.1'))

View file

@ -0,0 +1,2 @@
port=443
port=443

View file

@ -0,0 +1,220 @@
from __future__ import absolute_import, division, print_function, with_statement
import datetime
import os
import sys
from tornado.options import OptionParser, Error
from tornado.util import basestring_type
from tornado.test.util import unittest
try:
from cStringIO import StringIO # python 2
except ImportError:
from io import StringIO # python 3
try:
from unittest import mock # python 3.3
except ImportError:
try:
import mock # third-party mock package
except ImportError:
mock = None
class OptionsTest(unittest.TestCase):
def test_parse_command_line(self):
options = OptionParser()
options.define("port", default=80)
options.parse_command_line(["main.py", "--port=443"])
self.assertEqual(options.port, 443)
def test_parse_config_file(self):
options = OptionParser()
options.define("port", default=80)
options.parse_config_file(os.path.join(os.path.dirname(__file__),
"options_test.cfg"))
self.assertEquals(options.port, 443)
def test_parse_callbacks(self):
options = OptionParser()
self.called = False
def callback():
self.called = True
options.add_parse_callback(callback)
# non-final parse doesn't run callbacks
options.parse_command_line(["main.py"], final=False)
self.assertFalse(self.called)
# final parse does
options.parse_command_line(["main.py"])
self.assertTrue(self.called)
# callbacks can be run more than once on the same options
# object if there are multiple final parses
self.called = False
options.parse_command_line(["main.py"])
self.assertTrue(self.called)
def test_help(self):
options = OptionParser()
try:
orig_stderr = sys.stderr
sys.stderr = StringIO()
with self.assertRaises(SystemExit):
options.parse_command_line(["main.py", "--help"])
usage = sys.stderr.getvalue()
finally:
sys.stderr = orig_stderr
self.assertIn("Usage:", usage)
def test_subcommand(self):
base_options = OptionParser()
base_options.define("verbose", default=False)
sub_options = OptionParser()
sub_options.define("foo", type=str)
rest = base_options.parse_command_line(
["main.py", "--verbose", "subcommand", "--foo=bar"])
self.assertEqual(rest, ["subcommand", "--foo=bar"])
self.assertTrue(base_options.verbose)
rest2 = sub_options.parse_command_line(rest)
self.assertEqual(rest2, [])
self.assertEqual(sub_options.foo, "bar")
# the two option sets are distinct
try:
orig_stderr = sys.stderr
sys.stderr = StringIO()
with self.assertRaises(Error):
sub_options.parse_command_line(["subcommand", "--verbose"])
finally:
sys.stderr = orig_stderr
def test_setattr(self):
options = OptionParser()
options.define('foo', default=1, type=int)
options.foo = 2
self.assertEqual(options.foo, 2)
def test_setattr_type_check(self):
# setattr requires that options be the right type and doesn't
# parse from string formats.
options = OptionParser()
options.define('foo', default=1, type=int)
with self.assertRaises(Error):
options.foo = '2'
def test_setattr_with_callback(self):
values = []
options = OptionParser()
options.define('foo', default=1, type=int, callback=values.append)
options.foo = 2
self.assertEqual(values, [2])
def _sample_options(self):
options = OptionParser()
options.define('a', default=1)
options.define('b', default=2)
return options
def test_iter(self):
options = self._sample_options()
# OptionParsers always define 'help'.
self.assertEqual(set(['a', 'b', 'help']), set(iter(options)))
def test_getitem(self):
options = self._sample_options()
self.assertEqual(1, options['a'])
def test_items(self):
options = self._sample_options()
# OptionParsers always define 'help'.
expected = [('a', 1), ('b', 2), ('help', options.help)]
actual = sorted(options.items())
self.assertEqual(expected, actual)
def test_as_dict(self):
options = self._sample_options()
expected = {'a': 1, 'b': 2, 'help': options.help}
self.assertEqual(expected, options.as_dict())
def test_group_dict(self):
options = OptionParser()
options.define('a', default=1)
options.define('b', group='b_group', default=2)
frame = sys._getframe(0)
this_file = frame.f_code.co_filename
self.assertEqual(set(['b_group', '', this_file]), options.groups())
b_group_dict = options.group_dict('b_group')
self.assertEqual({'b': 2}, b_group_dict)
self.assertEqual({}, options.group_dict('nonexistent'))
@unittest.skipIf(mock is None, 'mock package not present')
def test_mock_patch(self):
# ensure that our setattr hooks don't interfere with mock.patch
options = OptionParser()
options.define('foo', default=1)
options.parse_command_line(['main.py', '--foo=2'])
self.assertEqual(options.foo, 2)
with mock.patch.object(options.mockable(), 'foo', 3):
self.assertEqual(options.foo, 3)
self.assertEqual(options.foo, 2)
# Try nested patches mixed with explicit sets
with mock.patch.object(options.mockable(), 'foo', 4):
self.assertEqual(options.foo, 4)
options.foo = 5
self.assertEqual(options.foo, 5)
with mock.patch.object(options.mockable(), 'foo', 6):
self.assertEqual(options.foo, 6)
self.assertEqual(options.foo, 5)
self.assertEqual(options.foo, 2)
def test_types(self):
options = OptionParser()
options.define('str', type=str)
options.define('basestring', type=basestring_type)
options.define('int', type=int)
options.define('float', type=float)
options.define('datetime', type=datetime.datetime)
options.define('timedelta', type=datetime.timedelta)
options.parse_command_line(['main.py',
'--str=asdf',
'--basestring=qwer',
'--int=42',
'--float=1.5',
'--datetime=2013-04-28 05:16',
'--timedelta=45s'])
self.assertEqual(options.str, 'asdf')
self.assertEqual(options.basestring, 'qwer')
self.assertEqual(options.int, 42)
self.assertEqual(options.float, 1.5)
self.assertEqual(options.datetime,
datetime.datetime(2013, 4, 28, 5, 16))
self.assertEqual(options.timedelta, datetime.timedelta(seconds=45))
def test_multiple_string(self):
options = OptionParser()
options.define('foo', type=str, multiple=True)
options.parse_command_line(['main.py', '--foo=a,b,c'])
self.assertEqual(options.foo, ['a', 'b', 'c'])
def test_multiple_int(self):
options = OptionParser()
options.define('foo', type=int, multiple=True)
options.parse_command_line(['main.py', '--foo=1,3,5:7'])
self.assertEqual(options.foo, [1, 3, 5, 6, 7])
def test_error_redefine(self):
options = OptionParser()
options.define('foo')
with self.assertRaises(Error) as cm:
options.define('foo')
self.assertRegexpMatches(str(cm.exception),
'Option.*foo.*already defined')

View file

@ -0,0 +1,204 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
import logging
import os
import signal
import subprocess
import sys
from tornado.httpclient import HTTPClient, HTTPError
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from tornado.log import gen_log
from tornado.process import fork_processes, task_id, Subprocess
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.testing import bind_unused_port, ExpectLog, AsyncTestCase
from tornado.test.util import unittest, skipIfNonUnix
from tornado.web import RequestHandler, Application
def skip_if_twisted():
if IOLoop.configured_class().__name__.endswith('TwistedIOLoop'):
raise unittest.SkipTest("Process tests not compatible with TwistedIOLoop")
# Not using AsyncHTTPTestCase because we need control over the IOLoop.
@skipIfNonUnix
class ProcessTest(unittest.TestCase):
def get_app(self):
class ProcessHandler(RequestHandler):
def get(self):
if self.get_argument("exit", None):
# must use os._exit instead of sys.exit so unittest's
# exception handler doesn't catch it
os._exit(int(self.get_argument("exit")))
if self.get_argument("signal", None):
os.kill(os.getpid(),
int(self.get_argument("signal")))
self.write(str(os.getpid()))
return Application([("/", ProcessHandler)])
def tearDown(self):
if task_id() is not None:
# We're in a child process, and probably got to this point
# via an uncaught exception. If we return now, both
# processes will continue with the rest of the test suite.
# Exit now so the parent process will restart the child
# (since we don't have a clean way to signal failure to
# the parent that won't restart)
logging.error("aborting child process from tearDown")
logging.shutdown()
os._exit(1)
# In the surviving process, clear the alarm we set earlier
signal.alarm(0)
super(ProcessTest, self).tearDown()
def test_multi_process(self):
# This test can't work on twisted because we use the global reactor
# and have no way to get it back into a sane state after the fork.
skip_if_twisted()
with ExpectLog(gen_log, "(Starting .* processes|child .* exited|uncaught exception)"):
self.assertFalse(IOLoop.initialized())
sock, port = bind_unused_port()
def get_url(path):
return "http://127.0.0.1:%d%s" % (port, path)
# ensure that none of these processes live too long
signal.alarm(5) # master process
try:
id = fork_processes(3, max_restarts=3)
self.assertTrue(id is not None)
signal.alarm(5) # child processes
except SystemExit as e:
# if we exit cleanly from fork_processes, all the child processes
# finished with status 0
self.assertEqual(e.code, 0)
self.assertTrue(task_id() is None)
sock.close()
return
try:
if id in (0, 1):
self.assertEqual(id, task_id())
server = HTTPServer(self.get_app())
server.add_sockets([sock])
IOLoop.instance().start()
elif id == 2:
self.assertEqual(id, task_id())
sock.close()
# Always use SimpleAsyncHTTPClient here; the curl
# version appears to get confused sometimes if the
# connection gets closed before it's had a chance to
# switch from writing mode to reading mode.
client = HTTPClient(SimpleAsyncHTTPClient)
def fetch(url, fail_ok=False):
try:
return client.fetch(get_url(url))
except HTTPError as e:
if not (fail_ok and e.code == 599):
raise
# Make two processes exit abnormally
fetch("/?exit=2", fail_ok=True)
fetch("/?exit=3", fail_ok=True)
# They've been restarted, so a new fetch will work
int(fetch("/").body)
# Now the same with signals
# Disabled because on the mac a process dying with a signal
# can trigger an "Application exited abnormally; send error
# report to Apple?" prompt.
# fetch("/?signal=%d" % signal.SIGTERM, fail_ok=True)
# fetch("/?signal=%d" % signal.SIGABRT, fail_ok=True)
# int(fetch("/").body)
# Now kill them normally so they won't be restarted
fetch("/?exit=0", fail_ok=True)
# One process left; watch it's pid change
pid = int(fetch("/").body)
fetch("/?exit=4", fail_ok=True)
pid2 = int(fetch("/").body)
self.assertNotEqual(pid, pid2)
# Kill the last one so we shut down cleanly
fetch("/?exit=0", fail_ok=True)
os._exit(0)
except Exception:
logging.error("exception in child process %d", id, exc_info=True)
raise
@skipIfNonUnix
class SubprocessTest(AsyncTestCase):
def test_subprocess(self):
subproc = Subprocess([sys.executable, '-u', '-i'],
stdin=Subprocess.STREAM,
stdout=Subprocess.STREAM, stderr=subprocess.STDOUT,
io_loop=self.io_loop)
self.addCleanup(lambda: os.kill(subproc.pid, signal.SIGTERM))
subproc.stdout.read_until(b'>>> ', self.stop)
self.wait()
subproc.stdin.write(b"print('hello')\n")
subproc.stdout.read_until(b'\n', self.stop)
data = self.wait()
self.assertEqual(data, b"hello\n")
subproc.stdout.read_until(b">>> ", self.stop)
self.wait()
subproc.stdin.write(b"raise SystemExit\n")
subproc.stdout.read_until_close(self.stop)
data = self.wait()
self.assertEqual(data, b"")
def test_close_stdin(self):
# Close the parent's stdin handle and see that the child recognizes it.
subproc = Subprocess([sys.executable, '-u', '-i'],
stdin=Subprocess.STREAM,
stdout=Subprocess.STREAM, stderr=subprocess.STDOUT,
io_loop=self.io_loop)
self.addCleanup(lambda: os.kill(subproc.pid, signal.SIGTERM))
subproc.stdout.read_until(b'>>> ', self.stop)
self.wait()
subproc.stdin.close()
subproc.stdout.read_until_close(self.stop)
data = self.wait()
self.assertEqual(data, b"\n")
def test_stderr(self):
subproc = Subprocess([sys.executable, '-u', '-c',
r"import sys; sys.stderr.write('hello\n')"],
stderr=Subprocess.STREAM,
io_loop=self.io_loop)
self.addCleanup(lambda: os.kill(subproc.pid, signal.SIGTERM))
subproc.stderr.read_until(b'\n', self.stop)
data = self.wait()
self.assertEqual(data, b'hello\n')
def test_sigchild(self):
# Twisted's SIGCHLD handler and Subprocess's conflict with each other.
skip_if_twisted()
Subprocess.initialize(io_loop=self.io_loop)
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, '-c', 'pass'],
io_loop=self.io_loop)
subproc.set_exit_callback(self.stop)
ret = self.wait()
self.assertEqual(ret, 0)
self.assertEqual(subproc.returncode, ret)
def test_sigchild_signal(self):
skip_if_twisted()
Subprocess.initialize(io_loop=self.io_loop)
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, '-c',
'import time; time.sleep(30)'],
io_loop=self.io_loop)
subproc.set_exit_callback(self.stop)
os.kill(subproc.pid, signal.SIGTERM)
ret = self.wait()
self.assertEqual(subproc.returncode, ret)
self.assertEqual(ret, -signal.SIGTERM)

View file

@ -0,0 +1,123 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
import gc
import locale # system locale module, not tornado.locale
import logging
import operator
import textwrap
import sys
from tornado.httpclient import AsyncHTTPClient
from tornado.ioloop import IOLoop
from tornado.netutil import Resolver
from tornado.options import define, options, add_parse_callback
from tornado.test.util import unittest
TEST_MODULES = [
'tornado.httputil.doctests',
'tornado.iostream.doctests',
'tornado.util.doctests',
'tornado.test.auth_test',
'tornado.test.concurrent_test',
'tornado.test.curl_httpclient_test',
'tornado.test.escape_test',
'tornado.test.gen_test',
'tornado.test.httpclient_test',
'tornado.test.httpserver_test',
'tornado.test.httputil_test',
'tornado.test.import_test',
'tornado.test.ioloop_test',
'tornado.test.iostream_test',
'tornado.test.locale_test',
'tornado.test.netutil_test',
'tornado.test.log_test',
'tornado.test.options_test',
'tornado.test.process_test',
'tornado.test.simple_httpclient_test',
'tornado.test.stack_context_test',
'tornado.test.template_test',
'tornado.test.testing_test',
'tornado.test.twisted_test',
'tornado.test.util_test',
'tornado.test.web_test',
'tornado.test.websocket_test',
'tornado.test.wsgi_test',
]
def all():
return unittest.defaultTestLoader.loadTestsFromNames(TEST_MODULES)
class TornadoTextTestRunner(unittest.TextTestRunner):
def run(self, test):
result = super(TornadoTextTestRunner, self).run(test)
if result.skipped:
skip_reasons = set(reason for (test, reason) in result.skipped)
self.stream.write(textwrap.fill(
"Some tests were skipped because: %s" %
", ".join(sorted(skip_reasons))))
self.stream.write("\n")
return result
if __name__ == '__main__':
# The -W command-line option does not work in a virtualenv with
# python 3 (as of virtualenv 1.7), so configure warnings
# programmatically instead.
import warnings
# Be strict about most warnings. This also turns on warnings that are
# ignored by default, including DeprecationWarnings and
# python 3.2's ResourceWarnings.
warnings.filterwarnings("error")
# setuptools sometimes gives ImportWarnings about things that are on
# sys.path even if they're not being used.
warnings.filterwarnings("ignore", category=ImportWarning)
# Tornado generally shouldn't use anything deprecated, but some of
# our dependencies do (last match wins).
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("error", category=DeprecationWarning,
module=r"tornado\..*")
# The unittest module is aggressive about deprecating redundant methods,
# leaving some without non-deprecated spellings that work on both
# 2.7 and 3.2
warnings.filterwarnings("ignore", category=DeprecationWarning,
message="Please use assert.* instead")
logging.getLogger("tornado.access").setLevel(logging.CRITICAL)
define('httpclient', type=str, default=None,
callback=AsyncHTTPClient.configure)
define('ioloop', type=str, default=None)
define('ioloop_time_monotonic', default=False)
define('resolver', type=str, default=None,
callback=Resolver.configure)
define('debug_gc', type=str, multiple=True,
help="A comma-separated list of gc module debug constants, "
"e.g. DEBUG_STATS or DEBUG_COLLECTABLE,DEBUG_OBJECTS",
callback=lambda values: gc.set_debug(
reduce(operator.or_, (getattr(gc, v) for v in values))))
define('locale', type=str, default=None,
callback=lambda x: locale.setlocale(locale.LC_ALL, x))
def configure_ioloop():
kwargs = {}
if options.ioloop_time_monotonic:
from tornado.platform.auto import monotonic_time
if monotonic_time is None:
raise RuntimeError("monotonic clock not found")
kwargs['time_func'] = monotonic_time
if options.ioloop or kwargs:
IOLoop.configure(options.ioloop, **kwargs)
add_parse_callback(configure_ioloop)
import tornado.testing
kwargs = {}
if sys.version_info >= (3, 2):
# HACK: unittest.main will make its own changes to the warning
# configuration, which may conflict with the settings above
# or command-line flags like -bb. Passing warnings=False
# suppresses this behavior, although this looks like an implementation
# detail. http://bugs.python.org/issue15626
kwargs['warnings'] = False
kwargs['testRunner'] = TornadoTextTestRunner
tornado.testing.main(**kwargs)

View file

@ -0,0 +1,398 @@
from __future__ import absolute_import, division, print_function, with_statement
import collections
from contextlib import closing
import errno
import gzip
import logging
import os
import re
import socket
import sys
from tornado.httpclient import AsyncHTTPClient
from tornado.httputil import HTTPHeaders
from tornado.ioloop import IOLoop
from tornado.log import gen_log
from tornado.simple_httpclient import SimpleAsyncHTTPClient, _DEFAULT_CA_CERTS
from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler
from tornado.test import httpclient_test
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog
from tornado.test.util import unittest, skipOnTravis
from tornado.web import RequestHandler, Application, asynchronous, url
class SimpleHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
def get_http_client(self):
client = SimpleAsyncHTTPClient(io_loop=self.io_loop,
force_instance=True)
self.assertTrue(isinstance(client, SimpleAsyncHTTPClient))
return client
class TriggerHandler(RequestHandler):
def initialize(self, queue, wake_callback):
self.queue = queue
self.wake_callback = wake_callback
@asynchronous
def get(self):
logging.debug("queuing trigger")
self.queue.append(self.finish)
if self.get_argument("wake", "true") == "true":
self.wake_callback()
class HangHandler(RequestHandler):
@asynchronous
def get(self):
pass
class ContentLengthHandler(RequestHandler):
def get(self):
self.set_header("Content-Length", self.get_argument("value"))
self.write("ok")
class HeadHandler(RequestHandler):
def head(self):
self.set_header("Content-Length", "7")
class OptionsHandler(RequestHandler):
def options(self):
self.set_header("Access-Control-Allow-Origin", "*")
self.write("ok")
class NoContentHandler(RequestHandler):
def get(self):
if self.get_argument("error", None):
self.set_header("Content-Length", "7")
self.set_status(204)
class SeeOtherPostHandler(RequestHandler):
def post(self):
redirect_code = int(self.request.body)
assert redirect_code in (302, 303), "unexpected body %r" % self.request.body
self.set_header("Location", "/see_other_get")
self.set_status(redirect_code)
class SeeOtherGetHandler(RequestHandler):
def get(self):
if self.request.body:
raise Exception("unexpected body %r" % self.request.body)
self.write("ok")
class HostEchoHandler(RequestHandler):
def get(self):
self.write(self.request.headers["Host"])
class SimpleHTTPClientTestMixin(object):
def get_app(self):
# callable objects to finish pending /trigger requests
self.triggers = collections.deque()
return Application([
url("/trigger", TriggerHandler, dict(queue=self.triggers,
wake_callback=self.stop)),
url("/chunk", ChunkHandler),
url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
url("/hang", HangHandler),
url("/hello", HelloWorldHandler),
url("/content_length", ContentLengthHandler),
url("/head", HeadHandler),
url("/options", OptionsHandler),
url("/no_content", NoContentHandler),
url("/see_other_post", SeeOtherPostHandler),
url("/see_other_get", SeeOtherGetHandler),
url("/host_echo", HostEchoHandler),
], gzip=True)
def test_singleton(self):
# Class "constructor" reuses objects on the same IOLoop
self.assertTrue(SimpleAsyncHTTPClient(self.io_loop) is
SimpleAsyncHTTPClient(self.io_loop))
# unless force_instance is used
self.assertTrue(SimpleAsyncHTTPClient(self.io_loop) is not
SimpleAsyncHTTPClient(self.io_loop,
force_instance=True))
# different IOLoops use different objects
io_loop2 = IOLoop()
self.assertTrue(SimpleAsyncHTTPClient(self.io_loop) is not
SimpleAsyncHTTPClient(io_loop2))
def test_connection_limit(self):
with closing(self.create_client(max_clients=2)) as client:
self.assertEqual(client.max_clients, 2)
seen = []
# Send 4 requests. Two can be sent immediately, while the others
# will be queued
for i in range(4):
client.fetch(self.get_url("/trigger"),
lambda response, i=i: (seen.append(i), self.stop()))
self.wait(condition=lambda: len(self.triggers) == 2)
self.assertEqual(len(client.queue), 2)
# Finish the first two requests and let the next two through
self.triggers.popleft()()
self.triggers.popleft()()
self.wait(condition=lambda: (len(self.triggers) == 2 and
len(seen) == 2))
self.assertEqual(set(seen), set([0, 1]))
self.assertEqual(len(client.queue), 0)
# Finish all the pending requests
self.triggers.popleft()()
self.triggers.popleft()()
self.wait(condition=lambda: len(seen) == 4)
self.assertEqual(set(seen), set([0, 1, 2, 3]))
self.assertEqual(len(self.triggers), 0)
def test_redirect_connection_limit(self):
# following redirects should not consume additional connections
with closing(self.create_client(max_clients=1)) as client:
client.fetch(self.get_url('/countdown/3'), self.stop,
max_redirects=3)
response = self.wait()
response.rethrow()
def test_default_certificates_exist(self):
open(_DEFAULT_CA_CERTS).close()
def test_gzip(self):
# All the tests in this file should be using gzip, but this test
# ensures that it is in fact getting compressed.
# Setting Accept-Encoding manually bypasses the client's
# decompression so we can see the raw data.
response = self.fetch("/chunk", use_gzip=False,
headers={"Accept-Encoding": "gzip"})
self.assertEqual(response.headers["Content-Encoding"], "gzip")
self.assertNotEqual(response.body, b"asdfqwer")
# Our test data gets bigger when gzipped. Oops. :)
self.assertEqual(len(response.body), 34)
f = gzip.GzipFile(mode="r", fileobj=response.buffer)
self.assertEqual(f.read(), b"asdfqwer")
def test_max_redirects(self):
response = self.fetch("/countdown/5", max_redirects=3)
self.assertEqual(302, response.code)
# We requested 5, followed three redirects for 4, 3, 2, then the last
# unfollowed redirect is to 1.
self.assertTrue(response.request.url.endswith("/countdown/5"))
self.assertTrue(response.effective_url.endswith("/countdown/2"))
self.assertTrue(response.headers["Location"].endswith("/countdown/1"))
def test_header_reuse(self):
# Apps may reuse a headers object if they are only passing in constant
# headers like user-agent. The header object should not be modified.
headers = HTTPHeaders({'User-Agent': 'Foo'})
self.fetch("/hello", headers=headers)
self.assertEqual(list(headers.get_all()), [('User-Agent', 'Foo')])
def test_see_other_redirect(self):
for code in (302, 303):
response = self.fetch("/see_other_post", method="POST", body="%d" % code)
self.assertEqual(200, response.code)
self.assertTrue(response.request.url.endswith("/see_other_post"))
self.assertTrue(response.effective_url.endswith("/see_other_get"))
# request is the original request, is a POST still
self.assertEqual("POST", response.request.method)
@skipOnTravis
def test_request_timeout(self):
response = self.fetch('/trigger?wake=false', request_timeout=0.1)
self.assertEqual(response.code, 599)
self.assertTrue(0.099 < response.request_time < 0.15, response.request_time)
self.assertEqual(str(response.error), "HTTP 599: Timeout")
# trigger the hanging request to let it clean up after itself
self.triggers.popleft()()
@unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present')
def test_ipv6(self):
try:
self.http_server.listen(self.get_http_port(), address='::1')
except socket.gaierror as e:
if e.args[0] == socket.EAI_ADDRFAMILY:
# python supports ipv6, but it's not configured on the network
# interface, so skip this test.
return
raise
url = self.get_url("/hello").replace("localhost", "[::1]")
# ipv6 is currently disabled by default and must be explicitly requested
self.http_client.fetch(url, self.stop)
response = self.wait()
self.assertEqual(response.code, 599)
self.http_client.fetch(url, self.stop, allow_ipv6=True)
response = self.wait()
self.assertEqual(response.body, b"Hello world!")
def test_multiple_content_length_accepted(self):
response = self.fetch("/content_length?value=2,2")
self.assertEqual(response.body, b"ok")
response = self.fetch("/content_length?value=2,%202,2")
self.assertEqual(response.body, b"ok")
response = self.fetch("/content_length?value=2,4")
self.assertEqual(response.code, 599)
response = self.fetch("/content_length?value=2,%202,3")
self.assertEqual(response.code, 599)
def test_head_request(self):
response = self.fetch("/head", method="HEAD")
self.assertEqual(response.code, 200)
self.assertEqual(response.headers["content-length"], "7")
self.assertFalse(response.body)
def test_options_request(self):
response = self.fetch("/options", method="OPTIONS")
self.assertEqual(response.code, 200)
self.assertEqual(response.headers["content-length"], "2")
self.assertEqual(response.headers["access-control-allow-origin"], "*")
self.assertEqual(response.body, b"ok")
def test_no_content(self):
response = self.fetch("/no_content")
self.assertEqual(response.code, 204)
# 204 status doesn't need a content-length, but tornado will
# add a zero content-length anyway.
self.assertEqual(response.headers["Content-length"], "0")
# 204 status with non-zero content length is malformed
response = self.fetch("/no_content?error=1")
self.assertEqual(response.code, 599)
def test_host_header(self):
host_re = re.compile(b"^localhost:[0-9]+$")
response = self.fetch("/host_echo")
self.assertTrue(host_re.match(response.body))
url = self.get_url("/host_echo").replace("http://", "http://me:secret@")
self.http_client.fetch(url, self.stop)
response = self.wait()
self.assertTrue(host_re.match(response.body), response.body)
def test_connection_refused(self):
server_socket, port = bind_unused_port()
server_socket.close()
with ExpectLog(gen_log, ".*", required=False):
self.http_client.fetch("http://localhost:%d/" % port, self.stop)
response = self.wait()
self.assertEqual(599, response.code)
if sys.platform != 'cygwin':
# cygwin returns EPERM instead of ECONNREFUSED here
self.assertTrue(str(errno.ECONNREFUSED) in str(response.error),
response.error)
# This is usually "Connection refused".
# On windows, strerror is broken and returns "Unknown error".
expected_message = os.strerror(errno.ECONNREFUSED)
self.assertTrue(expected_message in str(response.error),
response.error)
class SimpleHTTPClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPTestCase):
def setUp(self):
super(SimpleHTTPClientTestCase, self).setUp()
self.http_client = self.create_client()
def create_client(self, **kwargs):
return SimpleAsyncHTTPClient(self.io_loop, force_instance=True,
**kwargs)
class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase):
def setUp(self):
super(SimpleHTTPSClientTestCase, self).setUp()
self.http_client = self.create_client()
def create_client(self, **kwargs):
return SimpleAsyncHTTPClient(self.io_loop, force_instance=True,
defaults=dict(validate_cert=False),
**kwargs)
class CreateAsyncHTTPClientTestCase(AsyncTestCase):
def setUp(self):
super(CreateAsyncHTTPClientTestCase, self).setUp()
self.saved = AsyncHTTPClient._save_configuration()
def tearDown(self):
AsyncHTTPClient._restore_configuration(self.saved)
super(CreateAsyncHTTPClientTestCase, self).tearDown()
def test_max_clients(self):
AsyncHTTPClient.configure(SimpleAsyncHTTPClient)
with closing(AsyncHTTPClient(
self.io_loop, force_instance=True)) as client:
self.assertEqual(client.max_clients, 10)
with closing(AsyncHTTPClient(
self.io_loop, max_clients=11, force_instance=True)) as client:
self.assertEqual(client.max_clients, 11)
# Now configure max_clients statically and try overriding it
# with each way max_clients can be passed
AsyncHTTPClient.configure(SimpleAsyncHTTPClient, max_clients=12)
with closing(AsyncHTTPClient(
self.io_loop, force_instance=True)) as client:
self.assertEqual(client.max_clients, 12)
with closing(AsyncHTTPClient(
self.io_loop, max_clients=13, force_instance=True)) as client:
self.assertEqual(client.max_clients, 13)
with closing(AsyncHTTPClient(
self.io_loop, max_clients=14, force_instance=True)) as client:
self.assertEqual(client.max_clients, 14)
class HTTP100ContinueTestCase(AsyncHTTPTestCase):
def respond_100(self, request):
self.request = request
self.request.connection.stream.write(
b"HTTP/1.1 100 CONTINUE\r\n\r\n",
self.respond_200)
def respond_200(self):
self.request.connection.stream.write(
b"HTTP/1.1 200 OK\r\nContent-Length: 1\r\n\r\nA",
self.request.connection.stream.close)
def get_app(self):
# Not a full Application, but works as an HTTPServer callback
return self.respond_100
def test_100_continue(self):
res = self.fetch('/')
self.assertEqual(res.body, b'A')
class HostnameMappingTestCase(AsyncHTTPTestCase):
def setUp(self):
super(HostnameMappingTestCase, self).setUp()
self.http_client = SimpleAsyncHTTPClient(
self.io_loop,
hostname_mapping={
'www.example.com': '127.0.0.1',
('foo.example.com', 8000): ('127.0.0.1', self.get_http_port()),
})
def get_app(self):
return Application([url("/hello", HelloWorldHandler), ])
def test_hostname_mapping(self):
self.http_client.fetch(
'http://www.example.com:%d/hello' % self.get_http_port(), self.stop)
response = self.wait()
response.rethrow()
self.assertEqual(response.body, b'Hello world!')
def test_port_mapping(self):
self.http_client.fetch('http://foo.example.com:8000/hello', self.stop)
response = self.wait()
response.rethrow()
self.assertEqual(response.body, b'Hello world!')

View file

@ -0,0 +1,281 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
from tornado import gen
from tornado.log import app_log
from tornado.stack_context import (StackContext, wrap, NullContext, StackContextInconsistentError,
ExceptionStackContext, run_with_stack_context, _state)
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
from tornado.test.util import unittest
from tornado.web import asynchronous, Application, RequestHandler
import contextlib
import functools
import logging
class TestRequestHandler(RequestHandler):
def __init__(self, app, request, io_loop):
super(TestRequestHandler, self).__init__(app, request)
self.io_loop = io_loop
@asynchronous
def get(self):
logging.debug('in get()')
# call self.part2 without a self.async_callback wrapper. Its
# exception should still get thrown
self.io_loop.add_callback(self.part2)
def part2(self):
logging.debug('in part2()')
# Go through a third layer to make sure that contexts once restored
# are again passed on to future callbacks
self.io_loop.add_callback(self.part3)
def part3(self):
logging.debug('in part3()')
raise Exception('test exception')
def get_error_html(self, status_code, **kwargs):
if 'exception' in kwargs and str(kwargs['exception']) == 'test exception':
return 'got expected exception'
else:
return 'unexpected failure'
class HTTPStackContextTest(AsyncHTTPTestCase):
def get_app(self):
return Application([('/', TestRequestHandler,
dict(io_loop=self.io_loop))])
def test_stack_context(self):
with ExpectLog(app_log, "Uncaught exception GET /"):
self.http_client.fetch(self.get_url('/'), self.handle_response)
self.wait()
self.assertEqual(self.response.code, 500)
self.assertTrue(b'got expected exception' in self.response.body)
def handle_response(self, response):
self.response = response
self.stop()
class StackContextTest(AsyncTestCase):
def setUp(self):
super(StackContextTest, self).setUp()
self.active_contexts = []
@contextlib.contextmanager
def context(self, name):
self.active_contexts.append(name)
yield
self.assertEqual(self.active_contexts.pop(), name)
# Simulates the effect of an asynchronous library that uses its own
# StackContext internally and then returns control to the application.
def test_exit_library_context(self):
def library_function(callback):
# capture the caller's context before introducing our own
callback = wrap(callback)
with StackContext(functools.partial(self.context, 'library')):
self.io_loop.add_callback(
functools.partial(library_inner_callback, callback))
def library_inner_callback(callback):
self.assertEqual(self.active_contexts[-2:],
['application', 'library'])
callback()
def final_callback():
# implementation detail: the full context stack at this point
# is ['application', 'library', 'application']. The 'library'
# context was not removed, but is no longer innermost so
# the application context takes precedence.
self.assertEqual(self.active_contexts[-1], 'application')
self.stop()
with StackContext(functools.partial(self.context, 'application')):
library_function(final_callback)
self.wait()
def test_deactivate(self):
deactivate_callbacks = []
def f1():
with StackContext(functools.partial(self.context, 'c1')) as c1:
deactivate_callbacks.append(c1)
self.io_loop.add_callback(f2)
def f2():
with StackContext(functools.partial(self.context, 'c2')) as c2:
deactivate_callbacks.append(c2)
self.io_loop.add_callback(f3)
def f3():
with StackContext(functools.partial(self.context, 'c3')) as c3:
deactivate_callbacks.append(c3)
self.io_loop.add_callback(f4)
def f4():
self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
deactivate_callbacks[1]()
# deactivating a context doesn't remove it immediately,
# but it will be missing from the next iteration
self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
self.io_loop.add_callback(f5)
def f5():
self.assertEqual(self.active_contexts, ['c1', 'c3'])
self.stop()
self.io_loop.add_callback(f1)
self.wait()
def test_deactivate_order(self):
# Stack context deactivation has separate logic for deactivation at
# the head and tail of the stack, so make sure it works in any order.
def check_contexts():
# Make sure that the full-context array and the exception-context
# linked lists are consistent with each other.
full_contexts, chain = _state.contexts
exception_contexts = []
while chain is not None:
exception_contexts.append(chain)
chain = chain.old_contexts[1]
self.assertEqual(list(reversed(full_contexts)), exception_contexts)
return list(self.active_contexts)
def make_wrapped_function():
"""Wraps a function in three stack contexts, and returns
the function along with the deactivation functions.
"""
# Remove the test's stack context to make sure we can cover
# the case where the last context is deactivated.
with NullContext():
partial = functools.partial
with StackContext(partial(self.context, 'c0')) as c0:
with StackContext(partial(self.context, 'c1')) as c1:
with StackContext(partial(self.context, 'c2')) as c2:
return (wrap(check_contexts), [c0, c1, c2])
# First make sure the test mechanism works without any deactivations
func, deactivate_callbacks = make_wrapped_function()
self.assertEqual(func(), ['c0', 'c1', 'c2'])
# Deactivate the tail
func, deactivate_callbacks = make_wrapped_function()
deactivate_callbacks[0]()
self.assertEqual(func(), ['c1', 'c2'])
# Deactivate the middle
func, deactivate_callbacks = make_wrapped_function()
deactivate_callbacks[1]()
self.assertEqual(func(), ['c0', 'c2'])
# Deactivate the head
func, deactivate_callbacks = make_wrapped_function()
deactivate_callbacks[2]()
self.assertEqual(func(), ['c0', 'c1'])
def test_isolation_nonempty(self):
# f2 and f3 are a chain of operations started in context c1.
# f2 is incidentally run under context c2, but that context should
# not be passed along to f3.
def f1():
with StackContext(functools.partial(self.context, 'c1')):
wrapped = wrap(f2)
with StackContext(functools.partial(self.context, 'c2')):
wrapped()
def f2():
self.assertIn('c1', self.active_contexts)
self.io_loop.add_callback(f3)
def f3():
self.assertIn('c1', self.active_contexts)
self.assertNotIn('c2', self.active_contexts)
self.stop()
self.io_loop.add_callback(f1)
self.wait()
def test_isolation_empty(self):
# Similar to test_isolation_nonempty, but here the f2/f3 chain
# is started without any context. Behavior should be equivalent
# to the nonempty case (although historically it was not)
def f1():
with NullContext():
wrapped = wrap(f2)
with StackContext(functools.partial(self.context, 'c2')):
wrapped()
def f2():
self.io_loop.add_callback(f3)
def f3():
self.assertNotIn('c2', self.active_contexts)
self.stop()
self.io_loop.add_callback(f1)
self.wait()
def test_yield_in_with(self):
@gen.engine
def f():
with StackContext(functools.partial(self.context, 'c1')):
# This yield is a problem: the generator will be suspended
# and the StackContext's __exit__ is not called yet, so
# the context will be left on _state.contexts for anything
# that runs before the yield resolves.
yield gen.Task(self.io_loop.add_callback)
with self.assertRaises(StackContextInconsistentError):
f()
self.wait()
@gen_test
def test_yield_outside_with(self):
# This pattern avoids the problem in the previous test.
cb = yield gen.Callback('k1')
with StackContext(functools.partial(self.context, 'c1')):
self.io_loop.add_callback(cb)
yield gen.Wait('k1')
def test_yield_in_with_exception_stack_context(self):
# As above, but with ExceptionStackContext instead of StackContext.
@gen.engine
def f():
with ExceptionStackContext(lambda t, v, tb: False):
yield gen.Task(self.io_loop.add_callback)
with self.assertRaises(StackContextInconsistentError):
f()
self.wait()
@gen_test
def test_yield_outside_with_exception_stack_context(self):
cb = yield gen.Callback('k1')
with ExceptionStackContext(lambda t, v, tb: False):
self.io_loop.add_callback(cb)
yield gen.Wait('k1')
def test_run_with_stack_context(self):
@gen.coroutine
def f1():
self.assertEqual(self.active_contexts, ['c1'])
yield run_with_stack_context(
StackContext(functools.partial(self.context, 'c1')),
f2)
self.assertEqual(self.active_contexts, ['c1'])
@gen.coroutine
def f2():
self.assertEqual(self.active_contexts, ['c1', 'c2'])
yield gen.Task(self.io_loop.add_callback)
self.assertEqual(self.active_contexts, ['c1', 'c2'])
self.assertEqual(self.active_contexts, [])
run_with_stack_context(
StackContext(functools.partial(self.context, 'c1')),
f1)
self.assertEqual(self.active_contexts, [])
if __name__ == '__main__':
unittest.main()

View file

@ -0,0 +1 @@
this is the index

View file

@ -0,0 +1,2 @@
User-agent: *
Disallow: /

View file

@ -0,0 +1,391 @@
from __future__ import absolute_import, division, print_function, with_statement
import os
import sys
import traceback
from tornado.escape import utf8, native_str, to_unicode
from tornado.template import Template, DictLoader, ParseError, Loader
from tornado.test.util import unittest
from tornado.util import u, bytes_type, ObjectDict, unicode_type
class TemplateTest(unittest.TestCase):
def test_simple(self):
template = Template("Hello {{ name }}!")
self.assertEqual(template.generate(name="Ben"),
b"Hello Ben!")
def test_bytes(self):
template = Template("Hello {{ name }}!")
self.assertEqual(template.generate(name=utf8("Ben")),
b"Hello Ben!")
def test_expressions(self):
template = Template("2 + 2 = {{ 2 + 2 }}")
self.assertEqual(template.generate(), b"2 + 2 = 4")
def test_comment(self):
template = Template("Hello{# TODO i18n #} {{ name }}!")
self.assertEqual(template.generate(name=utf8("Ben")),
b"Hello Ben!")
def test_include(self):
loader = DictLoader({
"index.html": '{% include "header.html" %}\nbody text',
"header.html": "header text",
})
self.assertEqual(loader.load("index.html").generate(),
b"header text\nbody text")
def test_extends(self):
loader = DictLoader({
"base.html": """\
<title>{% block title %}default title{% end %}</title>
<body>{% block body %}default body{% end %}</body>
""",
"page.html": """\
{% extends "base.html" %}
{% block title %}page title{% end %}
{% block body %}page body{% end %}
""",
})
self.assertEqual(loader.load("page.html").generate(),
b"<title>page title</title>\n<body>page body</body>\n")
def test_relative_load(self):
loader = DictLoader({
"a/1.html": "{% include '2.html' %}",
"a/2.html": "{% include '../b/3.html' %}",
"b/3.html": "ok",
})
self.assertEqual(loader.load("a/1.html").generate(),
b"ok")
def test_escaping(self):
self.assertRaises(ParseError, lambda: Template("{{"))
self.assertRaises(ParseError, lambda: Template("{%"))
self.assertEqual(Template("{{!").generate(), b"{{")
self.assertEqual(Template("{%!").generate(), b"{%")
self.assertEqual(Template("{{ 'expr' }} {{!jquery expr}}").generate(),
b"expr {{jquery expr}}")
def test_unicode_template(self):
template = Template(utf8(u("\u00e9")))
self.assertEqual(template.generate(), utf8(u("\u00e9")))
def test_unicode_literal_expression(self):
# Unicode literals should be usable in templates. Note that this
# test simulates unicode characters appearing directly in the
# template file (with utf8 encoding), i.e. \u escapes would not
# be used in the template file itself.
if str is unicode_type:
# python 3 needs a different version of this test since
# 2to3 doesn't run on template internals
template = Template(utf8(u('{{ "\u00e9" }}')))
else:
template = Template(utf8(u('{{ u"\u00e9" }}')))
self.assertEqual(template.generate(), utf8(u("\u00e9")))
def test_custom_namespace(self):
loader = DictLoader({"test.html": "{{ inc(5) }}"}, namespace={"inc": lambda x: x + 1})
self.assertEqual(loader.load("test.html").generate(), b"6")
def test_apply(self):
def upper(s):
return s.upper()
template = Template(utf8("{% apply upper %}foo{% end %}"))
self.assertEqual(template.generate(upper=upper), b"FOO")
def test_unicode_apply(self):
def upper(s):
return to_unicode(s).upper()
template = Template(utf8(u("{% apply upper %}foo \u00e9{% end %}")))
self.assertEqual(template.generate(upper=upper), utf8(u("FOO \u00c9")))
def test_bytes_apply(self):
def upper(s):
return utf8(to_unicode(s).upper())
template = Template(utf8(u("{% apply upper %}foo \u00e9{% end %}")))
self.assertEqual(template.generate(upper=upper), utf8(u("FOO \u00c9")))
def test_if(self):
template = Template(utf8("{% if x > 4 %}yes{% else %}no{% end %}"))
self.assertEqual(template.generate(x=5), b"yes")
self.assertEqual(template.generate(x=3), b"no")
def test_if_empty_body(self):
template = Template(utf8("{% if True %}{% else %}{% end %}"))
self.assertEqual(template.generate(), b"")
def test_try(self):
template = Template(utf8("""{% try %}
try{% set y = 1/x %}
{% except %}-except
{% else %}-else
{% finally %}-finally
{% end %}"""))
self.assertEqual(template.generate(x=1), b"\ntry\n-else\n-finally\n")
self.assertEqual(template.generate(x=0), b"\ntry-except\n-finally\n")
def test_comment_directive(self):
template = Template(utf8("{% comment blah blah %}foo"))
self.assertEqual(template.generate(), b"foo")
def test_break_continue(self):
template = Template(utf8("""\
{% for i in range(10) %}
{% if i == 2 %}
{% continue %}
{% end %}
{{ i }}
{% if i == 6 %}
{% break %}
{% end %}
{% end %}"""))
result = template.generate()
# remove extraneous whitespace
result = b''.join(result.split())
self.assertEqual(result, b"013456")
def test_break_outside_loop(self):
try:
Template(utf8("{% break %}"))
raise Exception("Did not get expected exception")
except ParseError:
pass
def test_break_in_apply(self):
# This test verifies current behavior, although of course it would
# be nice if apply didn't cause seemingly unrelated breakage
try:
Template(utf8("{% for i in [] %}{% apply foo %}{% break %}{% end %}{% end %}"))
raise Exception("Did not get expected exception")
except ParseError:
pass
@unittest.skipIf(sys.version_info >= division.getMandatoryRelease(),
'no testable future imports')
def test_no_inherit_future(self):
# This file has from __future__ import division...
self.assertEqual(1 / 2, 0.5)
# ...but the template doesn't
template = Template('{{ 1 / 2 }}')
self.assertEqual(template.generate(), '0')
class StackTraceTest(unittest.TestCase):
def test_error_line_number_expression(self):
loader = DictLoader({"test.html": """one
two{{1/0}}
three
"""})
try:
loader.load("test.html").generate()
except ZeroDivisionError:
self.assertTrue("# test.html:2" in traceback.format_exc())
def test_error_line_number_directive(self):
loader = DictLoader({"test.html": """one
two{%if 1/0%}
three{%end%}
"""})
try:
loader.load("test.html").generate()
except ZeroDivisionError:
self.assertTrue("# test.html:2" in traceback.format_exc())
def test_error_line_number_module(self):
loader = DictLoader({
"base.html": "{% module Template('sub.html') %}",
"sub.html": "{{1/0}}",
}, namespace={"_tt_modules": ObjectDict({"Template": lambda path, **kwargs: loader.load(path).generate(**kwargs)})})
try:
loader.load("base.html").generate()
except ZeroDivisionError:
exc_stack = traceback.format_exc()
self.assertTrue('# base.html:1' in exc_stack)
self.assertTrue('# sub.html:1' in exc_stack)
def test_error_line_number_include(self):
loader = DictLoader({
"base.html": "{% include 'sub.html' %}",
"sub.html": "{{1/0}}",
})
try:
loader.load("base.html").generate()
except ZeroDivisionError:
self.assertTrue("# sub.html:1 (via base.html:1)" in
traceback.format_exc())
def test_error_line_number_extends_base_error(self):
loader = DictLoader({
"base.html": "{{1/0}}",
"sub.html": "{% extends 'base.html' %}",
})
try:
loader.load("sub.html").generate()
except ZeroDivisionError:
exc_stack = traceback.format_exc()
self.assertTrue("# base.html:1" in exc_stack)
def test_error_line_number_extends_sub_error(self):
loader = DictLoader({
"base.html": "{% block 'block' %}{% end %}",
"sub.html": """
{% extends 'base.html' %}
{% block 'block' %}
{{1/0}}
{% end %}
"""})
try:
loader.load("sub.html").generate()
except ZeroDivisionError:
self.assertTrue("# sub.html:4 (via base.html:1)" in
traceback.format_exc())
def test_multi_includes(self):
loader = DictLoader({
"a.html": "{% include 'b.html' %}",
"b.html": "{% include 'c.html' %}",
"c.html": "{{1/0}}",
})
try:
loader.load("a.html").generate()
except ZeroDivisionError:
self.assertTrue("# c.html:1 (via b.html:1, a.html:1)" in
traceback.format_exc())
class AutoEscapeTest(unittest.TestCase):
def setUp(self):
self.templates = {
"escaped.html": "{% autoescape xhtml_escape %}{{ name }}",
"unescaped.html": "{% autoescape None %}{{ name }}",
"default.html": "{{ name }}",
"include.html": """\
escaped: {% include 'escaped.html' %}
unescaped: {% include 'unescaped.html' %}
default: {% include 'default.html' %}
""",
"escaped_block.html": """\
{% autoescape xhtml_escape %}\
{% block name %}base: {{ name }}{% end %}""",
"unescaped_block.html": """\
{% autoescape None %}\
{% block name %}base: {{ name }}{% end %}""",
# Extend a base template with different autoescape policy,
# with and without overriding the base's blocks
"escaped_extends_unescaped.html": """\
{% autoescape xhtml_escape %}\
{% extends "unescaped_block.html" %}""",
"escaped_overrides_unescaped.html": """\
{% autoescape xhtml_escape %}\
{% extends "unescaped_block.html" %}\
{% block name %}extended: {{ name }}{% end %}""",
"unescaped_extends_escaped.html": """\
{% autoescape None %}\
{% extends "escaped_block.html" %}""",
"unescaped_overrides_escaped.html": """\
{% autoescape None %}\
{% extends "escaped_block.html" %}\
{% block name %}extended: {{ name }}{% end %}""",
"raw_expression.html": """\
{% autoescape xhtml_escape %}\
expr: {{ name }}
raw: {% raw name %}""",
}
def test_default_off(self):
loader = DictLoader(self.templates, autoescape=None)
name = "Bobby <table>s"
self.assertEqual(loader.load("escaped.html").generate(name=name),
b"Bobby &lt;table&gt;s")
self.assertEqual(loader.load("unescaped.html").generate(name=name),
b"Bobby <table>s")
self.assertEqual(loader.load("default.html").generate(name=name),
b"Bobby <table>s")
self.assertEqual(loader.load("include.html").generate(name=name),
b"escaped: Bobby &lt;table&gt;s\n"
b"unescaped: Bobby <table>s\n"
b"default: Bobby <table>s\n")
def test_default_on(self):
loader = DictLoader(self.templates, autoescape="xhtml_escape")
name = "Bobby <table>s"
self.assertEqual(loader.load("escaped.html").generate(name=name),
b"Bobby &lt;table&gt;s")
self.assertEqual(loader.load("unescaped.html").generate(name=name),
b"Bobby <table>s")
self.assertEqual(loader.load("default.html").generate(name=name),
b"Bobby &lt;table&gt;s")
self.assertEqual(loader.load("include.html").generate(name=name),
b"escaped: Bobby &lt;table&gt;s\n"
b"unescaped: Bobby <table>s\n"
b"default: Bobby &lt;table&gt;s\n")
def test_unextended_block(self):
loader = DictLoader(self.templates)
name = "<script>"
self.assertEqual(loader.load("escaped_block.html").generate(name=name),
b"base: &lt;script&gt;")
self.assertEqual(loader.load("unescaped_block.html").generate(name=name),
b"base: <script>")
def test_extended_block(self):
loader = DictLoader(self.templates)
def render(name):
return loader.load(name).generate(name="<script>")
self.assertEqual(render("escaped_extends_unescaped.html"),
b"base: <script>")
self.assertEqual(render("escaped_overrides_unescaped.html"),
b"extended: &lt;script&gt;")
self.assertEqual(render("unescaped_extends_escaped.html"),
b"base: &lt;script&gt;")
self.assertEqual(render("unescaped_overrides_escaped.html"),
b"extended: <script>")
def test_raw_expression(self):
loader = DictLoader(self.templates)
def render(name):
return loader.load(name).generate(name='<>&"')
self.assertEqual(render("raw_expression.html"),
b"expr: &lt;&gt;&amp;&quot;\n"
b"raw: <>&\"")
def test_custom_escape(self):
loader = DictLoader({"foo.py":
"{% autoescape py_escape %}s = {{ name }}\n"})
def py_escape(s):
self.assertEqual(type(s), bytes_type)
return repr(native_str(s))
def render(template, name):
return loader.load(template).generate(py_escape=py_escape,
name=name)
self.assertEqual(render("foo.py", "<html>"),
b"s = '<html>'\n")
self.assertEqual(render("foo.py", "';sys.exit()"),
b"""s = "';sys.exit()"\n""")
self.assertEqual(render("foo.py", ["not a string"]),
b"""s = "['not a string']"\n""")
class TemplateLoaderTest(unittest.TestCase):
def setUp(self):
self.loader = Loader(os.path.join(os.path.dirname(__file__), "templates"))
def test_utf8_in_file(self):
tmpl = self.loader.load("utf8.html")
result = tmpl.generate()
self.assertEqual(to_unicode(result).strip(), u("H\u00e9llo"))

View file

@ -0,0 +1 @@
Héllo

View file

@ -0,0 +1,15 @@
-----BEGIN CERTIFICATE-----
MIICSDCCAbGgAwIBAgIJAN1oTowzMbkzMA0GCSqGSIb3DQEBBQUAMD0xCzAJBgNV
BAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRkwFwYDVQQKDBBUb3JuYWRvIFdl
YiBUZXN0MB4XDTEwMDgyNTE4MjQ0NFoXDTIwMDgyMjE4MjQ0NFowPTELMAkGA1UE
BhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExGTAXBgNVBAoMEFRvcm5hZG8gV2Vi
IFRlc3QwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBALirW3mX4jbdFse2aZwW
zszCJ1IsRDrzALpbvMYLLbIZqo+Z8v5aERKTRQpXFqGaZyY+tdwYy7X7YXcLtKqv
jnw/MSeIaqkw5pROKz5aR0nkPLvcTmhJVLVPCLc8dFnIlu8aC9TrDhr90P+PzU39
UG7zLweA9zXKBuW3Tjo5dMP3AgMBAAGjUDBOMB0GA1UdDgQWBBRhJjMBYrzddCFr
/0vvPyHMeqgo0TAfBgNVHSMEGDAWgBRhJjMBYrzddCFr/0vvPyHMeqgo0TAMBgNV
HRMEBTADAQH/MA0GCSqGSIb3DQEBBQUAA4GBAGP6GaxSfb21bikcqaK3ZKCC1sRJ
tiCuvJZbBUFUCAzl05dYUfJZim/oWK+GqyUkUB8ciYivUNnn9OtS7DnlTgT2ws2e
lNgn5cuFXoAGcHXzVlHG3yoywYBf3y0Dn20uzrlLXUWJAzoSLOt2LTaXvwlgm7hF
W1q8SQ6UBshRw2X0
-----END CERTIFICATE-----

View file

@ -0,0 +1,16 @@
-----BEGIN PRIVATE KEY-----
MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALirW3mX4jbdFse2
aZwWzszCJ1IsRDrzALpbvMYLLbIZqo+Z8v5aERKTRQpXFqGaZyY+tdwYy7X7YXcL
tKqvjnw/MSeIaqkw5pROKz5aR0nkPLvcTmhJVLVPCLc8dFnIlu8aC9TrDhr90P+P
zU39UG7zLweA9zXKBuW3Tjo5dMP3AgMBAAECgYEAiygNaWYrf95AcUQi9w00zpUr
nj9fNvCwxr2kVbRMvd2balS/CC4EmXPCXdVcZ3B7dBVjYzSIJV0Fh/iZLtnVysD9
fcNMZ+Cz71b/T0ItsNYOsJk0qUVyP52uqsqkNppIPJsD19C+ZeMLZj6iEiylZyl8
2U16c/kVIjER63mUEGkCQQDayQOTGPJrKHqPAkUqzeJkfvHH2yCf+cySU+w6ezyr
j9yxcq8aZoLusCebDVT+kz7RqnD5JePFvB38cMuepYBLAkEA2BTFdZx30f4moPNv
JlXlPNJMUTUzsXG7n4vNc+18O5ous0NGQII8jZWrIcTrP8wiP9fF3JwUsKrJhcBn
xRs3hQJBAIDUgz1YIE+HW3vgi1gkOh6RPdBAsVpiXtr/fggFz3j60qrO7FswaAMj
SX8c/6KUlBYkNjgP3qruFf4zcUNvEzcCQQCaioCPFVE9ByBpjLG6IUTKsz2R9xL5
nfYqrbpLZ1aq6iLsYvkjugHE4X57sHLwNfdo4dHJbnf9wqhO2MVe25BhAkBdKYpY
7OKc/2mmMbJDhVBgoixz/muN/5VjdfbvVY48naZkJF1p1tmogqPC5F1jPCS4rM+S
FfPJIHRNEn2oktw5
-----END PRIVATE KEY-----

View file

@ -0,0 +1,159 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
from tornado import gen, ioloop
from tornado.testing import AsyncTestCase, gen_test
from tornado.test.util import unittest
import contextlib
import os
@contextlib.contextmanager
def set_environ(name, value):
old_value = os.environ.get('name')
os.environ[name] = value
try:
yield
finally:
if old_value is None:
del os.environ[name]
else:
os.environ[name] = old_value
class AsyncTestCaseTest(AsyncTestCase):
def test_exception_in_callback(self):
self.io_loop.add_callback(lambda: 1 / 0)
try:
self.wait()
self.fail("did not get expected exception")
except ZeroDivisionError:
pass
def test_wait_timeout(self):
time = self.io_loop.time
# Accept default 5-second timeout, no error
self.io_loop.add_timeout(time() + 0.01, self.stop)
self.wait()
# Timeout passed to wait()
self.io_loop.add_timeout(time() + 1, self.stop)
with self.assertRaises(self.failureException):
self.wait(timeout=0.01)
# Timeout set with environment variable
self.io_loop.add_timeout(time() + 1, self.stop)
with set_environ('ASYNC_TEST_TIMEOUT', '0.01'):
with self.assertRaises(self.failureException):
self.wait()
def test_subsequent_wait_calls(self):
"""
This test makes sure that a second call to wait()
clears the first timeout.
"""
self.io_loop.add_timeout(self.io_loop.time() + 0.01, self.stop)
self.wait(timeout=0.02)
self.io_loop.add_timeout(self.io_loop.time() + 0.03, self.stop)
self.wait(timeout=0.15)
class SetUpTearDownTest(unittest.TestCase):
def test_set_up_tear_down(self):
"""
This test makes sure that AsyncTestCase calls super methods for
setUp and tearDown.
InheritBoth is a subclass of both AsyncTestCase and
SetUpTearDown, with the ordering so that the super of
AsyncTestCase will be SetUpTearDown.
"""
events = []
result = unittest.TestResult()
class SetUpTearDown(unittest.TestCase):
def setUp(self):
events.append('setUp')
def tearDown(self):
events.append('tearDown')
class InheritBoth(AsyncTestCase, SetUpTearDown):
def test(self):
events.append('test')
InheritBoth('test').run(result)
expected = ['setUp', 'test', 'tearDown']
self.assertEqual(expected, events)
class GenTest(AsyncTestCase):
def setUp(self):
super(GenTest, self).setUp()
self.finished = False
def tearDown(self):
self.assertTrue(self.finished)
super(GenTest, self).tearDown()
@gen_test
def test_sync(self):
self.finished = True
@gen_test
def test_async(self):
yield gen.Task(self.io_loop.add_callback)
self.finished = True
def test_timeout(self):
# Set a short timeout and exceed it.
@gen_test(timeout=0.1)
def test(self):
yield gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1)
with self.assertRaises(ioloop.TimeoutError):
test(self)
self.finished = True
def test_no_timeout(self):
# A test that does not exceed its timeout should succeed.
@gen_test(timeout=1)
def test(self):
time = self.io_loop.time
yield gen.Task(self.io_loop.add_timeout, time() + 0.1)
test(self)
self.finished = True
def test_timeout_environment_variable(self):
@gen_test(timeout=0.5)
def test_long_timeout(self):
time = self.io_loop.time
yield gen.Task(self.io_loop.add_timeout, time() + 0.25)
# Uses provided timeout of 0.5 seconds, doesn't time out.
with set_environ('ASYNC_TEST_TIMEOUT', '0.1'):
test_long_timeout(self)
self.finished = True
def test_no_timeout_environment_variable(self):
@gen_test(timeout=0.01)
def test_short_timeout(self):
time = self.io_loop.time
yield gen.Task(self.io_loop.add_timeout, time() + 1)
# Uses environment-variable timeout of 0.1, times out.
with set_environ('ASYNC_TEST_TIMEOUT', '0.1'):
with self.assertRaises(ioloop.TimeoutError):
test_short_timeout(self)
self.finished = True
if __name__ == '__main__':
unittest.main()

View file

@ -0,0 +1,615 @@
# Author: Ovidiu Predescu
# Date: July 2011
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Unittest for the twisted-style reactor.
"""
from __future__ import absolute_import, division, print_function, with_statement
import os
import shutil
import signal
import tempfile
import threading
try:
import fcntl
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IReadDescriptor, IWriteDescriptor
from twisted.internet.protocol import Protocol
from twisted.python import log
from tornado.platform.twisted import TornadoReactor, TwistedIOLoop
from zope.interface import implementer
have_twisted = True
except ImportError:
have_twisted = False
# The core of Twisted 12.3.0 is available on python 3, but twisted.web is not
# so test for it separately.
try:
from twisted.web.client import Agent
from twisted.web.resource import Resource
from twisted.web.server import Site
have_twisted_web = True
except ImportError:
have_twisted_web = False
try:
import thread # py2
except ImportError:
import _thread as thread # py3
from tornado.httpclient import AsyncHTTPClient
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from tornado.platform.auto import set_close_exec
from tornado.platform.select import SelectIOLoop
from tornado.testing import bind_unused_port
from tornado.test.util import unittest
from tornado.util import import_object
from tornado.web import RequestHandler, Application
skipIfNoTwisted = unittest.skipUnless(have_twisted,
"twisted module not present")
def save_signal_handlers():
saved = {}
for sig in [signal.SIGINT, signal.SIGTERM, signal.SIGCHLD]:
saved[sig] = signal.getsignal(sig)
if "twisted" in repr(saved):
if not issubclass(IOLoop.configured_class(), TwistedIOLoop):
# when the global ioloop is twisted, we expect the signal
# handlers to be installed. Otherwise, it means we're not
# cleaning up after twisted properly.
raise Exception("twisted signal handlers already installed")
return saved
def restore_signal_handlers(saved):
for sig, handler in saved.items():
signal.signal(sig, handler)
class ReactorTestCase(unittest.TestCase):
def setUp(self):
self._saved_signals = save_signal_handlers()
self._io_loop = IOLoop()
self._reactor = TornadoReactor(self._io_loop)
def tearDown(self):
self._io_loop.close(all_fds=True)
restore_signal_handlers(self._saved_signals)
@skipIfNoTwisted
class ReactorWhenRunningTest(ReactorTestCase):
def test_whenRunning(self):
self._whenRunningCalled = False
self._anotherWhenRunningCalled = False
self._reactor.callWhenRunning(self.whenRunningCallback)
self._reactor.run()
self.assertTrue(self._whenRunningCalled)
self.assertTrue(self._anotherWhenRunningCalled)
def whenRunningCallback(self):
self._whenRunningCalled = True
self._reactor.callWhenRunning(self.anotherWhenRunningCallback)
self._reactor.stop()
def anotherWhenRunningCallback(self):
self._anotherWhenRunningCalled = True
@skipIfNoTwisted
class ReactorCallLaterTest(ReactorTestCase):
def test_callLater(self):
self._laterCalled = False
self._now = self._reactor.seconds()
self._timeout = 0.001
dc = self._reactor.callLater(self._timeout, self.callLaterCallback)
self.assertEqual(self._reactor.getDelayedCalls(), [dc])
self._reactor.run()
self.assertTrue(self._laterCalled)
self.assertTrue(self._called - self._now > self._timeout)
self.assertEqual(self._reactor.getDelayedCalls(), [])
def callLaterCallback(self):
self._laterCalled = True
self._called = self._reactor.seconds()
self._reactor.stop()
@skipIfNoTwisted
class ReactorTwoCallLaterTest(ReactorTestCase):
def test_callLater(self):
self._later1Called = False
self._later2Called = False
self._now = self._reactor.seconds()
self._timeout1 = 0.0005
dc1 = self._reactor.callLater(self._timeout1, self.callLaterCallback1)
self._timeout2 = 0.001
dc2 = self._reactor.callLater(self._timeout2, self.callLaterCallback2)
self.assertTrue(self._reactor.getDelayedCalls() == [dc1, dc2] or
self._reactor.getDelayedCalls() == [dc2, dc1])
self._reactor.run()
self.assertTrue(self._later1Called)
self.assertTrue(self._later2Called)
self.assertTrue(self._called1 - self._now > self._timeout1)
self.assertTrue(self._called2 - self._now > self._timeout2)
self.assertEqual(self._reactor.getDelayedCalls(), [])
def callLaterCallback1(self):
self._later1Called = True
self._called1 = self._reactor.seconds()
def callLaterCallback2(self):
self._later2Called = True
self._called2 = self._reactor.seconds()
self._reactor.stop()
@skipIfNoTwisted
class ReactorCallFromThreadTest(ReactorTestCase):
def setUp(self):
super(ReactorCallFromThreadTest, self).setUp()
self._mainThread = thread.get_ident()
def tearDown(self):
self._thread.join()
super(ReactorCallFromThreadTest, self).tearDown()
def _newThreadRun(self):
self.assertNotEqual(self._mainThread, thread.get_ident())
if hasattr(self._thread, 'ident'): # new in python 2.6
self.assertEqual(self._thread.ident, thread.get_ident())
self._reactor.callFromThread(self._fnCalledFromThread)
def _fnCalledFromThread(self):
self.assertEqual(self._mainThread, thread.get_ident())
self._reactor.stop()
def _whenRunningCallback(self):
self._thread = threading.Thread(target=self._newThreadRun)
self._thread.start()
def testCallFromThread(self):
self._reactor.callWhenRunning(self._whenRunningCallback)
self._reactor.run()
@skipIfNoTwisted
class ReactorCallInThread(ReactorTestCase):
def setUp(self):
super(ReactorCallInThread, self).setUp()
self._mainThread = thread.get_ident()
def _fnCalledInThread(self, *args, **kwargs):
self.assertNotEqual(thread.get_ident(), self._mainThread)
self._reactor.callFromThread(lambda: self._reactor.stop())
def _whenRunningCallback(self):
self._reactor.callInThread(self._fnCalledInThread)
def testCallInThread(self):
self._reactor.callWhenRunning(self._whenRunningCallback)
self._reactor.run()
class Reader(object):
def __init__(self, fd, callback):
self._fd = fd
self._callback = callback
def logPrefix(self):
return "Reader"
def close(self):
self._fd.close()
def fileno(self):
return self._fd.fileno()
def readConnectionLost(self, reason):
self.close()
def connectionLost(self, reason):
self.close()
def doRead(self):
self._callback(self._fd)
if have_twisted:
Reader = implementer(IReadDescriptor)(Reader)
class Writer(object):
def __init__(self, fd, callback):
self._fd = fd
self._callback = callback
def logPrefix(self):
return "Writer"
def close(self):
self._fd.close()
def fileno(self):
return self._fd.fileno()
def connectionLost(self, reason):
self.close()
def doWrite(self):
self._callback(self._fd)
if have_twisted:
Writer = implementer(IWriteDescriptor)(Writer)
@skipIfNoTwisted
class ReactorReaderWriterTest(ReactorTestCase):
def _set_nonblocking(self, fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
def setUp(self):
super(ReactorReaderWriterTest, self).setUp()
r, w = os.pipe()
self._set_nonblocking(r)
self._set_nonblocking(w)
set_close_exec(r)
set_close_exec(w)
self._p1 = os.fdopen(r, "rb", 0)
self._p2 = os.fdopen(w, "wb", 0)
def tearDown(self):
super(ReactorReaderWriterTest, self).tearDown()
self._p1.close()
self._p2.close()
def _testReadWrite(self):
"""
In this test the writer writes an 'x' to its fd. The reader
reads it, check the value and ends the test.
"""
self.shouldWrite = True
def checkReadInput(fd):
self.assertEquals(fd.read(1), b'x')
self._reactor.stop()
def writeOnce(fd):
if self.shouldWrite:
self.shouldWrite = False
fd.write(b'x')
self._reader = Reader(self._p1, checkReadInput)
self._writer = Writer(self._p2, writeOnce)
self._reactor.addWriter(self._writer)
# Test that adding the reader twice adds it only once to
# IOLoop.
self._reactor.addReader(self._reader)
self._reactor.addReader(self._reader)
def testReadWrite(self):
self._reactor.callWhenRunning(self._testReadWrite)
self._reactor.run()
def _testNoWriter(self):
"""
In this test we have no writer. Make sure the reader doesn't
read anything.
"""
def checkReadInput(fd):
self.fail("Must not be called.")
def stopTest():
# Close the writer here since the IOLoop doesn't know
# about it.
self._writer.close()
self._reactor.stop()
self._reader = Reader(self._p1, checkReadInput)
# We create a writer, but it should never be invoked.
self._writer = Writer(self._p2, lambda fd: fd.write('x'))
# Test that adding and removing the writer leaves us with no writer.
self._reactor.addWriter(self._writer)
self._reactor.removeWriter(self._writer)
# Test that adding and removing the reader doesn't cause
# unintended effects.
self._reactor.addReader(self._reader)
# Wake up after a moment and stop the test
self._reactor.callLater(0.001, stopTest)
def testNoWriter(self):
self._reactor.callWhenRunning(self._testNoWriter)
self._reactor.run()
# Test various combinations of twisted and tornado http servers,
# http clients, and event loop interfaces.
@skipIfNoTwisted
@unittest.skipIf(not have_twisted_web, 'twisted web not present')
class CompatibilityTests(unittest.TestCase):
def setUp(self):
self.saved_signals = save_signal_handlers()
self.io_loop = IOLoop()
self.io_loop.make_current()
self.reactor = TornadoReactor(self.io_loop)
def tearDown(self):
self.reactor.disconnectAll()
self.io_loop.clear_current()
self.io_loop.close(all_fds=True)
restore_signal_handlers(self.saved_signals)
def start_twisted_server(self):
class HelloResource(Resource):
isLeaf = True
def render_GET(self, request):
return "Hello from twisted!"
site = Site(HelloResource())
port = self.reactor.listenTCP(0, site, interface='127.0.0.1')
self.twisted_port = port.getHost().port
def start_tornado_server(self):
class HelloHandler(RequestHandler):
def get(self):
self.write("Hello from tornado!")
app = Application([('/', HelloHandler)],
log_function=lambda x: None)
server = HTTPServer(app, io_loop=self.io_loop)
sock, self.tornado_port = bind_unused_port()
server.add_sockets([sock])
def run_ioloop(self):
self.stop_loop = self.io_loop.stop
self.io_loop.start()
self.reactor.fireSystemEvent('shutdown')
def run_reactor(self):
self.stop_loop = self.reactor.stop
self.stop = self.reactor.stop
self.reactor.run()
def tornado_fetch(self, url, runner):
responses = []
client = AsyncHTTPClient(self.io_loop)
def callback(response):
responses.append(response)
self.stop_loop()
client.fetch(url, callback=callback)
runner()
self.assertEqual(len(responses), 1)
responses[0].rethrow()
return responses[0]
def twisted_fetch(self, url, runner):
# http://twistedmatrix.com/documents/current/web/howto/client.html
chunks = []
client = Agent(self.reactor)
d = client.request('GET', url)
class Accumulator(Protocol):
def __init__(self, finished):
self.finished = finished
def dataReceived(self, data):
chunks.append(data)
def connectionLost(self, reason):
self.finished.callback(None)
def callback(response):
finished = Deferred()
response.deliverBody(Accumulator(finished))
return finished
d.addCallback(callback)
def shutdown(ignored):
self.stop_loop()
d.addBoth(shutdown)
runner()
self.assertTrue(chunks)
return ''.join(chunks)
def testTwistedServerTornadoClientIOLoop(self):
self.start_twisted_server()
response = self.tornado_fetch(
'http://localhost:%d' % self.twisted_port, self.run_ioloop)
self.assertEqual(response.body, 'Hello from twisted!')
def testTwistedServerTornadoClientReactor(self):
self.start_twisted_server()
response = self.tornado_fetch(
'http://localhost:%d' % self.twisted_port, self.run_reactor)
self.assertEqual(response.body, 'Hello from twisted!')
def testTornadoServerTwistedClientIOLoop(self):
self.start_tornado_server()
response = self.twisted_fetch(
'http://localhost:%d' % self.tornado_port, self.run_ioloop)
self.assertEqual(response, 'Hello from tornado!')
def testTornadoServerTwistedClientReactor(self):
self.start_tornado_server()
response = self.twisted_fetch(
'http://localhost:%d' % self.tornado_port, self.run_reactor)
self.assertEqual(response, 'Hello from tornado!')
if have_twisted:
# Import and run as much of twisted's test suite as possible.
# This is unfortunately rather dependent on implementation details,
# but there doesn't appear to be a clean all-in-one conformance test
# suite for reactors.
#
# This is a list of all test suites using the ReactorBuilder
# available in Twisted 11.0.0 and 11.1.0 (and a blacklist of
# specific test methods to be disabled).
twisted_tests = {
'twisted.internet.test.test_core.ObjectModelIntegrationTest': [],
'twisted.internet.test.test_core.SystemEventTestsBuilder': [
'test_iterate', # deliberately not supported
'test_runAfterCrash', # fails because TwistedIOLoop uses the global reactor
] if issubclass(IOLoop.configured_class(), TwistedIOLoop) else [
'test_iterate', # deliberately not supported
],
'twisted.internet.test.test_fdset.ReactorFDSetTestsBuilder': [
"test_lostFileDescriptor", # incompatible with epoll and kqueue
],
'twisted.internet.test.test_process.ProcessTestsBuilder': [
],
# Process tests appear to work on OSX 10.7, but not 10.6
#'twisted.internet.test.test_process.PTYProcessTestsBuilder': [
# 'test_systemCallUninterruptedByChildExit',
# ],
'twisted.internet.test.test_tcp.TCPClientTestsBuilder': [
'test_badContext', # ssl-related; see also SSLClientTestsMixin
],
'twisted.internet.test.test_tcp.TCPPortTestsBuilder': [
# These use link-local addresses and cause firewall prompts on mac
'test_buildProtocolIPv6AddressScopeID',
'test_portGetHostOnIPv6ScopeID',
'test_serverGetHostOnIPv6ScopeID',
'test_serverGetPeerOnIPv6ScopeID',
],
'twisted.internet.test.test_tcp.TCPConnectionTestsBuilder': [],
'twisted.internet.test.test_tcp.WriteSequenceTests': [],
'twisted.internet.test.test_tcp.AbortConnectionTestCase': [],
'twisted.internet.test.test_threads.ThreadTestsBuilder': [],
'twisted.internet.test.test_time.TimeTestsBuilder': [],
# Extra third-party dependencies (pyOpenSSL)
#'twisted.internet.test.test_tls.SSLClientTestsMixin': [],
'twisted.internet.test.test_udp.UDPServerTestsBuilder': [],
'twisted.internet.test.test_unix.UNIXTestsBuilder': [
# Platform-specific. These tests would be skipped automatically
# if we were running twisted's own test runner.
'test_connectToLinuxAbstractNamespace',
'test_listenOnLinuxAbstractNamespace',
# These tests use twisted's sendmsg.c extension and sometimes
# fail with what looks like uninitialized memory errors
# (more common on pypy than cpython, but I've seen it on both)
'test_sendFileDescriptor',
'test_sendFileDescriptorTriggersPauseProducing',
'test_descriptorDeliveredBeforeBytes',
'test_avoidLeakingFileDescriptors',
],
'twisted.internet.test.test_unix.UNIXDatagramTestsBuilder': [
'test_listenOnLinuxAbstractNamespace',
],
'twisted.internet.test.test_unix.UNIXPortTestsBuilder': [],
}
for test_name, blacklist in twisted_tests.items():
try:
test_class = import_object(test_name)
except (ImportError, AttributeError):
continue
for test_func in blacklist:
if hasattr(test_class, test_func):
# The test_func may be defined in a mixin, so clobber
# it instead of delattr()
setattr(test_class, test_func, lambda self: None)
def make_test_subclass(test_class):
class TornadoTest(test_class):
_reactors = ["tornado.platform.twisted._TestReactor"]
def setUp(self):
# Twisted's tests expect to be run from a temporary
# directory; they create files in their working directory
# and don't always clean up after themselves.
self.__curdir = os.getcwd()
self.__tempdir = tempfile.mkdtemp()
os.chdir(self.__tempdir)
super(TornadoTest, self).setUp()
def tearDown(self):
super(TornadoTest, self).tearDown()
os.chdir(self.__curdir)
shutil.rmtree(self.__tempdir)
def buildReactor(self):
self.__saved_signals = save_signal_handlers()
return test_class.buildReactor(self)
def unbuildReactor(self, reactor):
test_class.unbuildReactor(self, reactor)
# Clean up file descriptors (especially epoll/kqueue
# objects) eagerly instead of leaving them for the
# GC. Unfortunately we can't do this in reactor.stop
# since twisted expects to be able to unregister
# connections in a post-shutdown hook.
reactor._io_loop.close(all_fds=True)
restore_signal_handlers(self.__saved_signals)
TornadoTest.__name__ = test_class.__name__
return TornadoTest
test_subclass = make_test_subclass(test_class)
globals().update(test_subclass.makeTestCaseClasses())
# Since we're not using twisted's test runner, it's tricky to get
# logging set up well. Most of the time it's easiest to just
# leave it turned off, but while working on these tests you may want
# to uncomment one of the other lines instead.
log.defaultObserver.stop()
# import sys; log.startLogging(sys.stderr, setStdout=0)
# log.startLoggingWithObserver(log.PythonLoggingObserver().emit, setStdout=0)
# import logging; logging.getLogger('twisted').setLevel(logging.WARNING)
if have_twisted:
class LayeredTwistedIOLoop(TwistedIOLoop):
"""Layers a TwistedIOLoop on top of a TornadoReactor on a SelectIOLoop.
This is of course silly, but is useful for testing purposes to make
sure we're implementing both sides of the various interfaces
correctly. In some tests another TornadoReactor is layered on top
of the whole stack.
"""
def initialize(self):
# When configured to use LayeredTwistedIOLoop we can't easily
# get the next-best IOLoop implementation, so use the lowest common
# denominator.
self.real_io_loop = SelectIOLoop()
reactor = TornadoReactor(io_loop=self.real_io_loop)
super(LayeredTwistedIOLoop, self).initialize(reactor=reactor)
self.add_callback(self.make_current)
def close(self, all_fds=False):
super(LayeredTwistedIOLoop, self).close(all_fds=all_fds)
# HACK: This is the same thing that test_class.unbuildReactor does.
for reader in self.reactor._internalReaders:
self.reactor.removeReader(reader)
reader.connectionLost(None)
self.real_io_loop.close(all_fds=all_fds)
def stop(self):
# One of twisted's tests fails if I don't delay crash()
# until the reactor has started, but if I move this to
# TwistedIOLoop then the tests fail when I'm *not* running
# tornado-on-twisted-on-tornado. I'm clearly missing something
# about the startup/crash semantics, but since stop and crash
# are really only used in tests it doesn't really matter.
self.reactor.callWhenRunning(self.reactor.crash)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,19 @@
from __future__ import absolute_import, division, print_function, with_statement
import os
import sys
# Encapsulate the choice of unittest or unittest2 here.
# To be used as 'from tornado.test.util import unittest'.
if sys.version_info >= (2, 7):
import unittest
else:
import unittest2 as unittest
skipIfNonUnix = unittest.skipIf(os.name != 'posix' or sys.platform == 'cygwin',
"non-unix platform")
# travis-ci.org runs our tests in an overworked virtual machine, which makes
# timing-related tests unreliable.
skipOnTravis = unittest.skipIf('TRAVIS' in os.environ,
'timing tests unreliable on travis')

View file

@ -0,0 +1,164 @@
# coding: utf-8
from __future__ import absolute_import, division, print_function, with_statement
import sys
from tornado.escape import utf8
from tornado.util import raise_exc_info, Configurable, u, exec_in, ArgReplacer
from tornado.test.util import unittest
try:
from cStringIO import StringIO # py2
except ImportError:
from io import StringIO # py3
class RaiseExcInfoTest(unittest.TestCase):
def test_two_arg_exception(self):
# This test would fail on python 3 if raise_exc_info were simply
# a three-argument raise statement, because TwoArgException
# doesn't have a "copy constructor"
class TwoArgException(Exception):
def __init__(self, a, b):
super(TwoArgException, self).__init__()
self.a, self.b = a, b
try:
raise TwoArgException(1, 2)
except TwoArgException:
exc_info = sys.exc_info()
try:
raise_exc_info(exc_info)
self.fail("didn't get expected exception")
except TwoArgException as e:
self.assertIs(e, exc_info[1])
class TestConfigurable(Configurable):
@classmethod
def configurable_base(cls):
return TestConfigurable
@classmethod
def configurable_default(cls):
return TestConfig1
class TestConfig1(TestConfigurable):
def initialize(self, a=None):
self.a = a
class TestConfig2(TestConfigurable):
def initialize(self, b=None):
self.b = b
class ConfigurableTest(unittest.TestCase):
def setUp(self):
self.saved = TestConfigurable._save_configuration()
def tearDown(self):
TestConfigurable._restore_configuration(self.saved)
def checkSubclasses(self):
# no matter how the class is configured, it should always be
# possible to instantiate the subclasses directly
self.assertIsInstance(TestConfig1(), TestConfig1)
self.assertIsInstance(TestConfig2(), TestConfig2)
obj = TestConfig1(a=1)
self.assertEqual(obj.a, 1)
obj = TestConfig2(b=2)
self.assertEqual(obj.b, 2)
def test_default(self):
obj = TestConfigurable()
self.assertIsInstance(obj, TestConfig1)
self.assertIs(obj.a, None)
obj = TestConfigurable(a=1)
self.assertIsInstance(obj, TestConfig1)
self.assertEqual(obj.a, 1)
self.checkSubclasses()
def test_config_class(self):
TestConfigurable.configure(TestConfig2)
obj = TestConfigurable()
self.assertIsInstance(obj, TestConfig2)
self.assertIs(obj.b, None)
obj = TestConfigurable(b=2)
self.assertIsInstance(obj, TestConfig2)
self.assertEqual(obj.b, 2)
self.checkSubclasses()
def test_config_args(self):
TestConfigurable.configure(None, a=3)
obj = TestConfigurable()
self.assertIsInstance(obj, TestConfig1)
self.assertEqual(obj.a, 3)
obj = TestConfigurable(a=4)
self.assertIsInstance(obj, TestConfig1)
self.assertEqual(obj.a, 4)
self.checkSubclasses()
# args bound in configure don't apply when using the subclass directly
obj = TestConfig1()
self.assertIs(obj.a, None)
def test_config_class_args(self):
TestConfigurable.configure(TestConfig2, b=5)
obj = TestConfigurable()
self.assertIsInstance(obj, TestConfig2)
self.assertEqual(obj.b, 5)
obj = TestConfigurable(b=6)
self.assertIsInstance(obj, TestConfig2)
self.assertEqual(obj.b, 6)
self.checkSubclasses()
# args bound in configure don't apply when using the subclass directly
obj = TestConfig2()
self.assertIs(obj.b, None)
class UnicodeLiteralTest(unittest.TestCase):
def test_unicode_escapes(self):
self.assertEqual(utf8(u('\u00e9')), b'\xc3\xa9')
class ExecInTest(unittest.TestCase):
# This test is python 2 only because there are no new future imports
# defined in python 3 yet.
@unittest.skipIf(sys.version_info >= print_function.getMandatoryRelease(),
'no testable future imports')
def test_no_inherit_future(self):
# This file has from __future__ import print_function...
f = StringIO()
print('hello', file=f)
# ...but the template doesn't
exec_in('print >> f, "world"', dict(f=f))
self.assertEqual(f.getvalue(), 'hello\nworld\n')
class ArgReplacerTest(unittest.TestCase):
def setUp(self):
def function(x, y, callback=None, z=None):
pass
self.replacer = ArgReplacer(function, 'callback')
def test_omitted(self):
self.assertEqual(self.replacer.replace('new', (1, 2), dict()),
(None, (1, 2), dict(callback='new')))
def test_position(self):
self.assertEqual(self.replacer.replace('new', (1, 2, 'old', 3), dict()),
('old', [1, 2, 'new', 3], dict()))
def test_keyword(self):
self.assertEqual(self.replacer.replace('new', (1,),
dict(y=2, callback='old', z=3)),
('old', (1,), dict(y=2, callback='new', z=3)))

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,87 @@
from tornado.concurrent import Future
from tornado import gen
from tornado.httpclient import HTTPError
from tornado.log import gen_log
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
from tornado.web import Application, RequestHandler
from tornado.websocket import WebSocketHandler, websocket_connect, WebSocketError
class EchoHandler(WebSocketHandler):
def initialize(self, close_future):
self.close_future = close_future
def on_message(self, message):
self.write_message(message, isinstance(message, bytes))
def on_close(self):
self.close_future.set_result(None)
class NonWebSocketHandler(RequestHandler):
def get(self):
self.write('ok')
class WebSocketTest(AsyncHTTPTestCase):
def get_app(self):
self.close_future = Future()
return Application([
('/echo', EchoHandler, dict(close_future=self.close_future)),
('/non_ws', NonWebSocketHandler),
])
@gen_test
def test_websocket_gen(self):
ws = yield websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port(),
io_loop=self.io_loop)
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
def test_websocket_callbacks(self):
websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port(),
io_loop=self.io_loop, callback=self.stop)
ws = self.wait().result()
ws.write_message('hello')
ws.read_message(self.stop)
response = self.wait().result()
self.assertEqual(response, 'hello')
@gen_test
def test_websocket_http_fail(self):
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(
'ws://localhost:%d/notfound' % self.get_http_port(),
io_loop=self.io_loop)
self.assertEqual(cm.exception.code, 404)
@gen_test
def test_websocket_http_success(self):
with self.assertRaises(WebSocketError):
yield websocket_connect(
'ws://localhost:%d/non_ws' % self.get_http_port(),
io_loop=self.io_loop)
@gen_test
def test_websocket_network_fail(self):
sock, port = bind_unused_port()
sock.close()
with self.assertRaises(HTTPError) as cm:
with ExpectLog(gen_log, ".*"):
yield websocket_connect(
'ws://localhost:%d/' % port,
io_loop=self.io_loop,
connect_timeout=0.01)
self.assertEqual(cm.exception.code, 599)
@gen_test
def test_websocket_close_buffered_data(self):
ws = yield websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port())
ws.write_message('hello')
ws.write_message('world')
ws.stream.close()
yield self.close_future

View file

@ -0,0 +1,87 @@
from __future__ import absolute_import, division, print_function, with_statement
from wsgiref.validate import validator
from tornado.escape import json_decode
from tornado.test.httpserver_test import TypeCheckHandler
from tornado.testing import AsyncHTTPTestCase
from tornado.util import u
from tornado.web import RequestHandler
from tornado.wsgi import WSGIApplication, WSGIContainer
class WSGIContainerTest(AsyncHTTPTestCase):
def wsgi_app(self, environ, start_response):
status = "200 OK"
response_headers = [("Content-Type", "text/plain")]
start_response(status, response_headers)
return [b"Hello world!"]
def get_app(self):
return WSGIContainer(validator(self.wsgi_app))
def test_simple(self):
response = self.fetch("/")
self.assertEqual(response.body, b"Hello world!")
class WSGIApplicationTest(AsyncHTTPTestCase):
def get_app(self):
class HelloHandler(RequestHandler):
def get(self):
self.write("Hello world!")
class PathQuotingHandler(RequestHandler):
def get(self, path):
self.write(path)
# It would be better to run the wsgiref server implementation in
# another thread instead of using our own WSGIContainer, but this
# fits better in our async testing framework and the wsgiref
# validator should keep us honest
return WSGIContainer(validator(WSGIApplication([
("/", HelloHandler),
("/path/(.*)", PathQuotingHandler),
("/typecheck", TypeCheckHandler),
])))
def test_simple(self):
response = self.fetch("/")
self.assertEqual(response.body, b"Hello world!")
def test_path_quoting(self):
response = self.fetch("/path/foo%20bar%C3%A9")
self.assertEqual(response.body, u("foo bar\u00e9").encode("utf-8"))
def test_types(self):
headers = {"Cookie": "foo=bar"}
response = self.fetch("/typecheck?foo=bar", headers=headers)
data = json_decode(response.body)
self.assertEqual(data, {})
response = self.fetch("/typecheck", method="POST", body="foo=bar", headers=headers)
data = json_decode(response.body)
self.assertEqual(data, {})
# This is kind of hacky, but run some of the HTTPServer tests through
# WSGIContainer and WSGIApplication to make sure everything survives
# repeated disassembly and reassembly.
from tornado.test import httpserver_test
from tornado.test import web_test
class WSGIConnectionTest(httpserver_test.HTTPConnectionTest):
def get_app(self):
return WSGIContainer(validator(WSGIApplication(self.get_handlers())))
def wrap_web_tests():
result = {}
for cls in web_test.wsgi_safe_tests:
class WSGIWrappedTest(cls):
def get_app(self):
self.app = WSGIApplication(self.get_handlers(),
**self.get_app_kwargs())
return WSGIContainer(validator(self.app))
result["WSGIWrapped_" + cls.__name__] = WSGIWrappedTest
return result
globals().update(wrap_web_tests())

View file

@ -0,0 +1,615 @@
#!/usr/bin/env python
"""Support classes for automated testing.
* `AsyncTestCase` and `AsyncHTTPTestCase`: Subclasses of unittest.TestCase
with additional support for testing asynchronous (`.IOLoop` based) code.
* `ExpectLog` and `LogTrapTestCase`: Make test logs less spammy.
* `main()`: A simple test runner (wrapper around unittest.main()) with support
for the tornado.autoreload module to rerun the tests when code changes.
"""
from __future__ import absolute_import, division, print_function, with_statement
try:
from tornado import gen
from tornado.httpclient import AsyncHTTPClient
from tornado.httpserver import HTTPServer
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.ioloop import IOLoop
from tornado import netutil
except ImportError:
# These modules are not importable on app engine. Parts of this module
# won't work, but e.g. LogTrapTestCase and main() will.
AsyncHTTPClient = None
gen = None
HTTPServer = None
IOLoop = None
netutil = None
SimpleAsyncHTTPClient = None
from tornado.log import gen_log
from tornado.stack_context import ExceptionStackContext
from tornado.util import raise_exc_info, basestring_type
import functools
import logging
import os
import re
import signal
import socket
import sys
try:
from cStringIO import StringIO # py2
except ImportError:
from io import StringIO # py3
# Tornado's own test suite requires the updated unittest module
# (either py27+ or unittest2) so tornado.test.util enforces
# this requirement, but for other users of tornado.testing we want
# to allow the older version if unitest2 is not available.
try:
import unittest2 as unittest
except ImportError:
import unittest
_next_port = 10000
def get_unused_port():
"""Returns a (hopefully) unused port number.
This function does not guarantee that the port it returns is available,
only that a series of get_unused_port calls in a single process return
distinct ports.
**Deprecated**. Use bind_unused_port instead, which is guaranteed
to find an unused port.
"""
global _next_port
port = _next_port
_next_port = _next_port + 1
return port
def bind_unused_port():
"""Binds a server socket to an available port on localhost.
Returns a tuple (socket, port).
"""
[sock] = netutil.bind_sockets(None, 'localhost', family=socket.AF_INET)
port = sock.getsockname()[1]
return sock, port
def get_async_test_timeout():
"""Get the global timeout setting for async tests.
Returns a float, the timeout in seconds.
.. versionadded:: 3.1
"""
try:
return float(os.environ.get('ASYNC_TEST_TIMEOUT'))
except (ValueError, TypeError):
return 5
class AsyncTestCase(unittest.TestCase):
"""`~unittest.TestCase` subclass for testing `.IOLoop`-based
asynchronous code.
The unittest framework is synchronous, so the test must be
complete by the time the test method returns. This means that
asynchronous code cannot be used in quite the same way as usual.
To write test functions that use the same ``yield``-based patterns
used with the `tornado.gen` module, decorate your test methods
with `tornado.testing.gen_test` instead of
`tornado.gen.coroutine`. This class also provides the `stop()`
and `wait()` methods for a more manual style of testing. The test
method itself must call ``self.wait()``, and asynchronous
callbacks should call ``self.stop()`` to signal completion.
By default, a new `.IOLoop` is constructed for each test and is available
as ``self.io_loop``. This `.IOLoop` should be used in the construction of
HTTP clients/servers, etc. If the code being tested requires a
global `.IOLoop`, subclasses should override `get_new_ioloop` to return it.
The `.IOLoop`'s ``start`` and ``stop`` methods should not be
called directly. Instead, use `self.stop <stop>` and `self.wait
<wait>`. Arguments passed to ``self.stop`` are returned from
``self.wait``. It is possible to have multiple ``wait``/``stop``
cycles in the same test.
Example::
# This test uses coroutine style.
class MyTestCase(AsyncTestCase):
@tornado.testing.gen_test
def test_http_fetch(self):
client = AsyncHTTPClient(self.io_loop)
response = yield client.fetch("http://www.tornadoweb.org")
# Test contents of response
self.assertIn("FriendFeed", response.body)
# This test uses argument passing between self.stop and self.wait.
class MyTestCase2(AsyncTestCase):
def test_http_fetch(self):
client = AsyncHTTPClient(self.io_loop)
client.fetch("http://www.tornadoweb.org/", self.stop)
response = self.wait()
# Test contents of response
self.assertIn("FriendFeed", response.body)
# This test uses an explicit callback-based style.
class MyTestCase3(AsyncTestCase):
def test_http_fetch(self):
client = AsyncHTTPClient(self.io_loop)
client.fetch("http://www.tornadoweb.org/", self.handle_fetch)
self.wait()
def handle_fetch(self, response):
# Test contents of response (failures and exceptions here
# will cause self.wait() to throw an exception and end the
# test).
# Exceptions thrown here are magically propagated to
# self.wait() in test_http_fetch() via stack_context.
self.assertIn("FriendFeed", response.body)
self.stop()
"""
def __init__(self, *args, **kwargs):
super(AsyncTestCase, self).__init__(*args, **kwargs)
self.__stopped = False
self.__running = False
self.__failure = None
self.__stop_args = None
self.__timeout = None
def setUp(self):
super(AsyncTestCase, self).setUp()
self.io_loop = self.get_new_ioloop()
self.io_loop.make_current()
def tearDown(self):
self.io_loop.clear_current()
if (not IOLoop.initialized() or
self.io_loop is not IOLoop.instance()):
# Try to clean up any file descriptors left open in the ioloop.
# This avoids leaks, especially when tests are run repeatedly
# in the same process with autoreload (because curl does not
# set FD_CLOEXEC on its file descriptors)
self.io_loop.close(all_fds=True)
super(AsyncTestCase, self).tearDown()
# In case an exception escaped or the StackContext caught an exception
# when there wasn't a wait() to re-raise it, do so here.
# This is our last chance to raise an exception in a way that the
# unittest machinery understands.
self.__rethrow()
def get_new_ioloop(self):
"""Creates a new `.IOLoop` for this test. May be overridden in
subclasses for tests that require a specific `.IOLoop` (usually
the singleton `.IOLoop.instance()`).
"""
return IOLoop()
def _handle_exception(self, typ, value, tb):
self.__failure = (typ, value, tb)
self.stop()
return True
def __rethrow(self):
if self.__failure is not None:
failure = self.__failure
self.__failure = None
raise_exc_info(failure)
def run(self, result=None):
with ExceptionStackContext(self._handle_exception):
super(AsyncTestCase, self).run(result)
# As a last resort, if an exception escaped super.run() and wasn't
# re-raised in tearDown, raise it here. This will cause the
# unittest run to fail messily, but that's better than silently
# ignoring an error.
self.__rethrow()
def stop(self, _arg=None, **kwargs):
"""Stops the `.IOLoop`, causing one pending (or future) call to `wait()`
to return.
Keyword arguments or a single positional argument passed to `stop()` are
saved and will be returned by `wait()`.
"""
assert _arg is None or not kwargs
self.__stop_args = kwargs or _arg
if self.__running:
self.io_loop.stop()
self.__running = False
self.__stopped = True
def wait(self, condition=None, timeout=None):
"""Runs the `.IOLoop` until stop is called or timeout has passed.
In the event of a timeout, an exception will be thrown. The
default timeout is 5 seconds; it may be overridden with a
``timeout`` keyword argument or globally with the
``ASYNC_TEST_TIMEOUT`` environment variable.
If ``condition`` is not None, the `.IOLoop` will be restarted
after `stop()` until ``condition()`` returns true.
.. versionchanged:: 3.1
Added the ``ASYNC_TEST_TIMEOUT`` environment variable.
"""
if timeout is None:
timeout = get_async_test_timeout()
if not self.__stopped:
if timeout:
def timeout_func():
try:
raise self.failureException(
'Async operation timed out after %s seconds' %
timeout)
except Exception:
self.__failure = sys.exc_info()
self.stop()
self.__timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout, timeout_func)
while True:
self.__running = True
self.io_loop.start()
if (self.__failure is not None or
condition is None or condition()):
break
if self.__timeout is not None:
self.io_loop.remove_timeout(self.__timeout)
self.__timeout = None
assert self.__stopped
self.__stopped = False
self.__rethrow()
result = self.__stop_args
self.__stop_args = None
return result
class AsyncHTTPTestCase(AsyncTestCase):
"""A test case that starts up an HTTP server.
Subclasses must override `get_app()`, which returns the
`tornado.web.Application` (or other `.HTTPServer` callback) to be tested.
Tests will typically use the provided ``self.http_client`` to fetch
URLs from this server.
Example::
class MyHTTPTest(AsyncHTTPTestCase):
def get_app(self):
return Application([('/', MyHandler)...])
def test_homepage(self):
# The following two lines are equivalent to
# response = self.fetch('/')
# but are shown in full here to demonstrate explicit use
# of self.stop and self.wait.
self.http_client.fetch(self.get_url('/'), self.stop)
response = self.wait()
# test contents of response
"""
def setUp(self):
super(AsyncHTTPTestCase, self).setUp()
sock, port = bind_unused_port()
self.__port = port
self.http_client = self.get_http_client()
self._app = self.get_app()
self.http_server = self.get_http_server()
self.http_server.add_sockets([sock])
def get_http_client(self):
return AsyncHTTPClient(io_loop=self.io_loop)
def get_http_server(self):
return HTTPServer(self._app, io_loop=self.io_loop,
**self.get_httpserver_options())
def get_app(self):
"""Should be overridden by subclasses to return a
`tornado.web.Application` or other `.HTTPServer` callback.
"""
raise NotImplementedError()
def fetch(self, path, **kwargs):
"""Convenience method to synchronously fetch a url.
The given path will be appended to the local server's host and
port. Any additional kwargs will be passed directly to
`.AsyncHTTPClient.fetch` (and so could be used to pass
``method="POST"``, ``body="..."``, etc).
"""
self.http_client.fetch(self.get_url(path), self.stop, **kwargs)
return self.wait()
def get_httpserver_options(self):
"""May be overridden by subclasses to return additional
keyword arguments for the server.
"""
return {}
def get_http_port(self):
"""Returns the port used by the server.
A new port is chosen for each test.
"""
return self.__port
def get_protocol(self):
return 'http'
def get_url(self, path):
"""Returns an absolute url for the given path on the test server."""
return '%s://localhost:%s%s' % (self.get_protocol(),
self.get_http_port(), path)
def tearDown(self):
self.http_server.stop()
if (not IOLoop.initialized() or
self.http_client.io_loop is not IOLoop.instance()):
self.http_client.close()
super(AsyncHTTPTestCase, self).tearDown()
class AsyncHTTPSTestCase(AsyncHTTPTestCase):
"""A test case that starts an HTTPS server.
Interface is generally the same as `AsyncHTTPTestCase`.
"""
def get_http_client(self):
# Some versions of libcurl have deadlock bugs with ssl,
# so always run these tests with SimpleAsyncHTTPClient.
return SimpleAsyncHTTPClient(io_loop=self.io_loop, force_instance=True,
defaults=dict(validate_cert=False))
def get_httpserver_options(self):
return dict(ssl_options=self.get_ssl_options())
def get_ssl_options(self):
"""May be overridden by subclasses to select SSL options.
By default includes a self-signed testing certificate.
"""
# Testing keys were generated with:
# openssl req -new -keyout tornado/test/test.key -out tornado/test/test.crt -nodes -days 3650 -x509
module_dir = os.path.dirname(__file__)
return dict(
certfile=os.path.join(module_dir, 'test', 'test.crt'),
keyfile=os.path.join(module_dir, 'test', 'test.key'))
def get_protocol(self):
return 'https'
def gen_test(func=None, timeout=None):
"""Testing equivalent of ``@gen.coroutine``, to be applied to test methods.
``@gen.coroutine`` cannot be used on tests because the `.IOLoop` is not
already running. ``@gen_test`` should be applied to test methods
on subclasses of `AsyncTestCase`.
Example::
class MyTest(AsyncHTTPTestCase):
@gen_test
def test_something(self):
response = yield gen.Task(self.fetch('/'))
By default, ``@gen_test`` times out after 5 seconds. The timeout may be
overridden globally with the ``ASYNC_TEST_TIMEOUT`` environment variable,
or for each test with the ``timeout`` keyword argument::
class MyTest(AsyncHTTPTestCase):
@gen_test(timeout=10)
def test_something_slow(self):
response = yield gen.Task(self.fetch('/'))
.. versionadded:: 3.1
The ``timeout`` argument and ``ASYNC_TEST_TIMEOUT`` environment
variable.
"""
if timeout is None:
timeout = get_async_test_timeout()
def wrap(f):
f = gen.coroutine(f)
@functools.wraps(f)
def wrapper(self):
return self.io_loop.run_sync(
functools.partial(f, self), timeout=timeout)
return wrapper
if func is not None:
# Used like:
# @gen_test
# def f(self):
# pass
return wrap(func)
else:
# Used like @gen_test(timeout=10)
return wrap
# Without this attribute, nosetests will try to run gen_test as a test
# anywhere it is imported.
gen_test.__test__ = False
class LogTrapTestCase(unittest.TestCase):
"""A test case that captures and discards all logging output
if the test passes.
Some libraries can produce a lot of logging output even when
the test succeeds, so this class can be useful to minimize the noise.
Simply use it as a base class for your test case. It is safe to combine
with AsyncTestCase via multiple inheritance
(``class MyTestCase(AsyncHTTPTestCase, LogTrapTestCase):``)
This class assumes that only one log handler is configured and
that it is a `~logging.StreamHandler`. This is true for both
`logging.basicConfig` and the "pretty logging" configured by
`tornado.options`. It is not compatible with other log buffering
mechanisms, such as those provided by some test runners.
"""
def run(self, result=None):
logger = logging.getLogger()
if not logger.handlers:
logging.basicConfig()
handler = logger.handlers[0]
if (len(logger.handlers) > 1 or
not isinstance(handler, logging.StreamHandler)):
# Logging has been configured in a way we don't recognize,
# so just leave it alone.
super(LogTrapTestCase, self).run(result)
return
old_stream = handler.stream
try:
handler.stream = StringIO()
gen_log.info("RUNNING TEST: " + str(self))
old_error_count = len(result.failures) + len(result.errors)
super(LogTrapTestCase, self).run(result)
new_error_count = len(result.failures) + len(result.errors)
if new_error_count != old_error_count:
old_stream.write(handler.stream.getvalue())
finally:
handler.stream = old_stream
class ExpectLog(logging.Filter):
"""Context manager to capture and suppress expected log output.
Useful to make tests of error conditions less noisy, while still
leaving unexpected log entries visible. *Not thread safe.*
Usage::
with ExpectLog('tornado.application', "Uncaught exception"):
error_response = self.fetch("/some_page")
"""
def __init__(self, logger, regex, required=True):
"""Constructs an ExpectLog context manager.
:param logger: Logger object (or name of logger) to watch. Pass
an empty string to watch the root logger.
:param regex: Regular expression to match. Any log entries on
the specified logger that match this regex will be suppressed.
:param required: If true, an exeption will be raised if the end of
the ``with`` statement is reached without matching any log entries.
"""
if isinstance(logger, basestring_type):
logger = logging.getLogger(logger)
self.logger = logger
self.regex = re.compile(regex)
self.required = required
self.matched = False
def filter(self, record):
message = record.getMessage()
if self.regex.match(message):
self.matched = True
return False
return True
def __enter__(self):
self.logger.addFilter(self)
def __exit__(self, typ, value, tb):
self.logger.removeFilter(self)
if not typ and self.required and not self.matched:
raise Exception("did not get expected log message")
def main(**kwargs):
"""A simple test runner.
This test runner is essentially equivalent to `unittest.main` from
the standard library, but adds support for tornado-style option
parsing and log formatting.
The easiest way to run a test is via the command line::
python -m tornado.testing tornado.test.stack_context_test
See the standard library unittest module for ways in which tests can
be specified.
Projects with many tests may wish to define a test script like
``tornado/test/runtests.py``. This script should define a method
``all()`` which returns a test suite and then call
`tornado.testing.main()`. Note that even when a test script is
used, the ``all()`` test suite may be overridden by naming a
single test on the command line::
# Runs all tests
python -m tornado.test.runtests
# Runs one test
python -m tornado.test.runtests tornado.test.stack_context_test
Additional keyword arguments passed through to ``unittest.main()``.
For example, use ``tornado.testing.main(verbosity=2)``
to show many test details as they are run.
See http://docs.python.org/library/unittest.html#unittest.main
for full argument list.
"""
from tornado.options import define, options, parse_command_line
define('exception_on_interrupt', type=bool, default=True,
help=("If true (default), ctrl-c raises a KeyboardInterrupt "
"exception. This prints a stack trace but cannot interrupt "
"certain operations. If false, the process is more reliably "
"killed, but does not print a stack trace."))
# support the same options as unittest's command-line interface
define('verbose', type=bool)
define('quiet', type=bool)
define('failfast', type=bool)
define('catch', type=bool)
define('buffer', type=bool)
argv = [sys.argv[0]] + parse_command_line(sys.argv)
if not options.exception_on_interrupt:
signal.signal(signal.SIGINT, signal.SIG_DFL)
if options.verbose is not None:
kwargs['verbosity'] = 2
if options.quiet is not None:
kwargs['verbosity'] = 0
if options.failfast is not None:
kwargs['failfast'] = True
if options.catch is not None:
kwargs['catchbreak'] = True
if options.buffer is not None:
kwargs['buffer'] = True
if __name__ == '__main__' and len(argv) == 1:
print("No tests specified", file=sys.stderr)
sys.exit(1)
try:
# In order to be able to run tests by their fully-qualified name
# on the command line without importing all tests here,
# module must be set to None. Python 3.2's unittest.main ignores
# defaultTest if no module is given (it tries to do its own
# test discovery, which is incompatible with auto2to3), so don't
# set module if we're not asking for a specific test.
if len(argv) > 1:
unittest.main(module=None, argv=argv, **kwargs)
else:
unittest.main(defaultTest="all", argv=argv, **kwargs)
except SystemExit as e:
if e.code == 0:
gen_log.info('PASS')
else:
gen_log.error('FAIL')
raise
if __name__ == '__main__':
main()

View file

@ -0,0 +1,270 @@
"""Miscellaneous utility functions and classes.
This module is used internally by Tornado. It is not necessarily expected
that the functions and classes defined here will be useful to other
applications, but they are documented here in case they are.
The one public-facing part of this module is the `Configurable` class
and its `~Configurable.configure` method, which becomes a part of the
interface of its subclasses, including `.AsyncHTTPClient`, `.IOLoop`,
and `.Resolver`.
"""
from __future__ import absolute_import, division, print_function, with_statement
import inspect
import sys
import zlib
class ObjectDict(dict):
"""Makes a dictionary behave like an object, with attribute-style access.
"""
def __getattr__(self, name):
try:
return self[name]
except KeyError:
raise AttributeError(name)
def __setattr__(self, name, value):
self[name] = value
class GzipDecompressor(object):
"""Streaming gzip decompressor.
The interface is like that of `zlib.decompressobj` (without the
optional arguments, but it understands gzip headers and checksums.
"""
def __init__(self):
# Magic parameter makes zlib module understand gzip header
# http://stackoverflow.com/questions/1838699/how-can-i-decompress-a-gzip-stream-with-zlib
# This works on cpython and pypy, but not jython.
self.decompressobj = zlib.decompressobj(16 + zlib.MAX_WBITS)
def decompress(self, value):
"""Decompress a chunk, returning newly-available data.
Some data may be buffered for later processing; `flush` must
be called when there is no more input data to ensure that
all data was processed.
"""
return self.decompressobj.decompress(value)
def flush(self):
"""Return any remaining buffered data not yet returned by decompress.
Also checks for errors such as truncated input.
No other methods may be called on this object after `flush`.
"""
return self.decompressobj.flush()
def import_object(name):
"""Imports an object by name.
import_object('x') is equivalent to 'import x'.
import_object('x.y.z') is equivalent to 'from x.y import z'.
>>> import tornado.escape
>>> import_object('tornado.escape') is tornado.escape
True
>>> import_object('tornado.escape.utf8') is tornado.escape.utf8
True
>>> import_object('tornado') is tornado
True
>>> import_object('tornado.missing_module')
Traceback (most recent call last):
...
ImportError: No module named missing_module
"""
if name.count('.') == 0:
return __import__(name, None, None)
parts = name.split('.')
obj = __import__('.'.join(parts[:-1]), None, None, [parts[-1]], 0)
try:
return getattr(obj, parts[-1])
except AttributeError:
raise ImportError("No module named %s" % parts[-1])
# Fake unicode literal support: Python 3.2 doesn't have the u'' marker for
# literal strings, and alternative solutions like "from __future__ import
# unicode_literals" have other problems (see PEP 414). u() can be applied
# to ascii strings that include \u escapes (but they must not contain
# literal non-ascii characters).
if type('') is not type(b''):
def u(s):
return s
bytes_type = bytes
unicode_type = str
basestring_type = str
else:
def u(s):
return s.decode('unicode_escape')
bytes_type = str
unicode_type = unicode
basestring_type = basestring
if sys.version_info > (3,):
exec("""
def raise_exc_info(exc_info):
raise exc_info[1].with_traceback(exc_info[2])
def exec_in(code, glob, loc=None):
if isinstance(code, str):
code = compile(code, '<string>', 'exec', dont_inherit=True)
exec(code, glob, loc)
""")
else:
exec("""
def raise_exc_info(exc_info):
raise exc_info[0], exc_info[1], exc_info[2]
def exec_in(code, glob, loc=None):
if isinstance(code, basestring):
# exec(string) inherits the caller's future imports; compile
# the string first to prevent that.
code = compile(code, '<string>', 'exec', dont_inherit=True)
exec code in glob, loc
""")
class Configurable(object):
"""Base class for configurable interfaces.
A configurable interface is an (abstract) class whose constructor
acts as a factory function for one of its implementation subclasses.
The implementation subclass as well as optional keyword arguments to
its initializer can be set globally at runtime with `configure`.
By using the constructor as the factory method, the interface
looks like a normal class, `isinstance` works as usual, etc. This
pattern is most useful when the choice of implementation is likely
to be a global decision (e.g. when `~select.epoll` is available,
always use it instead of `~select.select`), or when a
previously-monolithic class has been split into specialized
subclasses.
Configurable subclasses must define the class methods
`configurable_base` and `configurable_default`, and use the instance
method `initialize` instead of ``__init__``.
"""
__impl_class = None
__impl_kwargs = None
def __new__(cls, **kwargs):
base = cls.configurable_base()
args = {}
if cls is base:
impl = cls.configured_class()
if base.__impl_kwargs:
args.update(base.__impl_kwargs)
else:
impl = cls
args.update(kwargs)
instance = super(Configurable, cls).__new__(impl)
# initialize vs __init__ chosen for compatiblity with AsyncHTTPClient
# singleton magic. If we get rid of that we can switch to __init__
# here too.
instance.initialize(**args)
return instance
@classmethod
def configurable_base(cls):
"""Returns the base class of a configurable hierarchy.
This will normally return the class in which it is defined.
(which is *not* necessarily the same as the cls classmethod parameter).
"""
raise NotImplementedError()
@classmethod
def configurable_default(cls):
"""Returns the implementation class to be used if none is configured."""
raise NotImplementedError()
def initialize(self):
"""Initialize a `Configurable` subclass instance.
Configurable classes should use `initialize` instead of ``__init__``.
"""
@classmethod
def configure(cls, impl, **kwargs):
"""Sets the class to use when the base class is instantiated.
Keyword arguments will be saved and added to the arguments passed
to the constructor. This can be used to set global defaults for
some parameters.
"""
base = cls.configurable_base()
if isinstance(impl, (unicode_type, bytes_type)):
impl = import_object(impl)
if impl is not None and not issubclass(impl, cls):
raise ValueError("Invalid subclass of %s" % cls)
base.__impl_class = impl
base.__impl_kwargs = kwargs
@classmethod
def configured_class(cls):
"""Returns the currently configured class."""
base = cls.configurable_base()
if cls.__impl_class is None:
base.__impl_class = cls.configurable_default()
return base.__impl_class
@classmethod
def _save_configuration(cls):
base = cls.configurable_base()
return (base.__impl_class, base.__impl_kwargs)
@classmethod
def _restore_configuration(cls, saved):
base = cls.configurable_base()
base.__impl_class = saved[0]
base.__impl_kwargs = saved[1]
class ArgReplacer(object):
"""Replaces one value in an ``args, kwargs`` pair.
Inspects the function signature to find an argument by name
whether it is passed by position or keyword. For use in decorators
and similar wrappers.
"""
def __init__(self, func, name):
self.name = name
try:
self.arg_pos = inspect.getargspec(func).args.index(self.name)
except ValueError:
# Not a positional parameter
self.arg_pos = None
def replace(self, new_value, args, kwargs):
"""Replace the named argument in ``args, kwargs`` with ``new_value``.
Returns ``(old_value, args, kwargs)``. The returned ``args`` and
``kwargs`` objects may not be the same as the input objects, or
the input objects may be mutated.
If the named argument was not found, ``new_value`` will be added
to ``kwargs`` and None will be returned as ``old_value``.
"""
if self.arg_pos is not None and len(args) > self.arg_pos:
# The arg to replace is passed positionally
old_value = args[self.arg_pos]
args = list(args) # *args is normally a tuple
args[self.arg_pos] = new_value
else:
# The arg to replace is either omitted or passed by keyword.
old_value = kwargs.get(self.name)
kwargs[self.name] = new_value
return old_value, args, kwargs
def doctests():
import doctest
return doctest.DocTestSuite()

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,862 @@
"""Implementation of the WebSocket protocol.
`WebSockets <http://dev.w3.org/html5/websockets/>`_ allow for bidirectional
communication between the browser and server.
.. warning::
The WebSocket protocol was recently finalized as `RFC 6455
<http://tools.ietf.org/html/rfc6455>`_ and is not yet supported in
all browsers. Refer to http://caniuse.com/websockets for details
on compatibility. In addition, during development the protocol
went through several incompatible versions, and some browsers only
support older versions. By default this module only supports the
latest version of the protocol, but optional support for an older
version (known as "draft 76" or "hixie-76") can be enabled by
overriding `WebSocketHandler.allow_draft76` (see that method's
documentation for caveats).
"""
from __future__ import absolute_import, division, print_function, with_statement
# Author: Jacob Kristhammar, 2010
import array
import base64
import collections
import functools
import hashlib
import os
import struct
import time
import tornado.escape
import tornado.web
from tornado.concurrent import Future
from tornado.escape import utf8, native_str
from tornado import httpclient
from tornado.ioloop import IOLoop
from tornado.iostream import StreamClosedError
from tornado.log import gen_log, app_log
from tornado.netutil import Resolver
from tornado import simple_httpclient
from tornado.util import bytes_type, unicode_type
try:
xrange # py2
except NameError:
xrange = range # py3
class WebSocketError(Exception):
pass
class WebSocketHandler(tornado.web.RequestHandler):
"""Subclass this class to create a basic WebSocket handler.
Override `on_message` to handle incoming messages, and use
`write_message` to send messages to the client. You can also
override `open` and `on_close` to handle opened and closed
connections.
See http://dev.w3.org/html5/websockets/ for details on the
JavaScript interface. The protocol is specified at
http://tools.ietf.org/html/rfc6455.
Here is an example WebSocket handler that echos back all received messages
back to the client::
class EchoWebSocket(websocket.WebSocketHandler):
def open(self):
print "WebSocket opened"
def on_message(self, message):
self.write_message(u"You said: " + message)
def on_close(self):
print "WebSocket closed"
WebSockets are not standard HTTP connections. The "handshake" is
HTTP, but after the handshake, the protocol is
message-based. Consequently, most of the Tornado HTTP facilities
are not available in handlers of this type. The only communication
methods available to you are `write_message()`, `ping()`, and
`close()`. Likewise, your request handler class should implement
`open()` method rather than ``get()`` or ``post()``.
If you map the handler above to ``/websocket`` in your application, you can
invoke it in JavaScript with::
var ws = new WebSocket("ws://localhost:8888/websocket");
ws.onopen = function() {
ws.send("Hello, world");
};
ws.onmessage = function (evt) {
alert(evt.data);
};
This script pops up an alert box that says "You said: Hello, world".
"""
def __init__(self, application, request, **kwargs):
tornado.web.RequestHandler.__init__(self, application, request,
**kwargs)
self.stream = request.connection.stream
self.ws_connection = None
def _execute(self, transforms, *args, **kwargs):
self.open_args = args
self.open_kwargs = kwargs
# Websocket only supports GET method
if self.request.method != 'GET':
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 405 Method Not Allowed\r\n\r\n"
))
self.stream.close()
return
# Upgrade header should be present and should be equal to WebSocket
if self.request.headers.get("Upgrade", "").lower() != 'websocket':
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 400 Bad Request\r\n\r\n"
"Can \"Upgrade\" only to \"WebSocket\"."
))
self.stream.close()
return
# Connection header should be upgrade. Some proxy servers/load balancers
# might mess with it.
headers = self.request.headers
connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(","))
if 'upgrade' not in connection:
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 400 Bad Request\r\n\r\n"
"\"Connection\" must be \"Upgrade\"."
))
self.stream.close()
return
# The difference between version 8 and 13 is that in 8 the
# client sends a "Sec-Websocket-Origin" header and in 13 it's
# simply "Origin".
if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
self.ws_connection = WebSocketProtocol13(self)
self.ws_connection.accept_connection()
elif (self.allow_draft76() and
"Sec-WebSocket-Version" not in self.request.headers):
self.ws_connection = WebSocketProtocol76(self)
self.ws_connection.accept_connection()
else:
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 426 Upgrade Required\r\n"
"Sec-WebSocket-Version: 8\r\n\r\n"))
self.stream.close()
def write_message(self, message, binary=False):
"""Sends the given message to the client of this Web Socket.
The message may be either a string or a dict (which will be
encoded as json). If the ``binary`` argument is false, the
message will be sent as utf8; in binary mode any byte string
is allowed.
"""
if isinstance(message, dict):
message = tornado.escape.json_encode(message)
self.ws_connection.write_message(message, binary=binary)
def select_subprotocol(self, subprotocols):
"""Invoked when a new WebSocket requests specific subprotocols.
``subprotocols`` is a list of strings identifying the
subprotocols proposed by the client. This method may be
overridden to return one of those strings to select it, or
``None`` to not select a subprotocol. Failure to select a
subprotocol does not automatically abort the connection,
although clients may close the connection if none of their
proposed subprotocols was selected.
"""
return None
def open(self):
"""Invoked when a new WebSocket is opened.
The arguments to `open` are extracted from the `tornado.web.URLSpec`
regular expression, just like the arguments to
`tornado.web.RequestHandler.get`.
"""
pass
def on_message(self, message):
"""Handle incoming messages on the WebSocket
This method must be overridden.
"""
raise NotImplementedError
def ping(self, data):
"""Send ping frame to the remote end."""
self.ws_connection.write_ping(data)
def on_pong(self, data):
"""Invoked when the response to a ping frame is received."""
pass
def on_close(self):
"""Invoked when the WebSocket is closed."""
pass
def close(self):
"""Closes this Web Socket.
Once the close handshake is successful the socket will be closed.
"""
self.ws_connection.close()
self.ws_connection = None
def allow_draft76(self):
"""Override to enable support for the older "draft76" protocol.
The draft76 version of the websocket protocol is disabled by
default due to security concerns, but it can be enabled by
overriding this method to return True.
Connections using the draft76 protocol do not support the
``binary=True`` flag to `write_message`.
Support for the draft76 protocol is deprecated and will be
removed in a future version of Tornado.
"""
return False
def set_nodelay(self, value):
"""Set the no-delay flag for this stream.
By default, small messages may be delayed and/or combined to minimize
the number of packets sent. This can sometimes cause 200-500ms delays
due to the interaction between Nagle's algorithm and TCP delayed
ACKs. To reduce this delay (at the expense of possibly increasing
bandwidth usage), call ``self.set_nodelay(True)`` once the websocket
connection is established.
See `.BaseIOStream.set_nodelay` for additional details.
.. versionadded:: 3.1
"""
self.stream.set_nodelay(value)
def get_websocket_scheme(self):
"""Return the url scheme used for this request, either "ws" or "wss".
This is normally decided by HTTPServer, but applications
may wish to override this if they are using an SSL proxy
that does not provide the X-Scheme header as understood
by HTTPServer.
Note that this is only used by the draft76 protocol.
"""
return "wss" if self.request.protocol == "https" else "ws"
def async_callback(self, callback, *args, **kwargs):
"""Obsolete - catches exceptions from the wrapped function.
This function is normally unncecessary thanks to
`tornado.stack_context`.
"""
return self.ws_connection.async_callback(callback, *args, **kwargs)
def _not_supported(self, *args, **kwargs):
raise Exception("Method not supported for Web Sockets")
def on_connection_close(self):
if self.ws_connection:
self.ws_connection.on_connection_close()
self.ws_connection = None
self.on_close()
for method in ["write", "redirect", "set_header", "send_error", "set_cookie",
"set_status", "flush", "finish"]:
setattr(WebSocketHandler, method, WebSocketHandler._not_supported)
class WebSocketProtocol(object):
"""Base class for WebSocket protocol versions.
"""
def __init__(self, handler):
self.handler = handler
self.request = handler.request
self.stream = handler.stream
self.client_terminated = False
self.server_terminated = False
def async_callback(self, callback, *args, **kwargs):
"""Wrap callbacks with this if they are used on asynchronous requests.
Catches exceptions properly and closes this WebSocket if an exception
is uncaught.
"""
if args or kwargs:
callback = functools.partial(callback, *args, **kwargs)
def wrapper(*args, **kwargs):
try:
return callback(*args, **kwargs)
except Exception:
app_log.error("Uncaught exception in %s",
self.request.path, exc_info=True)
self._abort()
return wrapper
def on_connection_close(self):
self._abort()
def _abort(self):
"""Instantly aborts the WebSocket connection by closing the socket"""
self.client_terminated = True
self.server_terminated = True
self.stream.close() # forcibly tear down the connection
self.close() # let the subclass cleanup
class WebSocketProtocol76(WebSocketProtocol):
"""Implementation of the WebSockets protocol, version hixie-76.
This class provides basic functionality to process WebSockets requests as
specified in
http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76
"""
def __init__(self, handler):
WebSocketProtocol.__init__(self, handler)
self.challenge = None
self._waiting = None
def accept_connection(self):
try:
self._handle_websocket_headers()
except ValueError:
gen_log.debug("Malformed WebSocket request received")
self._abort()
return
scheme = self.handler.get_websocket_scheme()
# draft76 only allows a single subprotocol
subprotocol_header = ''
subprotocol = self.request.headers.get("Sec-WebSocket-Protocol", None)
if subprotocol:
selected = self.handler.select_subprotocol([subprotocol])
if selected:
assert selected == subprotocol
subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected
# Write the initial headers before attempting to read the challenge.
# This is necessary when using proxies (such as HAProxy), which
# need to see the Upgrade headers before passing through the
# non-HTTP traffic that follows.
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
"Upgrade: WebSocket\r\n"
"Connection: Upgrade\r\n"
"Server: TornadoServer/%(version)s\r\n"
"Sec-WebSocket-Origin: %(origin)s\r\n"
"Sec-WebSocket-Location: %(scheme)s://%(host)s%(uri)s\r\n"
"%(subprotocol)s"
"\r\n" % (dict(
version=tornado.version,
origin=self.request.headers["Origin"],
scheme=scheme,
host=self.request.host,
uri=self.request.uri,
subprotocol=subprotocol_header))))
self.stream.read_bytes(8, self._handle_challenge)
def challenge_response(self, challenge):
"""Generates the challenge response that's needed in the handshake
The challenge parameter should be the raw bytes as sent from the
client.
"""
key_1 = self.request.headers.get("Sec-Websocket-Key1")
key_2 = self.request.headers.get("Sec-Websocket-Key2")
try:
part_1 = self._calculate_part(key_1)
part_2 = self._calculate_part(key_2)
except ValueError:
raise ValueError("Invalid Keys/Challenge")
return self._generate_challenge_response(part_1, part_2, challenge)
def _handle_challenge(self, challenge):
try:
challenge_response = self.challenge_response(challenge)
except ValueError:
gen_log.debug("Malformed key data in WebSocket request")
self._abort()
return
self._write_response(challenge_response)
def _write_response(self, challenge):
self.stream.write(challenge)
self.async_callback(self.handler.open)(*self.handler.open_args, **self.handler.open_kwargs)
self._receive_message()
def _handle_websocket_headers(self):
"""Verifies all invariant- and required headers
If a header is missing or have an incorrect value ValueError will be
raised
"""
fields = ("Origin", "Host", "Sec-Websocket-Key1",
"Sec-Websocket-Key2")
if not all(map(lambda f: self.request.headers.get(f), fields)):
raise ValueError("Missing/Invalid WebSocket headers")
def _calculate_part(self, key):
"""Processes the key headers and calculates their key value.
Raises ValueError when feed invalid key."""
# pyflakes complains about variable reuse if both of these lines use 'c'
number = int(''.join(c for c in key if c.isdigit()))
spaces = len([c2 for c2 in key if c2.isspace()])
try:
key_number = number // spaces
except (ValueError, ZeroDivisionError):
raise ValueError
return struct.pack(">I", key_number)
def _generate_challenge_response(self, part_1, part_2, part_3):
m = hashlib.md5()
m.update(part_1)
m.update(part_2)
m.update(part_3)
return m.digest()
def _receive_message(self):
self.stream.read_bytes(1, self._on_frame_type)
def _on_frame_type(self, byte):
frame_type = ord(byte)
if frame_type == 0x00:
self.stream.read_until(b"\xff", self._on_end_delimiter)
elif frame_type == 0xff:
self.stream.read_bytes(1, self._on_length_indicator)
else:
self._abort()
def _on_end_delimiter(self, frame):
if not self.client_terminated:
self.async_callback(self.handler.on_message)(
frame[:-1].decode("utf-8", "replace"))
if not self.client_terminated:
self._receive_message()
def _on_length_indicator(self, byte):
if ord(byte) != 0x00:
self._abort()
return
self.client_terminated = True
self.close()
def write_message(self, message, binary=False):
"""Sends the given message to the client of this Web Socket."""
if binary:
raise ValueError(
"Binary messages not supported by this version of websockets")
if isinstance(message, unicode_type):
message = message.encode("utf-8")
assert isinstance(message, bytes_type)
self.stream.write(b"\x00" + message + b"\xff")
def write_ping(self, data):
"""Send ping frame."""
raise ValueError("Ping messages not supported by this version of websockets")
def close(self):
"""Closes the WebSocket connection."""
if not self.server_terminated:
if not self.stream.closed():
self.stream.write("\xff\x00")
self.server_terminated = True
if self.client_terminated:
if self._waiting is not None:
self.stream.io_loop.remove_timeout(self._waiting)
self._waiting = None
self.stream.close()
elif self._waiting is None:
self._waiting = self.stream.io_loop.add_timeout(
time.time() + 5, self._abort)
class WebSocketProtocol13(WebSocketProtocol):
"""Implementation of the WebSocket protocol from RFC 6455.
This class supports versions 7 and 8 of the protocol in addition to the
final version 13.
"""
def __init__(self, handler, mask_outgoing=False):
WebSocketProtocol.__init__(self, handler)
self.mask_outgoing = mask_outgoing
self._final_frame = False
self._frame_opcode = None
self._masked_frame = None
self._frame_mask = None
self._frame_length = None
self._fragmented_message_buffer = None
self._fragmented_message_opcode = None
self._waiting = None
def accept_connection(self):
try:
self._handle_websocket_headers()
self._accept_connection()
except ValueError:
gen_log.debug("Malformed WebSocket request received", exc_info=True)
self._abort()
return
def _handle_websocket_headers(self):
"""Verifies all invariant- and required headers
If a header is missing or have an incorrect value ValueError will be
raised
"""
fields = ("Host", "Sec-Websocket-Key", "Sec-Websocket-Version")
if not all(map(lambda f: self.request.headers.get(f), fields)):
raise ValueError("Missing/Invalid WebSocket headers")
@staticmethod
def compute_accept_value(key):
"""Computes the value for the Sec-WebSocket-Accept header,
given the value for Sec-WebSocket-Key.
"""
sha1 = hashlib.sha1()
sha1.update(utf8(key))
sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11") # Magic value
return native_str(base64.b64encode(sha1.digest()))
def _challenge_response(self):
return WebSocketProtocol13.compute_accept_value(
self.request.headers.get("Sec-Websocket-Key"))
def _accept_connection(self):
subprotocol_header = ''
subprotocols = self.request.headers.get("Sec-WebSocket-Protocol", '')
subprotocols = [s.strip() for s in subprotocols.split(',')]
if subprotocols:
selected = self.handler.select_subprotocol(subprotocols)
if selected:
assert selected in subprotocols
subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Accept: %s\r\n"
"%s"
"\r\n" % (self._challenge_response(), subprotocol_header)))
self.async_callback(self.handler.open)(*self.handler.open_args, **self.handler.open_kwargs)
self._receive_frame()
def _write_frame(self, fin, opcode, data):
if fin:
finbit = 0x80
else:
finbit = 0
frame = struct.pack("B", finbit | opcode)
l = len(data)
if self.mask_outgoing:
mask_bit = 0x80
else:
mask_bit = 0
if l < 126:
frame += struct.pack("B", l | mask_bit)
elif l <= 0xFFFF:
frame += struct.pack("!BH", 126 | mask_bit, l)
else:
frame += struct.pack("!BQ", 127 | mask_bit, l)
if self.mask_outgoing:
mask = os.urandom(4)
data = mask + self._apply_mask(mask, data)
frame += data
self.stream.write(frame)
def write_message(self, message, binary=False):
"""Sends the given message to the client of this Web Socket."""
if binary:
opcode = 0x2
else:
opcode = 0x1
message = tornado.escape.utf8(message)
assert isinstance(message, bytes_type)
try:
self._write_frame(True, opcode, message)
except StreamClosedError:
self._abort()
def write_ping(self, data):
"""Send ping frame."""
assert isinstance(data, bytes_type)
self._write_frame(True, 0x9, data)
def _receive_frame(self):
try:
self.stream.read_bytes(2, self._on_frame_start)
except StreamClosedError:
self._abort()
def _on_frame_start(self, data):
header, payloadlen = struct.unpack("BB", data)
self._final_frame = header & 0x80
reserved_bits = header & 0x70
self._frame_opcode = header & 0xf
self._frame_opcode_is_control = self._frame_opcode & 0x8
if reserved_bits:
# client is using as-yet-undefined extensions; abort
self._abort()
return
self._masked_frame = bool(payloadlen & 0x80)
payloadlen = payloadlen & 0x7f
if self._frame_opcode_is_control and payloadlen >= 126:
# control frames must have payload < 126
self._abort()
return
try:
if payloadlen < 126:
self._frame_length = payloadlen
if self._masked_frame:
self.stream.read_bytes(4, self._on_masking_key)
else:
self.stream.read_bytes(self._frame_length, self._on_frame_data)
elif payloadlen == 126:
self.stream.read_bytes(2, self._on_frame_length_16)
elif payloadlen == 127:
self.stream.read_bytes(8, self._on_frame_length_64)
except StreamClosedError:
self._abort()
def _on_frame_length_16(self, data):
self._frame_length = struct.unpack("!H", data)[0]
try:
if self._masked_frame:
self.stream.read_bytes(4, self._on_masking_key)
else:
self.stream.read_bytes(self._frame_length, self._on_frame_data)
except StreamClosedError:
self._abort()
def _on_frame_length_64(self, data):
self._frame_length = struct.unpack("!Q", data)[0]
try:
if self._masked_frame:
self.stream.read_bytes(4, self._on_masking_key)
else:
self.stream.read_bytes(self._frame_length, self._on_frame_data)
except StreamClosedError:
self._abort()
def _on_masking_key(self, data):
self._frame_mask = data
try:
self.stream.read_bytes(self._frame_length, self._on_masked_frame_data)
except StreamClosedError:
self._abort()
def _apply_mask(self, mask, data):
mask = array.array("B", mask)
unmasked = array.array("B", data)
for i in xrange(len(data)):
unmasked[i] = unmasked[i] ^ mask[i % 4]
if hasattr(unmasked, 'tobytes'):
# tostring was deprecated in py32. It hasn't been removed,
# but since we turn on deprecation warnings in our tests
# we need to use the right one.
return unmasked.tobytes()
else:
return unmasked.tostring()
def _on_masked_frame_data(self, data):
self._on_frame_data(self._apply_mask(self._frame_mask, data))
def _on_frame_data(self, data):
if self._frame_opcode_is_control:
# control frames may be interleaved with a series of fragmented
# data frames, so control frames must not interact with
# self._fragmented_*
if not self._final_frame:
# control frames must not be fragmented
self._abort()
return
opcode = self._frame_opcode
elif self._frame_opcode == 0: # continuation frame
if self._fragmented_message_buffer is None:
# nothing to continue
self._abort()
return
self._fragmented_message_buffer += data
if self._final_frame:
opcode = self._fragmented_message_opcode
data = self._fragmented_message_buffer
self._fragmented_message_buffer = None
else: # start of new data message
if self._fragmented_message_buffer is not None:
# can't start new message until the old one is finished
self._abort()
return
if self._final_frame:
opcode = self._frame_opcode
else:
self._fragmented_message_opcode = self._frame_opcode
self._fragmented_message_buffer = data
if self._final_frame:
self._handle_message(opcode, data)
if not self.client_terminated:
self._receive_frame()
def _handle_message(self, opcode, data):
if self.client_terminated:
return
if opcode == 0x1:
# UTF-8 data
try:
decoded = data.decode("utf-8")
except UnicodeDecodeError:
self._abort()
return
self.async_callback(self.handler.on_message)(decoded)
elif opcode == 0x2:
# Binary data
self.async_callback(self.handler.on_message)(data)
elif opcode == 0x8:
# Close
self.client_terminated = True
self.close()
elif opcode == 0x9:
# Ping
self._write_frame(True, 0xA, data)
elif opcode == 0xA:
# Pong
self.async_callback(self.handler.on_pong)(data)
else:
self._abort()
def close(self):
"""Closes the WebSocket connection."""
if not self.server_terminated:
if not self.stream.closed():
self._write_frame(True, 0x8, b"")
self.server_terminated = True
if self.client_terminated:
if self._waiting is not None:
self.stream.io_loop.remove_timeout(self._waiting)
self._waiting = None
self.stream.close()
elif self._waiting is None:
# Give the client a few seconds to complete a clean shutdown,
# otherwise just close the connection.
self._waiting = self.stream.io_loop.add_timeout(
self.stream.io_loop.time() + 5, self._abort)
class WebSocketClientConnection(simple_httpclient._HTTPConnection):
"""WebSocket client connection."""
def __init__(self, io_loop, request):
self.connect_future = Future()
self.read_future = None
self.read_queue = collections.deque()
self.key = base64.b64encode(os.urandom(16))
scheme, sep, rest = request.url.partition(':')
scheme = {'ws': 'http', 'wss': 'https'}[scheme]
request.url = scheme + sep + rest
request.headers.update({
'Upgrade': 'websocket',
'Connection': 'Upgrade',
'Sec-WebSocket-Key': self.key,
'Sec-WebSocket-Version': '13',
})
self.resolver = Resolver(io_loop=io_loop)
super(WebSocketClientConnection, self).__init__(
io_loop, None, request, lambda: None, self._on_http_response,
104857600, self.resolver)
def _on_close(self):
self.on_message(None)
self.resolver.close()
def _on_http_response(self, response):
if not self.connect_future.done():
if response.error:
self.connect_future.set_exception(response.error)
else:
self.connect_future.set_exception(WebSocketError(
"Non-websocket response"))
def _handle_1xx(self, code):
assert code == 101
assert self.headers['Upgrade'].lower() == 'websocket'
assert self.headers['Connection'].lower() == 'upgrade'
accept = WebSocketProtocol13.compute_accept_value(self.key)
assert self.headers['Sec-Websocket-Accept'] == accept
self.protocol = WebSocketProtocol13(self, mask_outgoing=True)
self.protocol._receive_frame()
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
self._timeout = None
self.connect_future.set_result(self)
def write_message(self, message, binary=False):
"""Sends a message to the WebSocket server."""
self.protocol.write_message(message, binary)
def read_message(self, callback=None):
"""Reads a message from the WebSocket server.
Returns a future whose result is the message, or None
if the connection is closed. If a callback argument
is given it will be called with the future when it is
ready.
"""
assert self.read_future is None
future = Future()
if self.read_queue:
future.set_result(self.read_queue.popleft())
else:
self.read_future = future
if callback is not None:
self.io_loop.add_future(future, callback)
return future
def on_message(self, message):
if self.read_future is not None:
self.read_future.set_result(message)
self.read_future = None
else:
self.read_queue.append(message)
def on_pong(self, data):
pass
def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None):
"""Client-side websocket support.
Takes a url and returns a Future whose result is a
`WebSocketClientConnection`.
"""
if io_loop is None:
io_loop = IOLoop.current()
request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
request = httpclient._RequestProxy(
request, httpclient.HTTPRequest._DEFAULTS)
conn = WebSocketClientConnection(io_loop, request)
if callback is not None:
io_loop.add_future(conn.connect_future, callback)
return conn.connect_future

View file

@ -0,0 +1,320 @@
#!/usr/bin/env python
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""WSGI support for the Tornado web framework.
WSGI is the Python standard for web servers, and allows for interoperability
between Tornado and other Python web frameworks and servers. This module
provides WSGI support in two ways:
* `WSGIApplication` is a version of `tornado.web.Application` that can run
inside a WSGI server. This is useful for running a Tornado app on another
HTTP server, such as Google App Engine. See the `WSGIApplication` class
documentation for limitations that apply.
* `WSGIContainer` lets you run other WSGI applications and frameworks on the
Tornado HTTP server. For example, with this class you can mix Django
and Tornado handlers in a single server.
"""
from __future__ import absolute_import, division, print_function, with_statement
import sys
import time
import tornado
from tornado import escape
from tornado import httputil
from tornado.log import access_log
from tornado import web
from tornado.escape import native_str, parse_qs_bytes
from tornado.util import bytes_type, unicode_type
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO # python 2
try:
import Cookie # py2
except ImportError:
import http.cookies as Cookie # py3
try:
import urllib.parse as urllib_parse # py3
except ImportError:
import urllib as urllib_parse
# PEP 3333 specifies that WSGI on python 3 generally deals with byte strings
# that are smuggled inside objects of type unicode (via the latin1 encoding).
# These functions are like those in the tornado.escape module, but defined
# here to minimize the temptation to use them in non-wsgi contexts.
if str is unicode_type:
def to_wsgi_str(s):
assert isinstance(s, bytes_type)
return s.decode('latin1')
def from_wsgi_str(s):
assert isinstance(s, str)
return s.encode('latin1')
else:
def to_wsgi_str(s):
assert isinstance(s, bytes_type)
return s
def from_wsgi_str(s):
assert isinstance(s, str)
return s
class WSGIApplication(web.Application):
"""A WSGI equivalent of `tornado.web.Application`.
`WSGIApplication` is very similar to `tornado.web.Application`,
except no asynchronous methods are supported (since WSGI does not
support non-blocking requests properly). If you call
``self.flush()`` or other asynchronous methods in your request
handlers running in a `WSGIApplication`, we throw an exception.
Example usage::
import tornado.web
import tornado.wsgi
import wsgiref.simple_server
class MainHandler(tornado.web.RequestHandler):
def get(self):
self.write("Hello, world")
if __name__ == "__main__":
application = tornado.wsgi.WSGIApplication([
(r"/", MainHandler),
])
server = wsgiref.simple_server.make_server('', 8888, application)
server.serve_forever()
See the `appengine demo
<https://github.com/facebook/tornado/tree/master/demos/appengine>`_
for an example of using this module to run a Tornado app on Google
App Engine.
WSGI applications use the same `.RequestHandler` class, but not
``@asynchronous`` methods or ``flush()``. This means that it is
not possible to use `.AsyncHTTPClient`, or the `tornado.auth` or
`tornado.websocket` modules.
"""
def __init__(self, handlers=None, default_host="", **settings):
web.Application.__init__(self, handlers, default_host, transforms=[],
wsgi=True, **settings)
def __call__(self, environ, start_response):
handler = web.Application.__call__(self, HTTPRequest(environ))
assert handler._finished
reason = handler._reason
status = str(handler._status_code) + " " + reason
headers = list(handler._headers.get_all())
if hasattr(handler, "_new_cookie"):
for cookie in handler._new_cookie.values():
headers.append(("Set-Cookie", cookie.OutputString(None)))
start_response(status,
[(native_str(k), native_str(v)) for (k, v) in headers])
return handler._write_buffer
class HTTPRequest(object):
"""Mimics `tornado.httpserver.HTTPRequest` for WSGI applications."""
def __init__(self, environ):
"""Parses the given WSGI environment to construct the request."""
self.method = environ["REQUEST_METHOD"]
self.path = urllib_parse.quote(from_wsgi_str(environ.get("SCRIPT_NAME", "")))
self.path += urllib_parse.quote(from_wsgi_str(environ.get("PATH_INFO", "")))
self.uri = self.path
self.arguments = {}
self.query = environ.get("QUERY_STRING", "")
if self.query:
self.uri += "?" + self.query
self.arguments = parse_qs_bytes(native_str(self.query),
keep_blank_values=True)
self.version = "HTTP/1.1"
self.headers = httputil.HTTPHeaders()
if environ.get("CONTENT_TYPE"):
self.headers["Content-Type"] = environ["CONTENT_TYPE"]
if environ.get("CONTENT_LENGTH"):
self.headers["Content-Length"] = environ["CONTENT_LENGTH"]
for key in environ:
if key.startswith("HTTP_"):
self.headers[key[5:].replace("_", "-")] = environ[key]
if self.headers.get("Content-Length"):
self.body = environ["wsgi.input"].read(
int(self.headers["Content-Length"]))
else:
self.body = ""
self.protocol = environ["wsgi.url_scheme"]
self.remote_ip = environ.get("REMOTE_ADDR", "")
if environ.get("HTTP_HOST"):
self.host = environ["HTTP_HOST"]
else:
self.host = environ["SERVER_NAME"]
# Parse request body
self.files = {}
httputil.parse_body_arguments(self.headers.get("Content-Type", ""),
self.body, self.arguments, self.files)
self._start_time = time.time()
self._finish_time = None
def supports_http_1_1(self):
"""Returns True if this request supports HTTP/1.1 semantics"""
return self.version == "HTTP/1.1"
@property
def cookies(self):
"""A dictionary of Cookie.Morsel objects."""
if not hasattr(self, "_cookies"):
self._cookies = Cookie.SimpleCookie()
if "Cookie" in self.headers:
try:
self._cookies.load(
native_str(self.headers["Cookie"]))
except Exception:
self._cookies = None
return self._cookies
def full_url(self):
"""Reconstructs the full URL for this request."""
return self.protocol + "://" + self.host + self.uri
def request_time(self):
"""Returns the amount of time it took for this request to execute."""
if self._finish_time is None:
return time.time() - self._start_time
else:
return self._finish_time - self._start_time
class WSGIContainer(object):
r"""Makes a WSGI-compatible function runnable on Tornado's HTTP server.
Wrap a WSGI function in a `WSGIContainer` and pass it to `.HTTPServer` to
run it. For example::
def simple_app(environ, start_response):
status = "200 OK"
response_headers = [("Content-type", "text/plain")]
start_response(status, response_headers)
return ["Hello world!\n"]
container = tornado.wsgi.WSGIContainer(simple_app)
http_server = tornado.httpserver.HTTPServer(container)
http_server.listen(8888)
tornado.ioloop.IOLoop.instance().start()
This class is intended to let other frameworks (Django, web.py, etc)
run on the Tornado HTTP server and I/O loop.
The `tornado.web.FallbackHandler` class is often useful for mixing
Tornado and WSGI apps in the same server. See
https://github.com/bdarnell/django-tornado-demo for a complete example.
"""
def __init__(self, wsgi_application):
self.wsgi_application = wsgi_application
def __call__(self, request):
data = {}
response = []
def start_response(status, response_headers, exc_info=None):
data["status"] = status
data["headers"] = response_headers
return response.append
app_response = self.wsgi_application(
WSGIContainer.environ(request), start_response)
response.extend(app_response)
body = b"".join(response)
if hasattr(app_response, "close"):
app_response.close()
if not data:
raise Exception("WSGI app did not call start_response")
status_code = int(data["status"].split()[0])
headers = data["headers"]
header_set = set(k.lower() for (k, v) in headers)
body = escape.utf8(body)
if status_code != 304:
if "content-length" not in header_set:
headers.append(("Content-Length", str(len(body))))
if "content-type" not in header_set:
headers.append(("Content-Type", "text/html; charset=UTF-8"))
if "server" not in header_set:
headers.append(("Server", "TornadoServer/%s" % tornado.version))
parts = [escape.utf8("HTTP/1.1 " + data["status"] + "\r\n")]
for key, value in headers:
parts.append(escape.utf8(key) + b": " + escape.utf8(value) + b"\r\n")
parts.append(b"\r\n")
parts.append(body)
request.write(b"".join(parts))
request.finish()
self._log(status_code, request)
@staticmethod
def environ(request):
"""Converts a `tornado.httpserver.HTTPRequest` to a WSGI environment.
"""
hostport = request.host.split(":")
if len(hostport) == 2:
host = hostport[0]
port = int(hostport[1])
else:
host = request.host
port = 443 if request.protocol == "https" else 80
environ = {
"REQUEST_METHOD": request.method,
"SCRIPT_NAME": "",
"PATH_INFO": to_wsgi_str(escape.url_unescape(
request.path, encoding=None, plus=False)),
"QUERY_STRING": request.query,
"REMOTE_ADDR": request.remote_ip,
"SERVER_NAME": host,
"SERVER_PORT": str(port),
"SERVER_PROTOCOL": request.version,
"wsgi.version": (1, 0),
"wsgi.url_scheme": request.protocol,
"wsgi.input": BytesIO(escape.utf8(request.body)),
"wsgi.errors": sys.stderr,
"wsgi.multithread": False,
"wsgi.multiprocess": True,
"wsgi.run_once": False,
}
if "Content-Type" in request.headers:
environ["CONTENT_TYPE"] = request.headers.pop("Content-Type")
if "Content-Length" in request.headers:
environ["CONTENT_LENGTH"] = request.headers.pop("Content-Length")
for key, value in request.headers.items():
environ["HTTP_" + key.replace("-", "_").upper()] = value
return environ
def _log(self, status_code, request):
if status_code < 400:
log_method = access_log.info
elif status_code < 500:
log_method = access_log.warning
else:
log_method = access_log.error
request_time = 1000.0 * request.request_time()
summary = request.method + " " + request.uri + " (" + \
request.remote_ip + ")"
log_method("%d %s %.2fms", status_code, summary, request_time)