Open Media Library Platform

This commit is contained in:
j 2013-10-11 19:28:32 +02:00
commit 411ad5b16f
5849 changed files with 1778641 additions and 0 deletions

View file

@ -0,0 +1,7 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web}.
"""

View file

@ -0,0 +1,82 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
General helpers for L{twisted.web} unit tests.
"""
from twisted.internet.defer import succeed
from twisted.web import server
from twisted.trial.unittest import TestCase
from twisted.python.failure import Failure
from twisted.web._flatten import flattenString
from twisted.web.error import FlattenerError
def _render(resource, request):
result = resource.render(request)
if isinstance(result, str):
request.write(result)
request.finish()
return succeed(None)
elif result is server.NOT_DONE_YET:
if request.finished:
return succeed(None)
else:
return request.notifyFinish()
else:
raise ValueError("Unexpected return value: %r" % (result,))
class FlattenTestCase(TestCase):
"""
A test case that assists with testing L{twisted.web._flatten}.
"""
def assertFlattensTo(self, root, target):
"""
Assert that a root element, when flattened, is equal to a string.
"""
d = flattenString(None, root)
d.addCallback(lambda s: self.assertEqual(s, target))
return d
def assertFlattensImmediately(self, root, target):
"""
Assert that a root element, when flattened, is equal to a string, and
performs no asynchronus Deferred anything.
This version is more convenient in tests which wish to make multiple
assertions about flattening, since it can be called multiple times
without having to add multiple callbacks.
@return: the result of rendering L{root}, which should be equivalent to
L{target}.
@rtype: L{bytes}
"""
results = []
it = self.assertFlattensTo(root, target)
it.addBoth(results.append)
# Do our best to clean it up if something goes wrong.
self.addCleanup(it.cancel)
if not results:
self.fail("Rendering did not complete immediately.")
result = results[0]
if isinstance(result, Failure):
result.raiseException()
return results[0]
def assertFlatteningRaises(self, root, exn):
"""
Assert flattening a root element raises a particular exception.
"""
d = self.assertFailure(self.assertFlattensTo(root, ''), FlattenerError)
d.addCallback(lambda exc: self.assertIsInstance(exc._exception, exn))
return d

View file

@ -0,0 +1,278 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Helpers related to HTTP requests, used by tests.
"""
from __future__ import division, absolute_import
__all__ = ['DummyChannel', 'DummyRequest']
from io import BytesIO
from zope.interface import implementer
from twisted.internet.defer import Deferred
from twisted.internet.address import IPv4Address
from twisted.internet.interfaces import ISSLTransport
from twisted.web.http_headers import Headers
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET, Session, Site
class DummyChannel:
class TCP:
port = 80
disconnected = False
def __init__(self):
self.written = BytesIO()
self.producers = []
def getPeer(self):
return IPv4Address("TCP", '192.168.1.1', 12344)
def write(self, data):
if not isinstance(data, bytes):
raise TypeError("Can only write bytes to a transport, not %r" % (data,))
self.written.write(data)
def writeSequence(self, iovec):
for data in iovec:
self.write(data)
def getHost(self):
return IPv4Address("TCP", '10.0.0.1', self.port)
def registerProducer(self, producer, streaming):
self.producers.append((producer, streaming))
def loseConnection(self):
self.disconnected = True
@implementer(ISSLTransport)
class SSL(TCP):
pass
site = Site(Resource())
def __init__(self):
self.transport = self.TCP()
def requestDone(self, request):
pass
class DummyRequest(object):
"""
Represents a dummy or fake request.
@ivar _finishedDeferreds: C{None} or a C{list} of L{Deferreds} which will
be called back with C{None} when C{finish} is called or which will be
errbacked if C{processingFailed} is called.
@type headers: C{dict}
@ivar headers: A mapping of header name to header value for all request
headers.
@type outgoingHeaders: C{dict}
@ivar outgoingHeaders: A mapping of header name to header value for all
response headers.
@type responseCode: C{int}
@ivar responseCode: The response code which was passed to
C{setResponseCode}.
@type written: C{list} of C{bytes}
@ivar written: The bytes which have been written to the request.
"""
uri = b'http://dummy/'
method = b'GET'
client = None
def registerProducer(self, prod,s):
self.go = 1
while self.go:
prod.resumeProducing()
def unregisterProducer(self):
self.go = 0
def __init__(self, postpath, session=None):
self.sitepath = []
self.written = []
self.finished = 0
self.postpath = postpath
self.prepath = []
self.session = None
self.protoSession = session or Session(0, self)
self.args = {}
self.outgoingHeaders = {}
self.requestHeaders = Headers()
self.responseHeaders = Headers()
self.responseCode = None
self.headers = {}
self._finishedDeferreds = []
self._serverName = b"dummy"
self.clientproto = b"HTTP/1.0"
def getHeader(self, name):
"""
Retrieve the value of a request header.
@type name: C{bytes}
@param name: The name of the request header for which to retrieve the
value. Header names are compared case-insensitively.
@rtype: C{bytes} or L{NoneType}
@return: The value of the specified request header.
"""
return self.headers.get(name.lower(), None)
def getAllHeaders(self):
"""
Retrieve all the values of the request headers as a dictionary.
@return: The entire C{headers} L{dict}.
"""
return self.headers
def setHeader(self, name, value):
"""TODO: make this assert on write() if the header is content-length
"""
self.outgoingHeaders[name.lower()] = value
def getSession(self):
if self.session:
return self.session
assert not self.written, "Session cannot be requested after data has been written."
self.session = self.protoSession
return self.session
def render(self, resource):
"""
Render the given resource as a response to this request.
This implementation only handles a few of the most common behaviors of
resources. It can handle a render method that returns a string or
C{NOT_DONE_YET}. It doesn't know anything about the semantics of
request methods (eg HEAD) nor how to set any particular headers.
Basically, it's largely broken, but sufficient for some tests at least.
It should B{not} be expanded to do all the same stuff L{Request} does.
Instead, L{DummyRequest} should be phased out and L{Request} (or some
other real code factored in a different way) used.
"""
result = resource.render(self)
if result is NOT_DONE_YET:
return
self.write(result)
self.finish()
def write(self, data):
if not isinstance(data, bytes):
raise TypeError("write() only accepts bytes")
self.written.append(data)
def notifyFinish(self):
"""
Return a L{Deferred} which is called back with C{None} when the request
is finished. This will probably only work if you haven't called
C{finish} yet.
"""
finished = Deferred()
self._finishedDeferreds.append(finished)
return finished
def finish(self):
"""
Record that the request is finished and callback and L{Deferred}s
waiting for notification of this.
"""
self.finished = self.finished + 1
if self._finishedDeferreds is not None:
observers = self._finishedDeferreds
self._finishedDeferreds = None
for obs in observers:
obs.callback(None)
def processingFailed(self, reason):
"""
Errback and L{Deferreds} waiting for finish notification.
"""
if self._finishedDeferreds is not None:
observers = self._finishedDeferreds
self._finishedDeferreds = None
for obs in observers:
obs.errback(reason)
def addArg(self, name, value):
self.args[name] = [value]
def setResponseCode(self, code, message=None):
"""
Set the HTTP status response code, but takes care that this is called
before any data is written.
"""
assert not self.written, "Response code cannot be set after data has been written: %s." % "@@@@".join(self.written)
self.responseCode = code
self.responseMessage = message
def setLastModified(self, when):
assert not self.written, "Last-Modified cannot be set after data has been written: %s." % "@@@@".join(self.written)
def setETag(self, tag):
assert not self.written, "ETag cannot be set after data has been written: %s." % "@@@@".join(self.written)
def getClientIP(self):
"""
Return the IPv4 address of the client which made this request, if there
is one, otherwise C{None}.
"""
if isinstance(self.client, IPv4Address):
return self.client.host
return None
def getRequestHostname(self):
"""
Get a dummy hostname associated to the HTTP request.
@rtype: C{bytes}
@returns: a dummy hostname
"""
return self._serverName
def getHost(self):
"""
Get a dummy transport's host.
@rtype: C{IPv4Address}
@returns: a dummy transport's host
"""
return IPv4Address('TCP', '127.0.0.1', 80)
def getClient(self):
"""
Stub to get the client doing the HTTP request.
This merely just ensures that this method exists here. Feel free to
extend it.
"""

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,364 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.twcgi}.
"""
import sys, os
from twisted.trial import unittest
from twisted.internet import reactor, interfaces, error
from twisted.python import util, failure, log
from twisted.web.http import NOT_FOUND, INTERNAL_SERVER_ERROR
from twisted.web import client, twcgi, server, resource
from twisted.web.test._util import _render
from twisted.web.test.test_web import DummyRequest
DUMMY_CGI = '''\
print "Header: OK"
print
print "cgi output"
'''
DUAL_HEADER_CGI = '''\
print "Header: spam"
print "Header: eggs"
print
print "cgi output"
'''
BROKEN_HEADER_CGI = '''\
print "XYZ"
print
print "cgi output"
'''
SPECIAL_HEADER_CGI = '''\
print "Server: monkeys"
print "Date: last year"
print
print "cgi output"
'''
READINPUT_CGI = '''\
# this is an example of a correctly-written CGI script which reads a body
# from stdin, which only reads env['CONTENT_LENGTH'] bytes.
import os, sys
body_length = int(os.environ.get('CONTENT_LENGTH',0))
indata = sys.stdin.read(body_length)
print "Header: OK"
print
print "readinput ok"
'''
READALLINPUT_CGI = '''\
# this is an example of the typical (incorrect) CGI script which expects
# the server to close stdin when the body of the request is complete.
# A correct CGI should only read env['CONTENT_LENGTH'] bytes.
import sys
indata = sys.stdin.read()
print "Header: OK"
print
print "readallinput ok"
'''
NO_DUPLICATE_CONTENT_TYPE_HEADER_CGI = '''\
print "content-type: text/cgi-duplicate-test"
print
print "cgi output"
'''
class PythonScript(twcgi.FilteredScript):
filter = sys.executable
class CGI(unittest.TestCase):
"""
Tests for L{twcgi.FilteredScript}.
"""
if not interfaces.IReactorProcess.providedBy(reactor):
skip = "CGI tests require a functional reactor.spawnProcess()"
def startServer(self, cgi):
root = resource.Resource()
cgipath = util.sibpath(__file__, cgi)
root.putChild("cgi", PythonScript(cgipath))
site = server.Site(root)
self.p = reactor.listenTCP(0, site)
return self.p.getHost().port
def tearDown(self):
if getattr(self, 'p', None):
return self.p.stopListening()
def writeCGI(self, source):
cgiFilename = os.path.abspath(self.mktemp())
cgiFile = file(cgiFilename, 'wt')
cgiFile.write(source)
cgiFile.close()
return cgiFilename
def testCGI(self):
cgiFilename = self.writeCGI(DUMMY_CGI)
portnum = self.startServer(cgiFilename)
d = client.getPage("http://localhost:%d/cgi" % portnum)
d.addCallback(self._testCGI_1)
return d
def _testCGI_1(self, res):
self.assertEqual(res, "cgi output" + os.linesep)
def test_protectedServerAndDate(self):
"""
If the CGI script emits a I{Server} or I{Date} header, these are
ignored.
"""
cgiFilename = self.writeCGI(SPECIAL_HEADER_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
factory = client.HTTPClientFactory(url)
reactor.connectTCP('localhost', portnum, factory)
def checkResponse(ignored):
self.assertNotIn('monkeys', factory.response_headers['server'])
self.assertNotIn('last year', factory.response_headers['date'])
factory.deferred.addCallback(checkResponse)
return factory.deferred
def test_noDuplicateContentTypeHeaders(self):
"""
If the CGI script emits a I{content-type} header, make sure that the
server doesn't add an additional (duplicate) one, as per ticket 4786.
"""
cgiFilename = self.writeCGI(NO_DUPLICATE_CONTENT_TYPE_HEADER_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
factory = client.HTTPClientFactory(url)
reactor.connectTCP('localhost', portnum, factory)
def checkResponse(ignored):
self.assertEqual(
factory.response_headers['content-type'], ['text/cgi-duplicate-test'])
factory.deferred.addCallback(checkResponse)
return factory.deferred
def test_duplicateHeaderCGI(self):
"""
If a CGI script emits two instances of the same header, both are sent in
the response.
"""
cgiFilename = self.writeCGI(DUAL_HEADER_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
factory = client.HTTPClientFactory(url)
reactor.connectTCP('localhost', portnum, factory)
def checkResponse(ignored):
self.assertEqual(
factory.response_headers['header'], ['spam', 'eggs'])
factory.deferred.addCallback(checkResponse)
return factory.deferred
def test_malformedHeaderCGI(self):
"""
Check for the error message in the duplicated header
"""
cgiFilename = self.writeCGI(BROKEN_HEADER_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
factory = client.HTTPClientFactory(url)
reactor.connectTCP('localhost', portnum, factory)
loggedMessages = []
def addMessage(eventDict):
loggedMessages.append(log.textFromEventDict(eventDict))
log.addObserver(addMessage)
self.addCleanup(log.removeObserver, addMessage)
def checkResponse(ignored):
self.assertEqual(loggedMessages[0],
"ignoring malformed CGI header: 'XYZ'")
factory.deferred.addCallback(checkResponse)
return factory.deferred
def testReadEmptyInput(self):
cgiFilename = os.path.abspath(self.mktemp())
cgiFile = file(cgiFilename, 'wt')
cgiFile.write(READINPUT_CGI)
cgiFile.close()
portnum = self.startServer(cgiFilename)
d = client.getPage("http://localhost:%d/cgi" % portnum)
d.addCallback(self._testReadEmptyInput_1)
return d
testReadEmptyInput.timeout = 5
def _testReadEmptyInput_1(self, res):
self.assertEqual(res, "readinput ok%s" % os.linesep)
def testReadInput(self):
cgiFilename = os.path.abspath(self.mktemp())
cgiFile = file(cgiFilename, 'wt')
cgiFile.write(READINPUT_CGI)
cgiFile.close()
portnum = self.startServer(cgiFilename)
d = client.getPage("http://localhost:%d/cgi" % portnum,
method="POST",
postdata="Here is your stdin")
d.addCallback(self._testReadInput_1)
return d
testReadInput.timeout = 5
def _testReadInput_1(self, res):
self.assertEqual(res, "readinput ok%s" % os.linesep)
def testReadAllInput(self):
cgiFilename = os.path.abspath(self.mktemp())
cgiFile = file(cgiFilename, 'wt')
cgiFile.write(READALLINPUT_CGI)
cgiFile.close()
portnum = self.startServer(cgiFilename)
d = client.getPage("http://localhost:%d/cgi" % portnum,
method="POST",
postdata="Here is your stdin")
d.addCallback(self._testReadAllInput_1)
return d
testReadAllInput.timeout = 5
def _testReadAllInput_1(self, res):
self.assertEqual(res, "readallinput ok%s" % os.linesep)
def test_useReactorArgument(self):
"""
L{twcgi.FilteredScript.runProcess} uses the reactor passed as an
argument to the constructor.
"""
class FakeReactor:
"""
A fake reactor recording whether spawnProcess is called.
"""
called = False
def spawnProcess(self, *args, **kwargs):
"""
Set the C{called} flag to C{True} if C{spawnProcess} is called.
@param args: Positional arguments.
@param kwargs: Keyword arguements.
"""
self.called = True
fakeReactor = FakeReactor()
request = DummyRequest(['a', 'b'])
resource = twcgi.FilteredScript("dummy-file", reactor=fakeReactor)
_render(resource, request)
self.assertTrue(fakeReactor.called)
class CGIScriptTests(unittest.TestCase):
"""
Tests for L{twcgi.CGIScript}.
"""
def test_pathInfo(self):
"""
L{twcgi.CGIScript.render} sets the process environment I{PATH_INFO} from
the request path.
"""
class FakeReactor:
"""
A fake reactor recording the environment passed to spawnProcess.
"""
def spawnProcess(self, process, filename, args, env, wdir):
"""
Store the C{env} L{dict} to an instance attribute.
@param process: Ignored
@param filename: Ignored
@param args: Ignored
@param env: The environment L{dict} which will be stored
@param wdir: Ignored
"""
self.process_env = env
_reactor = FakeReactor()
resource = twcgi.CGIScript(self.mktemp(), reactor=_reactor)
request = DummyRequest(['a', 'b'])
_render(resource, request)
self.assertEqual(_reactor.process_env["PATH_INFO"],
"/a/b")
class CGIDirectoryTests(unittest.TestCase):
"""
Tests for L{twcgi.CGIDirectory}.
"""
def test_render(self):
"""
L{twcgi.CGIDirectory.render} sets the HTTP response code to I{NOT
FOUND}.
"""
resource = twcgi.CGIDirectory(self.mktemp())
request = DummyRequest([''])
d = _render(resource, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
def test_notFoundChild(self):
"""
L{twcgi.CGIDirectory.getChild} returns a resource which renders an
response with the HTTP I{NOT FOUND} status code if the indicated child
does not exist as an entry in the directory used to initialized the
L{twcgi.CGIDirectory}.
"""
path = self.mktemp()
os.makedirs(path)
resource = twcgi.CGIDirectory(path)
request = DummyRequest(['foo'])
child = resource.getChild("foo", request)
d = _render(child, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
class CGIProcessProtocolTests(unittest.TestCase):
"""
Tests for L{twcgi.CGIProcessProtocol}.
"""
def test_prematureEndOfHeaders(self):
"""
If the process communicating with L{CGIProcessProtocol} ends before
finishing writing out headers, the response has I{INTERNAL SERVER
ERROR} as its status code.
"""
request = DummyRequest([''])
protocol = twcgi.CGIProcessProtocol(request)
protocol.processEnded(failure.Failure(error.ProcessTerminated()))
self.assertEqual(request.responseCode, INTERNAL_SERVER_ERROR)

View file

@ -0,0 +1,434 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.distrib}.
"""
from os.path import abspath
from xml.dom.minidom import parseString
try:
import pwd
except ImportError:
pwd = None
from zope.interface.verify import verifyObject
from twisted.python import log, filepath
from twisted.internet import reactor, defer
from twisted.trial import unittest
from twisted.spread import pb
from twisted.spread.banana import SIZE_LIMIT
from twisted.web import http, distrib, client, resource, static, server
from twisted.web.test.test_web import DummyRequest
from twisted.web.test._util import _render
from twisted.test import proto_helpers
class MySite(server.Site):
pass
class PBServerFactory(pb.PBServerFactory):
"""
A PB server factory which keeps track of the most recent protocol it
created.
@ivar proto: L{None} or the L{Broker} instance most recently returned
from C{buildProtocol}.
"""
proto = None
def buildProtocol(self, addr):
self.proto = pb.PBServerFactory.buildProtocol(self, addr)
return self.proto
class DistribTest(unittest.TestCase):
port1 = None
port2 = None
sub = None
f1 = None
def tearDown(self):
"""
Clean up all the event sources left behind by either directly by
test methods or indirectly via some distrib API.
"""
dl = [defer.Deferred(), defer.Deferred()]
if self.f1 is not None and self.f1.proto is not None:
self.f1.proto.notifyOnDisconnect(lambda: dl[0].callback(None))
else:
dl[0].callback(None)
if self.sub is not None and self.sub.publisher is not None:
self.sub.publisher.broker.notifyOnDisconnect(
lambda: dl[1].callback(None))
self.sub.publisher.broker.transport.loseConnection()
else:
dl[1].callback(None)
if self.port1 is not None:
dl.append(self.port1.stopListening())
if self.port2 is not None:
dl.append(self.port2.stopListening())
return defer.gatherResults(dl)
def testDistrib(self):
# site1 is the publisher
r1 = resource.Resource()
r1.putChild("there", static.Data("root", "text/plain"))
site1 = server.Site(r1)
self.f1 = PBServerFactory(distrib.ResourcePublisher(site1))
self.port1 = reactor.listenTCP(0, self.f1)
self.sub = distrib.ResourceSubscription("127.0.0.1",
self.port1.getHost().port)
r2 = resource.Resource()
r2.putChild("here", self.sub)
f2 = MySite(r2)
self.port2 = reactor.listenTCP(0, f2)
d = client.getPage("http://127.0.0.1:%d/here/there" % \
self.port2.getHost().port)
d.addCallback(self.assertEqual, 'root')
return d
def _setupDistribServer(self, child):
"""
Set up a resource on a distrib site using L{ResourcePublisher}.
@param child: The resource to publish using distrib.
@return: A tuple consisting of the host and port on which to contact
the created site.
"""
distribRoot = resource.Resource()
distribRoot.putChild("child", child)
distribSite = server.Site(distribRoot)
self.f1 = distribFactory = PBServerFactory(
distrib.ResourcePublisher(distribSite))
distribPort = reactor.listenTCP(
0, distribFactory, interface="127.0.0.1")
self.addCleanup(distribPort.stopListening)
addr = distribPort.getHost()
self.sub = mainRoot = distrib.ResourceSubscription(
addr.host, addr.port)
mainSite = server.Site(mainRoot)
mainPort = reactor.listenTCP(0, mainSite, interface="127.0.0.1")
self.addCleanup(mainPort.stopListening)
mainAddr = mainPort.getHost()
return mainPort, mainAddr
def _requestTest(self, child, **kwargs):
"""
Set up a resource on a distrib site using L{ResourcePublisher} and
then retrieve it from a L{ResourceSubscription} via an HTTP client.
@param child: The resource to publish using distrib.
@param **kwargs: Extra keyword arguments to pass to L{getPage} when
requesting the resource.
@return: A L{Deferred} which fires with the result of the request.
"""
mainPort, mainAddr = self._setupDistribServer(child)
return client.getPage("http://%s:%s/child" % (
mainAddr.host, mainAddr.port), **kwargs)
def _requestAgentTest(self, child, **kwargs):
"""
Set up a resource on a distrib site using L{ResourcePublisher} and
then retrieve it from a L{ResourceSubscription} via an HTTP client.
@param child: The resource to publish using distrib.
@param **kwargs: Extra keyword arguments to pass to L{Agent.request} when
requesting the resource.
@return: A L{Deferred} which fires with a tuple consisting of a
L{twisted.test.proto_helpers.AccumulatingProtocol} containing the
body of the response and an L{IResponse} with the response itself.
"""
mainPort, mainAddr = self._setupDistribServer(child)
d = client.Agent(reactor).request("GET", "http://%s:%s/child" % (
mainAddr.host, mainAddr.port), **kwargs)
def cbCollectBody(response):
protocol = proto_helpers.AccumulatingProtocol()
response.deliverBody(protocol)
d = protocol.closedDeferred = defer.Deferred()
d.addCallback(lambda _: (protocol, response))
return d
d.addCallback(cbCollectBody)
return d
def test_requestHeaders(self):
"""
The request headers are available on the request object passed to a
distributed resource's C{render} method.
"""
requestHeaders = {}
class ReportRequestHeaders(resource.Resource):
def render(self, request):
requestHeaders.update(dict(
request.requestHeaders.getAllRawHeaders()))
return ""
request = self._requestTest(
ReportRequestHeaders(), headers={'foo': 'bar'})
def cbRequested(result):
self.assertEqual(requestHeaders['Foo'], ['bar'])
request.addCallback(cbRequested)
return request
def test_requestResponseCode(self):
"""
The response code can be set by the request object passed to a
distributed resource's C{render} method.
"""
class SetResponseCode(resource.Resource):
def render(self, request):
request.setResponseCode(200)
return ""
request = self._requestAgentTest(SetResponseCode())
def cbRequested(result):
self.assertEqual(result[0].data, "")
self.assertEqual(result[1].code, 200)
self.assertEqual(result[1].phrase, "OK")
request.addCallback(cbRequested)
return request
def test_requestResponseCodeMessage(self):
"""
The response code and message can be set by the request object passed to
a distributed resource's C{render} method.
"""
class SetResponseCode(resource.Resource):
def render(self, request):
request.setResponseCode(200, "some-message")
return ""
request = self._requestAgentTest(SetResponseCode())
def cbRequested(result):
self.assertEqual(result[0].data, "")
self.assertEqual(result[1].code, 200)
self.assertEqual(result[1].phrase, "some-message")
request.addCallback(cbRequested)
return request
def test_largeWrite(self):
"""
If a string longer than the Banana size limit is passed to the
L{distrib.Request} passed to the remote resource, it is broken into
smaller strings to be transported over the PB connection.
"""
class LargeWrite(resource.Resource):
def render(self, request):
request.write('x' * SIZE_LIMIT + 'y')
request.finish()
return server.NOT_DONE_YET
request = self._requestTest(LargeWrite())
request.addCallback(self.assertEqual, 'x' * SIZE_LIMIT + 'y')
return request
def test_largeReturn(self):
"""
Like L{test_largeWrite}, but for the case where C{render} returns a
long string rather than explicitly passing it to L{Request.write}.
"""
class LargeReturn(resource.Resource):
def render(self, request):
return 'x' * SIZE_LIMIT + 'y'
request = self._requestTest(LargeReturn())
request.addCallback(self.assertEqual, 'x' * SIZE_LIMIT + 'y')
return request
def test_connectionLost(self):
"""
If there is an error issuing the request to the remote publisher, an
error response is returned.
"""
# Using pb.Root as a publisher will cause request calls to fail with an
# error every time. Just what we want to test.
self.f1 = serverFactory = PBServerFactory(pb.Root())
self.port1 = serverPort = reactor.listenTCP(0, serverFactory)
self.sub = subscription = distrib.ResourceSubscription(
"127.0.0.1", serverPort.getHost().port)
request = DummyRequest([''])
d = _render(subscription, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, 500)
# This is the error we caused the request to fail with. It should
# have been logged.
self.assertEqual(len(self.flushLoggedErrors(pb.NoSuchMethod)), 1)
d.addCallback(cbRendered)
return d
class _PasswordDatabase:
def __init__(self, users):
self._users = users
def getpwall(self):
return iter(self._users)
def getpwnam(self, username):
for user in self._users:
if user[0] == username:
return user
raise KeyError()
class UserDirectoryTests(unittest.TestCase):
"""
Tests for L{UserDirectory}, a resource for listing all user resources
available on a system.
"""
def setUp(self):
self.alice = ('alice', 'x', 123, 456, 'Alice,,,', self.mktemp(), '/bin/sh')
self.bob = ('bob', 'x', 234, 567, 'Bob,,,', self.mktemp(), '/bin/sh')
self.database = _PasswordDatabase([self.alice, self.bob])
self.directory = distrib.UserDirectory(self.database)
def test_interface(self):
"""
L{UserDirectory} instances provide L{resource.IResource}.
"""
self.assertTrue(verifyObject(resource.IResource, self.directory))
def _404Test(self, name):
"""
Verify that requesting the C{name} child of C{self.directory} results
in a 404 response.
"""
request = DummyRequest([name])
result = self.directory.getChild(name, request)
d = _render(result, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, 404)
d.addCallback(cbRendered)
return d
def test_getInvalidUser(self):
"""
L{UserDirectory.getChild} returns a resource which renders a 404
response when passed a string which does not correspond to any known
user.
"""
return self._404Test('carol')
def test_getUserWithoutResource(self):
"""
L{UserDirectory.getChild} returns a resource which renders a 404
response when passed a string which corresponds to a known user who has
neither a user directory nor a user distrib socket.
"""
return self._404Test('alice')
def test_getPublicHTMLChild(self):
"""
L{UserDirectory.getChild} returns a L{static.File} instance when passed
the name of a user with a home directory containing a I{public_html}
directory.
"""
home = filepath.FilePath(self.bob[-2])
public_html = home.child('public_html')
public_html.makedirs()
request = DummyRequest(['bob'])
result = self.directory.getChild('bob', request)
self.assertIsInstance(result, static.File)
self.assertEqual(result.path, public_html.path)
def test_getDistribChild(self):
"""
L{UserDirectory.getChild} returns a L{ResourceSubscription} instance
when passed the name of a user suffixed with C{".twistd"} who has a
home directory containing a I{.twistd-web-pb} socket.
"""
home = filepath.FilePath(self.bob[-2])
home.makedirs()
web = home.child('.twistd-web-pb')
request = DummyRequest(['bob'])
result = self.directory.getChild('bob.twistd', request)
self.assertIsInstance(result, distrib.ResourceSubscription)
self.assertEqual(result.host, 'unix')
self.assertEqual(abspath(result.port), web.path)
def test_invalidMethod(self):
"""
L{UserDirectory.render} raises L{UnsupportedMethod} in response to a
non-I{GET} request.
"""
request = DummyRequest([''])
request.method = 'POST'
self.assertRaises(
server.UnsupportedMethod, self.directory.render, request)
def test_render(self):
"""
L{UserDirectory} renders a list of links to available user content
in response to a I{GET} request.
"""
public_html = filepath.FilePath(self.alice[-2]).child('public_html')
public_html.makedirs()
web = filepath.FilePath(self.bob[-2])
web.makedirs()
# This really only works if it's a unix socket, but the implementation
# doesn't currently check for that. It probably should someday, and
# then skip users with non-sockets.
web.child('.twistd-web-pb').setContent("")
request = DummyRequest([''])
result = _render(self.directory, request)
def cbRendered(ignored):
document = parseString(''.join(request.written))
# Each user should have an li with a link to their page.
[alice, bob] = document.getElementsByTagName('li')
self.assertEqual(alice.firstChild.tagName, 'a')
self.assertEqual(alice.firstChild.getAttribute('href'), 'alice/')
self.assertEqual(alice.firstChild.firstChild.data, 'Alice (file)')
self.assertEqual(bob.firstChild.tagName, 'a')
self.assertEqual(bob.firstChild.getAttribute('href'), 'bob.twistd/')
self.assertEqual(bob.firstChild.firstChild.data, 'Bob (twistd)')
result.addCallback(cbRendered)
return result
def test_passwordDatabase(self):
"""
If L{UserDirectory} is instantiated with no arguments, it uses the
L{pwd} module as its password database.
"""
directory = distrib.UserDirectory()
self.assertIdentical(directory._pwd, pwd)
if pwd is None:
test_passwordDatabase.skip = "pwd module required"

View file

@ -0,0 +1,306 @@
# -*- test-case-name: twisted.web.test.test_domhelpers -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Specific tests for (some of) the methods in L{twisted.web.domhelpers}.
"""
from xml.dom import minidom
from twisted.trial.unittest import TestCase
from twisted.web import microdom
from twisted.web import domhelpers
class DOMHelpersTestsMixin:
"""
A mixin for L{TestCase} subclasses which defines test methods for
domhelpers functionality based on a DOM creation function provided by a
subclass.
"""
dom = None
def test_getElementsByTagName(self):
doc1 = self.dom.parseString('<foo/>')
actual=domhelpers.getElementsByTagName(doc1, 'foo')[0].nodeName
expected='foo'
self.assertEqual(actual, expected)
el1=doc1.documentElement
actual=domhelpers.getElementsByTagName(el1, 'foo')[0].nodeName
self.assertEqual(actual, expected)
doc2_xml='<a><foo in="a"/><b><foo in="b"/></b><c><foo in="c"/></c><foo in="d"/><foo in="ef"/><g><foo in="g"/><h><foo in="h"/></h></g></a>'
doc2 = self.dom.parseString(doc2_xml)
tag_list=domhelpers.getElementsByTagName(doc2, 'foo')
actual=''.join([node.getAttribute('in') for node in tag_list])
expected='abcdefgh'
self.assertEqual(actual, expected)
el2=doc2.documentElement
tag_list=domhelpers.getElementsByTagName(el2, 'foo')
actual=''.join([node.getAttribute('in') for node in tag_list])
self.assertEqual(actual, expected)
doc3_xml='''
<a><foo in="a"/>
<b><foo in="b"/>
<d><foo in="d"/>
<g><foo in="g"/></g>
<h><foo in="h"/></h>
</d>
<e><foo in="e"/>
<i><foo in="i"/></i>
</e>
</b>
<c><foo in="c"/>
<f><foo in="f"/>
<j><foo in="j"/></j>
</f>
</c>
</a>'''
doc3 = self.dom.parseString(doc3_xml)
tag_list=domhelpers.getElementsByTagName(doc3, 'foo')
actual=''.join([node.getAttribute('in') for node in tag_list])
expected='abdgheicfj'
self.assertEqual(actual, expected)
el3=doc3.documentElement
tag_list=domhelpers.getElementsByTagName(el3, 'foo')
actual=''.join([node.getAttribute('in') for node in tag_list])
self.assertEqual(actual, expected)
doc4_xml='<foo><bar></bar><baz><foo/></baz></foo>'
doc4 = self.dom.parseString(doc4_xml)
actual=domhelpers.getElementsByTagName(doc4, 'foo')
root=doc4.documentElement
expected=[root, root.childNodes[-1].childNodes[0]]
self.assertEqual(actual, expected)
actual=domhelpers.getElementsByTagName(root, 'foo')
self.assertEqual(actual, expected)
def test_gatherTextNodes(self):
doc1 = self.dom.parseString('<a>foo</a>')
actual=domhelpers.gatherTextNodes(doc1)
expected='foo'
self.assertEqual(actual, expected)
actual=domhelpers.gatherTextNodes(doc1.documentElement)
self.assertEqual(actual, expected)
doc2_xml='<a>a<b>b</b><c>c</c>def<g>g<h>h</h></g></a>'
doc2 = self.dom.parseString(doc2_xml)
actual=domhelpers.gatherTextNodes(doc2)
expected='abcdefgh'
self.assertEqual(actual, expected)
actual=domhelpers.gatherTextNodes(doc2.documentElement)
self.assertEqual(actual, expected)
doc3_xml=('<a>a<b>b<d>d<g>g</g><h>h</h></d><e>e<i>i</i></e></b>' +
'<c>c<f>f<j>j</j></f></c></a>')
doc3 = self.dom.parseString(doc3_xml)
actual=domhelpers.gatherTextNodes(doc3)
expected='abdgheicfj'
self.assertEqual(actual, expected)
actual=domhelpers.gatherTextNodes(doc3.documentElement)
self.assertEqual(actual, expected)
def test_clearNode(self):
doc1 = self.dom.parseString('<a><b><c><d/></c></b></a>')
a_node=doc1.documentElement
domhelpers.clearNode(a_node)
self.assertEqual(
a_node.toxml(),
self.dom.Element('a').toxml())
doc2 = self.dom.parseString('<a><b><c><d/></c></b></a>')
b_node=doc2.documentElement.childNodes[0]
domhelpers.clearNode(b_node)
actual=doc2.documentElement.toxml()
expected = self.dom.Element('a')
expected.appendChild(self.dom.Element('b'))
self.assertEqual(actual, expected.toxml())
def test_get(self):
doc1 = self.dom.parseString('<a><b id="bar"/><c class="foo"/></a>')
node=domhelpers.get(doc1, "foo")
actual=node.toxml()
expected = self.dom.Element('c')
expected.setAttribute('class', 'foo')
self.assertEqual(actual, expected.toxml())
node=domhelpers.get(doc1, "bar")
actual=node.toxml()
expected = self.dom.Element('b')
expected.setAttribute('id', 'bar')
self.assertEqual(actual, expected.toxml())
self.assertRaises(domhelpers.NodeLookupError,
domhelpers.get,
doc1,
"pzork")
def test_getIfExists(self):
doc1 = self.dom.parseString('<a><b id="bar"/><c class="foo"/></a>')
node=domhelpers.getIfExists(doc1, "foo")
actual=node.toxml()
expected = self.dom.Element('c')
expected.setAttribute('class', 'foo')
self.assertEqual(actual, expected.toxml())
node=domhelpers.getIfExists(doc1, "pzork")
self.assertIdentical(node, None)
def test_getAndClear(self):
doc1 = self.dom.parseString('<a><b id="foo"><c></c></b></a>')
node=domhelpers.getAndClear(doc1, "foo")
actual=node.toxml()
expected = self.dom.Element('b')
expected.setAttribute('id', 'foo')
self.assertEqual(actual, expected.toxml())
def test_locateNodes(self):
doc1 = self.dom.parseString('<a><b foo="olive"><c foo="olive"/></b><d foo="poopy"/></a>')
node_list=domhelpers.locateNodes(
doc1.childNodes, 'foo', 'olive', noNesting=1)
actual=''.join([node.toxml() for node in node_list])
expected = self.dom.Element('b')
expected.setAttribute('foo', 'olive')
c = self.dom.Element('c')
c.setAttribute('foo', 'olive')
expected.appendChild(c)
self.assertEqual(actual, expected.toxml())
node_list=domhelpers.locateNodes(
doc1.childNodes, 'foo', 'olive', noNesting=0)
actual=''.join([node.toxml() for node in node_list])
self.assertEqual(actual, expected.toxml() + c.toxml())
def test_getParents(self):
doc1 = self.dom.parseString('<a><b><c><d/></c><e/></b><f/></a>')
node_list = domhelpers.getParents(
doc1.childNodes[0].childNodes[0].childNodes[0])
actual = ''.join([node.tagName for node in node_list
if hasattr(node, 'tagName')])
self.assertEqual(actual, 'cba')
def test_findElementsWithAttribute(self):
doc1 = self.dom.parseString('<a foo="1"><b foo="2"/><c foo="1"/><d/></a>')
node_list = domhelpers.findElementsWithAttribute(doc1, 'foo')
actual = ''.join([node.tagName for node in node_list])
self.assertEqual(actual, 'abc')
node_list = domhelpers.findElementsWithAttribute(doc1, 'foo', '1')
actual = ''.join([node.tagName for node in node_list])
self.assertEqual(actual, 'ac')
def test_findNodesNamed(self):
doc1 = self.dom.parseString('<doc><foo/><bar/><foo>a</foo></doc>')
node_list = domhelpers.findNodesNamed(doc1, 'foo')
actual = len(node_list)
self.assertEqual(actual, 2)
# NOT SURE WHAT THESE ARE SUPPOSED TO DO..
# def test_RawText FIXME
# def test_superSetAttribute FIXME
# def test_superPrependAttribute FIXME
# def test_superAppendAttribute FIXME
# def test_substitute FIXME
def test_escape(self):
j='this string " contains many & characters> xml< won\'t like'
expected='this string &quot; contains many &amp; characters&gt; xml&lt; won\'t like'
self.assertEqual(domhelpers.escape(j), expected)
def test_unescape(self):
j='this string &quot; has &&amp; entities &gt; &lt; and some characters xml won\'t like<'
expected='this string " has && entities > < and some characters xml won\'t like<'
self.assertEqual(domhelpers.unescape(j), expected)
def test_getNodeText(self):
"""
L{getNodeText} returns the concatenation of all the text data at or
beneath the node passed to it.
"""
node = self.dom.parseString('<foo><bar>baz</bar><bar>quux</bar></foo>')
self.assertEqual(domhelpers.getNodeText(node), "bazquux")
class MicroDOMHelpersTests(DOMHelpersTestsMixin, TestCase):
dom = microdom
def test_gatherTextNodesDropsWhitespace(self):
"""
Microdom discards whitespace-only text nodes, so L{gatherTextNodes}
returns only the text from nodes which had non-whitespace characters.
"""
doc4_xml='''<html>
<head>
</head>
<body>
stuff
</body>
</html>
'''
doc4 = self.dom.parseString(doc4_xml)
actual = domhelpers.gatherTextNodes(doc4)
expected = '\n stuff\n '
self.assertEqual(actual, expected)
actual = domhelpers.gatherTextNodes(doc4.documentElement)
self.assertEqual(actual, expected)
def test_textEntitiesNotDecoded(self):
"""
Microdom does not decode entities in text nodes.
"""
doc5_xml='<x>Souffl&amp;</x>'
doc5 = self.dom.parseString(doc5_xml)
actual=domhelpers.gatherTextNodes(doc5)
expected='Souffl&amp;'
self.assertEqual(actual, expected)
actual=domhelpers.gatherTextNodes(doc5.documentElement)
self.assertEqual(actual, expected)
class MiniDOMHelpersTests(DOMHelpersTestsMixin, TestCase):
dom = minidom
def test_textEntitiesDecoded(self):
"""
Minidom does decode entities in text nodes.
"""
doc5_xml='<x>Souffl&amp;</x>'
doc5 = self.dom.parseString(doc5_xml)
actual=domhelpers.gatherTextNodes(doc5)
expected='Souffl&'
self.assertEqual(actual, expected)
actual=domhelpers.gatherTextNodes(doc5.documentElement)
self.assertEqual(actual, expected)
def test_getNodeUnicodeText(self):
"""
L{domhelpers.getNodeText} returns a C{unicode} string when text
nodes are represented in the DOM with unicode, whether or not there
are non-ASCII characters present.
"""
node = self.dom.parseString("<foo>bar</foo>")
text = domhelpers.getNodeText(node)
self.assertEqual(text, u"bar")
self.assertIsInstance(text, unicode)
node = self.dom.parseString(u"<foo>\N{SNOWMAN}</foo>".encode('utf-8'))
text = domhelpers.getNodeText(node)
self.assertEqual(text, u"\N{SNOWMAN}")
self.assertIsInstance(text, unicode)

View file

@ -0,0 +1,151 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
HTTP errors.
"""
from twisted.trial import unittest
from twisted.web import error
class ErrorTestCase(unittest.TestCase):
"""
Tests for how L{Error} attributes are initialized.
"""
def test_noMessageValidStatus(self):
"""
If no C{message} argument is passed to the L{Error} constructor and the
C{code} argument is a valid HTTP status code, C{code} is mapped to a
descriptive string to which C{message} is assigned.
"""
e = error.Error("200")
self.assertEqual(e.message, "OK")
def test_noMessageInvalidStatus(self):
"""
If no C{message} argument is passed to the L{Error} constructor and
C{code} isn't a valid HTTP status code, C{message} stays C{None}.
"""
e = error.Error("InvalidCode")
self.assertEqual(e.message, None)
def test_messageExists(self):
"""
If a C{message} argument is passed to the L{Error} constructor, the
C{message} isn't affected by the value of C{status}.
"""
e = error.Error("200", "My own message")
self.assertEqual(e.message, "My own message")
class PageRedirectTestCase(unittest.TestCase):
"""
Tests for how L{PageRedirect} attributes are initialized.
"""
def test_noMessageValidStatus(self):
"""
If no C{message} argument is passed to the L{PageRedirect} constructor
and the C{code} argument is a valid HTTP status code, C{code} is mapped
to a descriptive string to which C{message} is assigned.
"""
e = error.PageRedirect("200", location="/foo")
self.assertEqual(e.message, "OK to /foo")
def test_noMessageValidStatusNoLocation(self):
"""
If no C{message} argument is passed to the L{PageRedirect} constructor
and C{location} is also empty and the C{code} argument is a valid HTTP
status code, C{code} is mapped to a descriptive string to which
C{message} is assigned without trying to include an empty location.
"""
e = error.PageRedirect("200")
self.assertEqual(e.message, "OK")
def test_noMessageInvalidStatusLocationExists(self):
"""
If no C{message} argument is passed to the L{PageRedirect} constructor
and C{code} isn't a valid HTTP status code, C{message} stays C{None}.
"""
e = error.PageRedirect("InvalidCode", location="/foo")
self.assertEqual(e.message, None)
def test_messageExistsLocationExists(self):
"""
If a C{message} argument is passed to the L{PageRedirect} constructor,
the C{message} isn't affected by the value of C{status}.
"""
e = error.PageRedirect("200", "My own message", location="/foo")
self.assertEqual(e.message, "My own message to /foo")
def test_messageExistsNoLocation(self):
"""
If a C{message} argument is passed to the L{PageRedirect} constructor
and no location is provided, C{message} doesn't try to include the empty
location.
"""
e = error.PageRedirect("200", "My own message")
self.assertEqual(e.message, "My own message")
class InfiniteRedirectionTestCase(unittest.TestCase):
"""
Tests for how L{InfiniteRedirection} attributes are initialized.
"""
def test_noMessageValidStatus(self):
"""
If no C{message} argument is passed to the L{InfiniteRedirection}
constructor and the C{code} argument is a valid HTTP status code,
C{code} is mapped to a descriptive string to which C{message} is
assigned.
"""
e = error.InfiniteRedirection("200", location="/foo")
self.assertEqual(e.message, "OK to /foo")
def test_noMessageValidStatusNoLocation(self):
"""
If no C{message} argument is passed to the L{InfiniteRedirection}
constructor and C{location} is also empty and the C{code} argument is a
valid HTTP status code, C{code} is mapped to a descriptive string to
which C{message} is assigned without trying to include an empty
location.
"""
e = error.InfiniteRedirection("200")
self.assertEqual(e.message, "OK")
def test_noMessageInvalidStatusLocationExists(self):
"""
If no C{message} argument is passed to the L{InfiniteRedirection}
constructor and C{code} isn't a valid HTTP status code, C{message} stays
C{None}.
"""
e = error.InfiniteRedirection("InvalidCode", location="/foo")
self.assertEqual(e.message, None)
def test_messageExistsLocationExists(self):
"""
If a C{message} argument is passed to the L{InfiniteRedirection}
constructor, the C{message} isn't affected by the value of C{status}.
"""
e = error.InfiniteRedirection("200", "My own message", location="/foo")
self.assertEqual(e.message, "My own message to /foo")
def test_messageExistsNoLocation(self):
"""
If a C{message} argument is passed to the L{InfiniteRedirection}
constructor and no location is provided, C{message} doesn't try to
include the empty location.
"""
e = error.InfiniteRedirection("200", "My own message")
self.assertEqual(e.message, "My own message")

View file

@ -0,0 +1,554 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for the flattening portion of L{twisted.web.template}, implemented in
L{twisted.web._flatten}.
"""
import sys
import traceback
from xml.etree.cElementTree import XML
from zope.interface import implements, implementer
from twisted.trial.unittest import TestCase
from twisted.test.testutils import XMLAssertionMixin
from twisted.internet.defer import passthru, succeed, gatherResults
from twisted.web.iweb import IRenderable
from twisted.web.error import UnfilledSlot, UnsupportedType, FlattenerError
from twisted.web.template import tags, Tag, Comment, CDATA, CharRef, slot
from twisted.web.template import Element, renderer, TagLoader, flattenString
from twisted.web.test._util import FlattenTestCase
class OrderedAttributes(object):
"""
An L{OrderedAttributes} is a stand-in for the L{Tag.attributes} dictionary
that orders things in a deterministic order. It doesn't do any sorting, so
whatever order the attributes are passed in, they will be returned.
@ivar attributes: The result of a L{dict}C{.items} call.
@type attributes: L{list} of 2-L{tuples}
"""
def __init__(self, attributes):
self.attributes = attributes
def iteritems(self):
"""
Like L{dict}C{.iteritems}.
@return: an iterator
@rtype: list iterator
"""
return iter(self.attributes)
class TestSerialization(FlattenTestCase, XMLAssertionMixin):
"""
Tests for flattening various things.
"""
def test_nestedTags(self):
"""
Test that nested tags flatten correctly.
"""
return self.assertFlattensTo(
tags.html(tags.body('42'), hi='there'),
'<html hi="there"><body>42</body></html>')
def test_serializeString(self):
"""
Test that strings will be flattened and escaped correctly.
"""
return gatherResults([
self.assertFlattensTo('one', 'one'),
self.assertFlattensTo('<abc&&>123', '&lt;abc&amp;&amp;&gt;123'),
])
def test_serializeSelfClosingTags(self):
"""
The serialized form of a self-closing tag is C{'<tagName />'}.
"""
return self.assertFlattensTo(tags.img(), '<img />')
def test_serializeAttribute(self):
"""
The serialized form of attribute I{a} with value I{b} is C{'a="b"'}.
"""
self.assertFlattensImmediately(tags.img(src='foo'),
'<img src="foo" />')
def test_serializedMultipleAttributes(self):
"""
Multiple attributes are separated by a single space in their serialized
form.
"""
tag = tags.img()
tag.attributes = OrderedAttributes([("src", "foo"), ("name", "bar")])
self.assertFlattensImmediately(tag, '<img src="foo" name="bar" />')
def checkAttributeSanitization(self, wrapData, wrapTag):
"""
Common implementation of L{test_serializedAttributeWithSanitization}
and L{test_serializedDeferredAttributeWithSanitization},
L{test_serializedAttributeWithTransparentTag}.
@param wrapData: A 1-argument callable that wraps around the
attribute's value so other tests can customize it.
@param wrapData: callable taking L{bytes} and returning something
flattenable
@param wrapTag: A 1-argument callable that wraps around the outer tag
so other tests can customize it.
@type wrapTag: callable taking L{Tag} and returning L{Tag}.
"""
self.assertFlattensImmediately(
wrapTag(tags.img(src=wrapData("<>&\""))),
'<img src="&lt;&gt;&amp;&quot;" />')
def test_serializedAttributeWithSanitization(self):
"""
Attribute values containing C{"<"}, C{">"}, C{"&"}, or C{'"'} have
C{"&lt;"}, C{"&gt;"}, C{"&amp;"}, or C{"&quot;"} substituted for those
bytes in the serialized output.
"""
self.checkAttributeSanitization(passthru, passthru)
def test_serializedDeferredAttributeWithSanitization(self):
"""
Like L{test_serializedAttributeWithSanitization}, but when the contents
of the attribute are in a L{Deferred
<twisted.internet.defer.Deferred>}.
"""
self.checkAttributeSanitization(succeed, passthru)
def test_serializedAttributeWithSlotWithSanitization(self):
"""
Like L{test_serializedAttributeWithSanitization} but with a slot.
"""
toss = []
self.checkAttributeSanitization(
lambda value: toss.append(value) or slot("stuff"),
lambda tag: tag.fillSlots(stuff=toss.pop())
)
def test_serializedAttributeWithTransparentTag(self):
"""
Attribute values which are supplied via the value of a C{t:transparent}
tag have the same subsitution rules to them as values supplied
directly.
"""
self.checkAttributeSanitization(tags.transparent, passthru)
def test_serializedAttributeWithTransparentTagWithRenderer(self):
"""
Like L{test_serializedAttributeWithTransparentTag}, but when the
attribute is rendered by a renderer on an element.
"""
class WithRenderer(Element):
def __init__(self, value, loader):
self.value = value
super(WithRenderer, self).__init__(loader)
@renderer
def stuff(self, request, tag):
return self.value
toss = []
self.checkAttributeSanitization(
lambda value: toss.append(value) or
tags.transparent(render="stuff"),
lambda tag: WithRenderer(toss.pop(), TagLoader(tag))
)
def test_serializedAttributeWithRenderable(self):
"""
Like L{test_serializedAttributeWithTransparentTag}, but when the
attribute is a provider of L{IRenderable} rather than a transparent
tag.
"""
@implementer(IRenderable)
class Arbitrary(object):
def __init__(self, value):
self.value = value
def render(self, request):
return self.value
self.checkAttributeSanitization(Arbitrary, passthru)
def checkTagAttributeSerialization(self, wrapTag):
"""
Common implementation of L{test_serializedAttributeWithTag} and
L{test_serializedAttributeWithDeferredTag}.
@param wrapTag: A 1-argument callable that wraps around the attribute's
value so other tests can customize it.
@param wrapTag: callable taking L{Tag} and returning something
flattenable
"""
innerTag = tags.a('<>&"')
outerTag = tags.img(src=wrapTag(innerTag))
outer = self.assertFlattensImmediately(
outerTag,
'<img src="&lt;a&gt;&amp;lt;&amp;gt;&amp;amp;&quot;&lt;/a&gt;" />')
inner = self.assertFlattensImmediately(
innerTag, '<a>&lt;&gt;&amp;"</a>')
# Since the above quoting is somewhat tricky, validate it by making sure
# that the main use-case for tag-within-attribute is supported here: if
# we serialize a tag, it is quoted *such that it can be parsed out again
# as a tag*.
self.assertXMLEqual(XML(outer).attrib['src'], inner)
def test_serializedAttributeWithTag(self):
"""
L{Tag} objects which are serialized within the context of an attribute
are serialized such that the text content of the attribute may be
parsed to retrieve the tag.
"""
self.checkTagAttributeSerialization(passthru)
def test_serializedAttributeWithDeferredTag(self):
"""
Like L{test_serializedAttributeWithTag}, but when the L{Tag} is in a
L{Deferred <twisted.internet.defer.Deferred>}.
"""
self.checkTagAttributeSerialization(succeed)
def test_serializedAttributeWithTagWithAttribute(self):
"""
Similar to L{test_serializedAttributeWithTag}, but for the additional
complexity where the tag which is the attribute value itself has an
attribute value which contains bytes which require substitution.
"""
flattened = self.assertFlattensImmediately(
tags.img(src=tags.a(href='<>&"')),
'<img src="&lt;a href='
'&quot;&amp;lt;&amp;gt;&amp;amp;&amp;quot;&quot;&gt;'
'&lt;/a&gt;" />')
# As in checkTagAttributeSerialization, belt-and-suspenders:
self.assertXMLEqual(XML(flattened).attrib['src'],
'<a href="&lt;&gt;&amp;&quot;"></a>')
def test_serializeComment(self):
"""
Test that comments are correctly flattened and escaped.
"""
return self.assertFlattensTo(Comment('foo bar'), '<!--foo bar-->'),
def test_commentEscaping(self):
"""
The data in a L{Comment} is escaped and mangled in the flattened output
so that the result is a legal SGML and XML comment.
SGML comment syntax is complicated and hard to use. This rule is more
restrictive, and more compatible:
Comments start with <!-- and end with --> and never contain -- or >.
Also by XML syntax, a comment may not end with '-'.
@see: U{http://www.w3.org/TR/REC-xml/#sec-comments}
"""
def verifyComment(c):
self.assertTrue(
c.startswith('<!--'),
"%r does not start with the comment prefix" % (c,))
self.assertTrue(
c.endswith('-->'),
"%r does not end with the comment suffix" % (c,))
# If it is shorter than 7, then the prefix and suffix overlap
# illegally.
self.assertTrue(
len(c) >= 7,
"%r is too short to be a legal comment" % (c,))
content = c[4:-3]
self.assertNotIn('--', content)
self.assertNotIn('>', content)
if content:
self.assertNotEqual(content[-1], '-')
results = []
for c in [
'',
'foo---bar',
'foo---bar-',
'foo>bar',
'foo-->bar',
'----------------',
]:
d = flattenString(None, Comment(c))
d.addCallback(verifyComment)
results.append(d)
return gatherResults(results)
def test_serializeCDATA(self):
"""
Test that CDATA is correctly flattened and escaped.
"""
return gatherResults([
self.assertFlattensTo(CDATA('foo bar'), '<![CDATA[foo bar]]>'),
self.assertFlattensTo(
CDATA('foo ]]> bar'),
'<![CDATA[foo ]]]]><![CDATA[> bar]]>'),
])
def test_serializeUnicode(self):
"""
Test that unicode is encoded correctly in the appropriate places, and
raises an error when it occurs in inappropriate place.
"""
snowman = u'\N{SNOWMAN}'
return gatherResults([
self.assertFlattensTo(snowman, '\xe2\x98\x83'),
self.assertFlattensTo(tags.p(snowman), '<p>\xe2\x98\x83</p>'),
self.assertFlattensTo(Comment(snowman), '<!--\xe2\x98\x83-->'),
self.assertFlattensTo(CDATA(snowman), '<![CDATA[\xe2\x98\x83]]>'),
self.assertFlatteningRaises(
Tag(snowman), UnicodeEncodeError),
self.assertFlatteningRaises(
Tag('p', attributes={snowman: ''}), UnicodeEncodeError),
])
def test_serializeCharRef(self):
"""
A character reference is flattened to a string using the I{&#NNNN;}
syntax.
"""
ref = CharRef(ord(u"\N{SNOWMAN}"))
return self.assertFlattensTo(ref, "&#9731;")
def test_serializeDeferred(self):
"""
Test that a deferred is substituted with the current value in the
callback chain when flattened.
"""
return self.assertFlattensTo(succeed('two'), 'two')
def test_serializeSameDeferredTwice(self):
"""
Test that the same deferred can be flattened twice.
"""
d = succeed('three')
return gatherResults([
self.assertFlattensTo(d, 'three'),
self.assertFlattensTo(d, 'three'),
])
def test_serializeIRenderable(self):
"""
Test that flattening respects all of the IRenderable interface.
"""
class FakeElement(object):
implements(IRenderable)
def render(ign,ored):
return tags.p(
'hello, ',
tags.transparent(render='test'), ' - ',
tags.transparent(render='test'))
def lookupRenderMethod(ign, name):
self.assertEqual(name, 'test')
return lambda ign, node: node('world')
return gatherResults([
self.assertFlattensTo(FakeElement(), '<p>hello, world - world</p>'),
])
def test_serializeSlots(self):
"""
Test that flattening a slot will use the slot value from the tag.
"""
t1 = tags.p(slot('test'))
t2 = t1.clone()
t2.fillSlots(test='hello, world')
return gatherResults([
self.assertFlatteningRaises(t1, UnfilledSlot),
self.assertFlattensTo(t2, '<p>hello, world</p>'),
])
def test_serializeDeferredSlots(self):
"""
Test that a slot with a deferred as its value will be flattened using
the value from the deferred.
"""
t = tags.p(slot('test'))
t.fillSlots(test=succeed(tags.em('four>')))
return self.assertFlattensTo(t, '<p><em>four&gt;</em></p>')
def test_unknownTypeRaises(self):
"""
Test that flattening an unknown type of thing raises an exception.
"""
return self.assertFlatteningRaises(None, UnsupportedType)
# Use the co_filename mechanism (instead of the __file__ mechanism) because
# it is the mechanism traceback formatting uses. The two do not necessarily
# agree with each other. This requires a code object compiled in this file.
# The easiest way to get a code object is with a new function. I'll use a
# lambda to avoid adding anything else to this namespace. The result will
# be a string which agrees with the one the traceback module will put into a
# traceback for frames associated with functions defined in this file.
HERE = (lambda: None).func_code.co_filename
class FlattenerErrorTests(TestCase):
"""
Tests for L{FlattenerError}.
"""
def test_string(self):
"""
If a L{FlattenerError} is created with a string root, up to around 40
bytes from that string are included in the string representation of the
exception.
"""
self.assertEqual(
str(FlattenerError(RuntimeError("reason"), ['abc123xyz'], [])),
"Exception while flattening:\n"
" 'abc123xyz'\n"
"RuntimeError: reason\n")
self.assertEqual(
str(FlattenerError(
RuntimeError("reason"), ['0123456789' * 10], [])),
"Exception while flattening:\n"
" '01234567890123456789<...>01234567890123456789'\n"
"RuntimeError: reason\n")
def test_unicode(self):
"""
If a L{FlattenerError} is created with a unicode root, up to around 40
characters from that string are included in the string representation
of the exception.
"""
self.assertEqual(
str(FlattenerError(
RuntimeError("reason"), [u'abc\N{SNOWMAN}xyz'], [])),
"Exception while flattening:\n"
" u'abc\\u2603xyz'\n" # Codepoint for SNOWMAN
"RuntimeError: reason\n")
self.assertEqual(
str(FlattenerError(
RuntimeError("reason"), [u'01234567\N{SNOWMAN}9' * 10],
[])),
"Exception while flattening:\n"
" u'01234567\\u2603901234567\\u26039<...>01234567\\u2603901234567"
"\\u26039'\n"
"RuntimeError: reason\n")
def test_renderable(self):
"""
If a L{FlattenerError} is created with an L{IRenderable} provider root,
the repr of that object is included in the string representation of the
exception.
"""
class Renderable(object):
implements(IRenderable)
def __repr__(self):
return "renderable repr"
self.assertEqual(
str(FlattenerError(
RuntimeError("reason"), [Renderable()], [])),
"Exception while flattening:\n"
" renderable repr\n"
"RuntimeError: reason\n")
def test_tag(self):
"""
If a L{FlattenerError} is created with a L{Tag} instance with source
location information, the source location is included in the string
representation of the exception.
"""
tag = Tag(
'div', filename='/foo/filename.xhtml', lineNumber=17, columnNumber=12)
self.assertEqual(
str(FlattenerError(RuntimeError("reason"), [tag], [])),
"Exception while flattening:\n"
" File \"/foo/filename.xhtml\", line 17, column 12, in \"div\"\n"
"RuntimeError: reason\n")
def test_tagWithoutLocation(self):
"""
If a L{FlattenerError} is created with a L{Tag} instance without source
location information, only the tagName is included in the string
representation of the exception.
"""
self.assertEqual(
str(FlattenerError(RuntimeError("reason"), [Tag('span')], [])),
"Exception while flattening:\n"
" Tag <span>\n"
"RuntimeError: reason\n")
def test_traceback(self):
"""
If a L{FlattenerError} is created with traceback frames, they are
included in the string representation of the exception.
"""
# Try to be realistic in creating the data passed in for the traceback
# frames.
def f():
g()
def g():
raise RuntimeError("reason")
try:
f()
except RuntimeError, exc:
# Get the traceback, minus the info for *this* frame
tbinfo = traceback.extract_tb(sys.exc_info()[2])[1:]
else:
self.fail("f() must raise RuntimeError")
self.assertEqual(
str(FlattenerError(exc, [], tbinfo)),
"Exception while flattening:\n"
" File \"%s\", line %d, in f\n"
" g()\n"
" File \"%s\", line %d, in g\n"
" raise RuntimeError(\"reason\")\n"
"RuntimeError: reason\n" % (
HERE, f.func_code.co_firstlineno + 1,
HERE, g.func_code.co_firstlineno + 1))

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,631 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.http_headers}.
"""
from __future__ import division, absolute_import
import sys
from twisted.python.compat import _PY3
from twisted.trial.unittest import TestCase
from twisted.web.http_headers import _DictHeaders, Headers
class HeadersTests(TestCase):
"""
Tests for L{Headers}.
"""
def test_initializer(self):
"""
The header values passed to L{Headers.__init__} can be retrieved via
L{Headers.getRawHeaders}.
"""
h = Headers({b'Foo': [b'bar']})
self.assertEqual(h.getRawHeaders(b'foo'), [b'bar'])
def test_setRawHeaders(self):
"""
L{Headers.setRawHeaders} sets the header values for the given
header name to the sequence of byte string values.
"""
rawValue = [b"value1", b"value2"]
h = Headers()
h.setRawHeaders(b"test", rawValue)
self.assertTrue(h.hasHeader(b"test"))
self.assertTrue(h.hasHeader(b"Test"))
self.assertEqual(h.getRawHeaders(b"test"), rawValue)
def test_rawHeadersTypeChecking(self):
"""
L{Headers.setRawHeaders} requires values to be of type list.
"""
h = Headers()
self.assertRaises(TypeError, h.setRawHeaders, {b'Foo': b'bar'})
def test_addRawHeader(self):
"""
L{Headers.addRawHeader} adds a new value for a given header.
"""
h = Headers()
h.addRawHeader(b"test", b"lemur")
self.assertEqual(h.getRawHeaders(b"test"), [b"lemur"])
h.addRawHeader(b"test", b"panda")
self.assertEqual(h.getRawHeaders(b"test"), [b"lemur", b"panda"])
def test_getRawHeadersNoDefault(self):
"""
L{Headers.getRawHeaders} returns C{None} if the header is not found and
no default is specified.
"""
self.assertIdentical(Headers().getRawHeaders(b"test"), None)
def test_getRawHeadersDefaultValue(self):
"""
L{Headers.getRawHeaders} returns the specified default value when no
header is found.
"""
h = Headers()
default = object()
self.assertIdentical(h.getRawHeaders(b"test", default), default)
def test_getRawHeaders(self):
"""
L{Headers.getRawHeaders} returns the values which have been set for a
given header.
"""
h = Headers()
h.setRawHeaders(b"test", [b"lemur"])
self.assertEqual(h.getRawHeaders(b"test"), [b"lemur"])
self.assertEqual(h.getRawHeaders(b"Test"), [b"lemur"])
def test_hasHeaderTrue(self):
"""
Check that L{Headers.hasHeader} returns C{True} when the given header
is found.
"""
h = Headers()
h.setRawHeaders(b"test", [b"lemur"])
self.assertTrue(h.hasHeader(b"test"))
self.assertTrue(h.hasHeader(b"Test"))
def test_hasHeaderFalse(self):
"""
L{Headers.hasHeader} returns C{False} when the given header is not
found.
"""
self.assertFalse(Headers().hasHeader(b"test"))
def test_removeHeader(self):
"""
Check that L{Headers.removeHeader} removes the given header.
"""
h = Headers()
h.setRawHeaders(b"foo", [b"lemur"])
self.assertTrue(h.hasHeader(b"foo"))
h.removeHeader(b"foo")
self.assertFalse(h.hasHeader(b"foo"))
h.setRawHeaders(b"bar", [b"panda"])
self.assertTrue(h.hasHeader(b"bar"))
h.removeHeader(b"Bar")
self.assertFalse(h.hasHeader(b"bar"))
def test_removeHeaderDoesntExist(self):
"""
L{Headers.removeHeader} is a no-operation when the specified header is
not found.
"""
h = Headers()
h.removeHeader(b"test")
self.assertEqual(list(h.getAllRawHeaders()), [])
def test_canonicalNameCaps(self):
"""
L{Headers._canonicalNameCaps} returns the canonical capitalization for
the given header.
"""
h = Headers()
self.assertEqual(h._canonicalNameCaps(b"test"), b"Test")
self.assertEqual(h._canonicalNameCaps(b"test-stuff"), b"Test-Stuff")
self.assertEqual(h._canonicalNameCaps(b"content-md5"), b"Content-MD5")
self.assertEqual(h._canonicalNameCaps(b"dnt"), b"DNT")
self.assertEqual(h._canonicalNameCaps(b"etag"), b"ETag")
self.assertEqual(h._canonicalNameCaps(b"p3p"), b"P3P")
self.assertEqual(h._canonicalNameCaps(b"te"), b"TE")
self.assertEqual(h._canonicalNameCaps(b"www-authenticate"),
b"WWW-Authenticate")
self.assertEqual(h._canonicalNameCaps(b"x-xss-protection"),
b"X-XSS-Protection")
def test_getAllRawHeaders(self):
"""
L{Headers.getAllRawHeaders} returns an iterable of (k, v) pairs, where
C{k} is the canonicalized representation of the header name, and C{v}
is a sequence of values.
"""
h = Headers()
h.setRawHeaders(b"test", [b"lemurs"])
h.setRawHeaders(b"www-authenticate", [b"basic aksljdlk="])
allHeaders = set([(k, tuple(v)) for k, v in h.getAllRawHeaders()])
self.assertEqual(allHeaders,
set([(b"WWW-Authenticate", (b"basic aksljdlk=",)),
(b"Test", (b"lemurs",))]))
def test_headersComparison(self):
"""
A L{Headers} instance compares equal to itself and to another
L{Headers} instance with the same values.
"""
first = Headers()
first.setRawHeaders(b"foo", [b"panda"])
second = Headers()
second.setRawHeaders(b"foo", [b"panda"])
third = Headers()
third.setRawHeaders(b"foo", [b"lemur", b"panda"])
self.assertEqual(first, first)
self.assertEqual(first, second)
self.assertNotEqual(first, third)
def test_otherComparison(self):
"""
An instance of L{Headers} does not compare equal to other unrelated
objects.
"""
h = Headers()
self.assertNotEqual(h, ())
self.assertNotEqual(h, object())
self.assertNotEqual(h, b"foo")
def test_repr(self):
"""
The L{repr} of a L{Headers} instance shows the names and values of all
the headers it contains.
"""
foo = b"foo"
bar = b"bar"
baz = b"baz"
self.assertEqual(
repr(Headers({foo: [bar, baz]})),
"Headers({%r: [%r, %r]})" % (foo, bar, baz))
def test_subclassRepr(self):
"""
The L{repr} of an instance of a subclass of L{Headers} uses the name
of the subclass instead of the string C{"Headers"}.
"""
foo = b"foo"
bar = b"bar"
baz = b"baz"
class FunnyHeaders(Headers):
pass
self.assertEqual(
repr(FunnyHeaders({foo: [bar, baz]})),
"FunnyHeaders({%r: [%r, %r]})" % (foo, bar, baz))
def test_copy(self):
"""
L{Headers.copy} creates a new independant copy of an existing
L{Headers} instance, allowing future modifications without impacts
between the copies.
"""
h = Headers()
h.setRawHeaders(b'test', [b'foo'])
i = h.copy()
self.assertEqual(i.getRawHeaders(b'test'), [b'foo'])
h.addRawHeader(b'test', b'bar')
self.assertEqual(i.getRawHeaders(b'test'), [b'foo'])
i.addRawHeader(b'test', b'baz')
self.assertEqual(h.getRawHeaders(b'test'), [b'foo', b'bar'])
class HeaderDictTests(TestCase):
"""
Tests for the backwards compatible C{dict} interface for L{Headers}
provided by L{_DictHeaders}.
"""
def headers(self, **kw):
"""
Create a L{Headers} instance populated with the header name/values
specified by C{kw} and a L{_DictHeaders} wrapped around it and return
them both.
"""
h = Headers()
for k, v in kw.items():
h.setRawHeaders(k.encode('ascii'), v)
return h, _DictHeaders(h)
def test_getItem(self):
"""
L{_DictHeaders.__getitem__} returns a single header for the given name.
"""
headers, wrapper = self.headers(test=[b"lemur"])
self.assertEqual(wrapper[b"test"], b"lemur")
def test_getItemMultiple(self):
"""
L{_DictHeaders.__getitem__} returns only the last header value for a
given name.
"""
headers, wrapper = self.headers(test=[b"lemur", b"panda"])
self.assertEqual(wrapper[b"test"], b"panda")
def test_getItemMissing(self):
"""
L{_DictHeaders.__getitem__} raises L{KeyError} if called with a header
which is not present.
"""
headers, wrapper = self.headers()
exc = self.assertRaises(KeyError, wrapper.__getitem__, b"test")
self.assertEqual(exc.args, (b"test",))
def test_iteration(self):
"""
L{_DictHeaders.__iter__} returns an iterator the elements of which
are the lowercase name of each header present.
"""
headers, wrapper = self.headers(foo=[b"lemur", b"panda"], bar=[b"baz"])
self.assertEqual(set(list(wrapper)), set([b"foo", b"bar"]))
def test_length(self):
"""
L{_DictHeaders.__len__} returns the number of headers present.
"""
headers, wrapper = self.headers()
self.assertEqual(len(wrapper), 0)
headers.setRawHeaders(b"foo", [b"bar"])
self.assertEqual(len(wrapper), 1)
headers.setRawHeaders(b"test", [b"lemur", b"panda"])
self.assertEqual(len(wrapper), 2)
def test_setItem(self):
"""
L{_DictHeaders.__setitem__} sets a single header value for the given
name.
"""
headers, wrapper = self.headers()
wrapper[b"test"] = b"lemur"
self.assertEqual(headers.getRawHeaders(b"test"), [b"lemur"])
def test_setItemOverwrites(self):
"""
L{_DictHeaders.__setitem__} will replace any previous header values for
the given name.
"""
headers, wrapper = self.headers(test=[b"lemur", b"panda"])
wrapper[b"test"] = b"lemur"
self.assertEqual(headers.getRawHeaders(b"test"), [b"lemur"])
def test_delItem(self):
"""
L{_DictHeaders.__delitem__} will remove the header values for the given
name.
"""
headers, wrapper = self.headers(test=[b"lemur"])
del wrapper[b"test"]
self.assertFalse(headers.hasHeader(b"test"))
def test_delItemMissing(self):
"""
L{_DictHeaders.__delitem__} will raise L{KeyError} if the given name is
not present.
"""
headers, wrapper = self.headers()
exc = self.assertRaises(KeyError, wrapper.__delitem__, b"test")
self.assertEqual(exc.args, (b"test",))
def test_keys(self, _method='keys', _requireList=not _PY3):
"""
L{_DictHeaders.keys} will return a list of all present header names.
"""
headers, wrapper = self.headers(test=[b"lemur"], foo=[b"bar"])
keys = getattr(wrapper, _method)()
if _requireList:
self.assertIsInstance(keys, list)
self.assertEqual(set(keys), set([b"foo", b"test"]))
def test_iterkeys(self):
"""
L{_DictHeaders.iterkeys} will return all present header names.
"""
self.test_keys('iterkeys', False)
def test_values(self, _method='values', _requireList=not _PY3):
"""
L{_DictHeaders.values} will return a list of all present header values,
returning only the last value for headers with more than one.
"""
headers, wrapper = self.headers(
foo=[b"lemur"], bar=[b"marmot", b"panda"])
values = getattr(wrapper, _method)()
if _requireList:
self.assertIsInstance(values, list)
self.assertEqual(set(values), set([b"lemur", b"panda"]))
def test_itervalues(self):
"""
L{_DictHeaders.itervalues} will return all present header values,
returning only the last value for headers with more than one.
"""
self.test_values('itervalues', False)
def test_items(self, _method='items', _requireList=not _PY3):
"""
L{_DictHeaders.items} will return a list of all present header names
and values as tuples, returning only the last value for headers with
more than one.
"""
headers, wrapper = self.headers(
foo=[b"lemur"], bar=[b"marmot", b"panda"])
items = getattr(wrapper, _method)()
if _requireList:
self.assertIsInstance(items, list)
self.assertEqual(
set(items), set([(b"foo", b"lemur"), (b"bar", b"panda")]))
def test_iteritems(self):
"""
L{_DictHeaders.iteritems} will return all present header names and
values as tuples, returning only the last value for headers with more
than one.
"""
self.test_items('iteritems', False)
def test_clear(self):
"""
L{_DictHeaders.clear} will remove all headers.
"""
headers, wrapper = self.headers(foo=[b"lemur"], bar=[b"panda"])
wrapper.clear()
self.assertEqual(list(headers.getAllRawHeaders()), [])
def test_copy(self):
"""
L{_DictHeaders.copy} will return a C{dict} with all the same headers
and the last value for each.
"""
headers, wrapper = self.headers(
foo=[b"lemur", b"panda"], bar=[b"marmot"])
duplicate = wrapper.copy()
self.assertEqual(duplicate, {b"foo": b"panda", b"bar": b"marmot"})
def test_get(self):
"""
L{_DictHeaders.get} returns the last value for the given header name.
"""
headers, wrapper = self.headers(foo=[b"lemur", b"panda"])
self.assertEqual(wrapper.get(b"foo"), b"panda")
def test_getMissing(self):
"""
L{_DictHeaders.get} returns C{None} for a header which is not present.
"""
headers, wrapper = self.headers()
self.assertIdentical(wrapper.get(b"foo"), None)
def test_getDefault(self):
"""
L{_DictHeaders.get} returns the last value for the given header name
even when it is invoked with a default value.
"""
headers, wrapper = self.headers(foo=[b"lemur"])
self.assertEqual(wrapper.get(b"foo", b"bar"), b"lemur")
def test_getDefaultMissing(self):
"""
L{_DictHeaders.get} returns the default value specified if asked for a
header which is not present.
"""
headers, wrapper = self.headers()
self.assertEqual(wrapper.get(b"foo", b"bar"), b"bar")
def test_has_key(self):
"""
L{_DictHeaders.has_key} returns C{True} if the given header is present,
C{False} otherwise.
"""
headers, wrapper = self.headers(foo=[b"lemur"])
self.assertTrue(wrapper.has_key(b"foo"))
self.assertFalse(wrapper.has_key(b"bar"))
def test_contains(self):
"""
L{_DictHeaders.__contains__} returns C{True} if the given header is
present, C{False} otherwise.
"""
headers, wrapper = self.headers(foo=[b"lemur"])
self.assertIn(b"foo", wrapper)
self.assertNotIn(b"bar", wrapper)
def test_pop(self):
"""
L{_DictHeaders.pop} returns the last header value associated with the
given header name and removes the header.
"""
headers, wrapper = self.headers(foo=[b"lemur", b"panda"])
self.assertEqual(wrapper.pop(b"foo"), b"panda")
self.assertIdentical(headers.getRawHeaders(b"foo"), None)
def test_popMissing(self):
"""
L{_DictHeaders.pop} raises L{KeyError} if passed a header name which is
not present.
"""
headers, wrapper = self.headers()
self.assertRaises(KeyError, wrapper.pop, b"foo")
def test_popDefault(self):
"""
L{_DictHeaders.pop} returns the last header value associated with the
given header name and removes the header, even if it is supplied with a
default value.
"""
headers, wrapper = self.headers(foo=[b"lemur"])
self.assertEqual(wrapper.pop(b"foo", b"bar"), b"lemur")
self.assertIdentical(headers.getRawHeaders(b"foo"), None)
def test_popDefaultMissing(self):
"""
L{_DictHeaders.pop} returns the default value is asked for a header
name which is not present.
"""
headers, wrapper = self.headers(foo=[b"lemur"])
self.assertEqual(wrapper.pop(b"bar", b"baz"), b"baz")
self.assertEqual(headers.getRawHeaders(b"foo"), [b"lemur"])
def test_popitem(self):
"""
L{_DictHeaders.popitem} returns some header name/value pair.
"""
headers, wrapper = self.headers(foo=[b"lemur", b"panda"])
self.assertEqual(wrapper.popitem(), (b"foo", b"panda"))
self.assertIdentical(headers.getRawHeaders(b"foo"), None)
def test_popitemEmpty(self):
"""
L{_DictHeaders.popitem} raises L{KeyError} if there are no headers
present.
"""
headers, wrapper = self.headers()
self.assertRaises(KeyError, wrapper.popitem)
def test_update(self):
"""
L{_DictHeaders.update} adds the header/value pairs in the C{dict} it is
passed, overriding any existing values for those headers.
"""
headers, wrapper = self.headers(foo=[b"lemur"])
wrapper.update({b"foo": b"panda", b"bar": b"marmot"})
self.assertEqual(headers.getRawHeaders(b"foo"), [b"panda"])
self.assertEqual(headers.getRawHeaders(b"bar"), [b"marmot"])
def test_updateWithKeywords(self):
"""
L{_DictHeaders.update} adds header names given as keyword arguments
with the keyword values as the header value.
"""
headers, wrapper = self.headers(foo=[b"lemur"])
wrapper.update(foo=b"panda", bar=b"marmot")
self.assertEqual(headers.getRawHeaders(b"foo"), [b"panda"])
self.assertEqual(headers.getRawHeaders(b"bar"), [b"marmot"])
if _PY3:
test_updateWithKeywords.skip = "Not yet supported on Python 3; see #6082."
def test_setdefaultMissing(self):
"""
If passed the name of a header which is not present,
L{_DictHeaders.setdefault} sets the value of the given header to the
specified default value and returns it.
"""
headers, wrapper = self.headers(foo=[b"bar"])
self.assertEqual(wrapper.setdefault(b"baz", b"quux"), b"quux")
self.assertEqual(headers.getRawHeaders(b"foo"), [b"bar"])
self.assertEqual(headers.getRawHeaders(b"baz"), [b"quux"])
def test_setdefaultPresent(self):
"""
If passed the name of a header which is present,
L{_DictHeaders.setdefault} makes no changes to the headers and
returns the last value already associated with that header.
"""
headers, wrapper = self.headers(foo=[b"bar", b"baz"])
self.assertEqual(wrapper.setdefault(b"foo", b"quux"), b"baz")
self.assertEqual(headers.getRawHeaders(b"foo"), [b"bar", b"baz"])
def test_setdefaultDefault(self):
"""
If a value is not passed to L{_DictHeaders.setdefault}, C{None} is
used.
"""
# This results in an invalid state for the headers, but maybe some
# application is doing this an intermediate step towards some other
# state. Anyway, it was broken with the old implementation so it's
# broken with the new implementation. Compatibility, for the win.
# -exarkun
headers, wrapper = self.headers()
self.assertIdentical(wrapper.setdefault(b"foo"), None)
self.assertEqual(headers.getRawHeaders(b"foo"), [None])
def test_dictComparison(self):
"""
An instance of L{_DictHeaders} compares equal to a C{dict} which
contains the same header/value pairs. For header names with multiple
values, the last value only is considered.
"""
headers, wrapper = self.headers(foo=[b"lemur"], bar=[b"panda", b"marmot"])
self.assertNotEqual(wrapper, {b"foo": b"lemur", b"bar": b"panda"})
self.assertEqual(wrapper, {b"foo": b"lemur", b"bar": b"marmot"})
def test_otherComparison(self):
"""
An instance of L{_DictHeaders} does not compare equal to other
unrelated objects.
"""
headers, wrapper = self.headers()
self.assertNotEqual(wrapper, ())
self.assertNotEqual(wrapper, object())
self.assertNotEqual(wrapper, b"foo")
if _PY3:
# Python 3 lacks these APIs
del test_iterkeys, test_itervalues, test_iteritems, test_has_key

View file

@ -0,0 +1,634 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web._auth}.
"""
from zope.interface import implements
from zope.interface.verify import verifyObject
from twisted.trial import unittest
from twisted.python.failure import Failure
from twisted.internet.error import ConnectionDone
from twisted.internet.address import IPv4Address
from twisted.cred import error, portal
from twisted.cred.checkers import InMemoryUsernamePasswordDatabaseDontUse
from twisted.cred.checkers import ANONYMOUS, AllowAnonymousAccess
from twisted.cred.credentials import IUsernamePassword
from twisted.web.iweb import ICredentialFactory
from twisted.web.resource import IResource, Resource, getChildForRequest
from twisted.web._auth import basic, digest
from twisted.web._auth.wrapper import HTTPAuthSessionWrapper, UnauthorizedResource
from twisted.web._auth.basic import BasicCredentialFactory
from twisted.web.server import NOT_DONE_YET
from twisted.web.static import Data
from twisted.web.test.test_web import DummyRequest
def b64encode(s):
return s.encode('base64').strip()
class BasicAuthTestsMixin:
"""
L{TestCase} mixin class which defines a number of tests for
L{basic.BasicCredentialFactory}. Because this mixin defines C{setUp}, it
must be inherited before L{TestCase}.
"""
def setUp(self):
self.request = self.makeRequest()
self.realm = 'foo'
self.username = 'dreid'
self.password = 'S3CuR1Ty'
self.credentialFactory = basic.BasicCredentialFactory(self.realm)
def makeRequest(self, method='GET', clientAddress=None):
"""
Create a request object to be passed to
L{basic.BasicCredentialFactory.decode} along with a response value.
Override this in a subclass.
"""
raise NotImplementedError("%r did not implement makeRequest" % (
self.__class__,))
def test_interface(self):
"""
L{BasicCredentialFactory} implements L{ICredentialFactory}.
"""
self.assertTrue(
verifyObject(ICredentialFactory, self.credentialFactory))
def test_usernamePassword(self):
"""
L{basic.BasicCredentialFactory.decode} turns a base64-encoded response
into a L{UsernamePassword} object with a password which reflects the
one which was encoded in the response.
"""
response = b64encode('%s:%s' % (self.username, self.password))
creds = self.credentialFactory.decode(response, self.request)
self.assertTrue(IUsernamePassword.providedBy(creds))
self.assertTrue(creds.checkPassword(self.password))
self.assertFalse(creds.checkPassword(self.password + 'wrong'))
def test_incorrectPadding(self):
"""
L{basic.BasicCredentialFactory.decode} decodes a base64-encoded
response with incorrect padding.
"""
response = b64encode('%s:%s' % (self.username, self.password))
response = response.strip('=')
creds = self.credentialFactory.decode(response, self.request)
self.assertTrue(verifyObject(IUsernamePassword, creds))
self.assertTrue(creds.checkPassword(self.password))
def test_invalidEncoding(self):
"""
L{basic.BasicCredentialFactory.decode} raises L{LoginFailed} if passed
a response which is not base64-encoded.
"""
response = 'x' # one byte cannot be valid base64 text
self.assertRaises(
error.LoginFailed,
self.credentialFactory.decode, response, self.makeRequest())
def test_invalidCredentials(self):
"""
L{basic.BasicCredentialFactory.decode} raises L{LoginFailed} when
passed a response which is not valid base64-encoded text.
"""
response = b64encode('123abc+/')
self.assertRaises(
error.LoginFailed,
self.credentialFactory.decode,
response, self.makeRequest())
class RequestMixin:
def makeRequest(self, method='GET', clientAddress=None):
"""
Create a L{DummyRequest} (change me to create a
L{twisted.web.http.Request} instead).
"""
request = DummyRequest('/')
request.method = method
request.client = clientAddress
return request
class BasicAuthTestCase(RequestMixin, BasicAuthTestsMixin, unittest.TestCase):
"""
Basic authentication tests which use L{twisted.web.http.Request}.
"""
class DigestAuthTestCase(RequestMixin, unittest.TestCase):
"""
Digest authentication tests which use L{twisted.web.http.Request}.
"""
def setUp(self):
"""
Create a DigestCredentialFactory for testing
"""
self.realm = "test realm"
self.algorithm = "md5"
self.credentialFactory = digest.DigestCredentialFactory(
self.algorithm, self.realm)
self.request = self.makeRequest()
def test_decode(self):
"""
L{digest.DigestCredentialFactory.decode} calls the C{decode} method on
L{twisted.cred.digest.DigestCredentialFactory} with the HTTP method and
host of the request.
"""
host = '169.254.0.1'
method = 'GET'
done = [False]
response = object()
def check(_response, _method, _host):
self.assertEqual(response, _response)
self.assertEqual(method, _method)
self.assertEqual(host, _host)
done[0] = True
self.patch(self.credentialFactory.digest, 'decode', check)
req = self.makeRequest(method, IPv4Address('TCP', host, 81))
self.credentialFactory.decode(response, req)
self.assertTrue(done[0])
def test_interface(self):
"""
L{DigestCredentialFactory} implements L{ICredentialFactory}.
"""
self.assertTrue(
verifyObject(ICredentialFactory, self.credentialFactory))
def test_getChallenge(self):
"""
The challenge issued by L{DigestCredentialFactory.getChallenge} must
include C{'qop'}, C{'realm'}, C{'algorithm'}, C{'nonce'}, and
C{'opaque'} keys. The values for the C{'realm'} and C{'algorithm'}
keys must match the values supplied to the factory's initializer.
None of the values may have newlines in them.
"""
challenge = self.credentialFactory.getChallenge(self.request)
self.assertEqual(challenge['qop'], 'auth')
self.assertEqual(challenge['realm'], 'test realm')
self.assertEqual(challenge['algorithm'], 'md5')
self.assertIn('nonce', challenge)
self.assertIn('opaque', challenge)
for v in challenge.values():
self.assertNotIn('\n', v)
def test_getChallengeWithoutClientIP(self):
"""
L{DigestCredentialFactory.getChallenge} can issue a challenge even if
the L{Request} it is passed returns C{None} from C{getClientIP}.
"""
request = self.makeRequest('GET', None)
challenge = self.credentialFactory.getChallenge(request)
self.assertEqual(challenge['qop'], 'auth')
self.assertEqual(challenge['realm'], 'test realm')
self.assertEqual(challenge['algorithm'], 'md5')
self.assertIn('nonce', challenge)
self.assertIn('opaque', challenge)
class UnauthorizedResourceTests(unittest.TestCase):
"""
Tests for L{UnauthorizedResource}.
"""
def test_getChildWithDefault(self):
"""
An L{UnauthorizedResource} is every child of itself.
"""
resource = UnauthorizedResource([])
self.assertIdentical(
resource.getChildWithDefault("foo", None), resource)
self.assertIdentical(
resource.getChildWithDefault("bar", None), resource)
def _unauthorizedRenderTest(self, request):
"""
Render L{UnauthorizedResource} for the given request object and verify
that the response code is I{Unauthorized} and that a I{WWW-Authenticate}
header is set in the response containing a challenge.
"""
resource = UnauthorizedResource([
BasicCredentialFactory('example.com')])
request.render(resource)
self.assertEqual(request.responseCode, 401)
self.assertEqual(
request.responseHeaders.getRawHeaders('www-authenticate'),
['basic realm="example.com"'])
def test_render(self):
"""
L{UnauthorizedResource} renders with a 401 response code and a
I{WWW-Authenticate} header and puts a simple unauthorized message
into the response body.
"""
request = DummyRequest([''])
self._unauthorizedRenderTest(request)
self.assertEqual('Unauthorized', ''.join(request.written))
def test_renderHEAD(self):
"""
The rendering behavior of L{UnauthorizedResource} for a I{HEAD} request
is like its handling of a I{GET} request, but no response body is
written.
"""
request = DummyRequest([''])
request.method = 'HEAD'
self._unauthorizedRenderTest(request)
self.assertEqual('', ''.join(request.written))
def test_renderQuotesRealm(self):
"""
The realm value included in the I{WWW-Authenticate} header set in
the response when L{UnauthorizedResounrce} is rendered has quotes
and backslashes escaped.
"""
resource = UnauthorizedResource([
BasicCredentialFactory('example\\"foo')])
request = DummyRequest([''])
request.render(resource)
self.assertEqual(
request.responseHeaders.getRawHeaders('www-authenticate'),
['basic realm="example\\\\\\"foo"'])
class Realm(object):
"""
A simple L{IRealm} implementation which gives out L{WebAvatar} for any
avatarId.
@type loggedIn: C{int}
@ivar loggedIn: The number of times C{requestAvatar} has been invoked for
L{IResource}.
@type loggedOut: C{int}
@ivar loggedOut: The number of times the logout callback has been invoked.
"""
implements(portal.IRealm)
def __init__(self, avatarFactory):
self.loggedOut = 0
self.loggedIn = 0
self.avatarFactory = avatarFactory
def requestAvatar(self, avatarId, mind, *interfaces):
if IResource in interfaces:
self.loggedIn += 1
return IResource, self.avatarFactory(avatarId), self.logout
raise NotImplementedError()
def logout(self):
self.loggedOut += 1
class HTTPAuthHeaderTests(unittest.TestCase):
"""
Tests for L{HTTPAuthSessionWrapper}.
"""
makeRequest = DummyRequest
def setUp(self):
"""
Create a realm, portal, and L{HTTPAuthSessionWrapper} to use in the tests.
"""
self.username = 'foo bar'
self.password = 'bar baz'
self.avatarContent = "contents of the avatar resource itself"
self.childName = "foo-child"
self.childContent = "contents of the foo child of the avatar"
self.checker = InMemoryUsernamePasswordDatabaseDontUse()
self.checker.addUser(self.username, self.password)
self.avatar = Data(self.avatarContent, 'text/plain')
self.avatar.putChild(
self.childName, Data(self.childContent, 'text/plain'))
self.avatars = {self.username: self.avatar}
self.realm = Realm(self.avatars.get)
self.portal = portal.Portal(self.realm, [self.checker])
self.credentialFactories = []
self.wrapper = HTTPAuthSessionWrapper(
self.portal, self.credentialFactories)
def _authorizedBasicLogin(self, request):
"""
Add an I{basic authorization} header to the given request and then
dispatch it, starting from C{self.wrapper} and returning the resulting
L{IResource}.
"""
authorization = b64encode(self.username + ':' + self.password)
request.headers['authorization'] = 'Basic ' + authorization
return getChildForRequest(self.wrapper, request)
def test_getChildWithDefault(self):
"""
Resource traversal which encounters an L{HTTPAuthSessionWrapper}
results in an L{UnauthorizedResource} instance when the request does
not have the required I{Authorization} headers.
"""
request = self.makeRequest([self.childName])
child = getChildForRequest(self.wrapper, request)
d = request.notifyFinish()
def cbFinished(result):
self.assertEqual(request.responseCode, 401)
d.addCallback(cbFinished)
request.render(child)
return d
def _invalidAuthorizationTest(self, response):
"""
Create a request with the given value as the value of an
I{Authorization} header and perform resource traversal with it,
starting at C{self.wrapper}. Assert that the result is a 401 response
code. Return a L{Deferred} which fires when this is all done.
"""
self.credentialFactories.append(BasicCredentialFactory('example.com'))
request = self.makeRequest([self.childName])
request.headers['authorization'] = response
child = getChildForRequest(self.wrapper, request)
d = request.notifyFinish()
def cbFinished(result):
self.assertEqual(request.responseCode, 401)
d.addCallback(cbFinished)
request.render(child)
return d
def test_getChildWithDefaultUnauthorizedUser(self):
"""
Resource traversal which enouncters an L{HTTPAuthSessionWrapper}
results in an L{UnauthorizedResource} when the request has an
I{Authorization} header with a user which does not exist.
"""
return self._invalidAuthorizationTest('Basic ' + b64encode('foo:bar'))
def test_getChildWithDefaultUnauthorizedPassword(self):
"""
Resource traversal which enouncters an L{HTTPAuthSessionWrapper}
results in an L{UnauthorizedResource} when the request has an
I{Authorization} header with a user which exists and the wrong
password.
"""
return self._invalidAuthorizationTest(
'Basic ' + b64encode(self.username + ':bar'))
def test_getChildWithDefaultUnrecognizedScheme(self):
"""
Resource traversal which enouncters an L{HTTPAuthSessionWrapper}
results in an L{UnauthorizedResource} when the request has an
I{Authorization} header with an unrecognized scheme.
"""
return self._invalidAuthorizationTest('Quux foo bar baz')
def test_getChildWithDefaultAuthorized(self):
"""
Resource traversal which encounters an L{HTTPAuthSessionWrapper}
results in an L{IResource} which renders the L{IResource} avatar
retrieved from the portal when the request has a valid I{Authorization}
header.
"""
self.credentialFactories.append(BasicCredentialFactory('example.com'))
request = self.makeRequest([self.childName])
child = self._authorizedBasicLogin(request)
d = request.notifyFinish()
def cbFinished(ignored):
self.assertEqual(request.written, [self.childContent])
d.addCallback(cbFinished)
request.render(child)
return d
def test_renderAuthorized(self):
"""
Resource traversal which terminates at an L{HTTPAuthSessionWrapper}
and includes correct authentication headers results in the
L{IResource} avatar (not one of its children) retrieved from the
portal being rendered.
"""
self.credentialFactories.append(BasicCredentialFactory('example.com'))
# Request it exactly, not any of its children.
request = self.makeRequest([])
child = self._authorizedBasicLogin(request)
d = request.notifyFinish()
def cbFinished(ignored):
self.assertEqual(request.written, [self.avatarContent])
d.addCallback(cbFinished)
request.render(child)
return d
def test_getChallengeCalledWithRequest(self):
"""
When L{HTTPAuthSessionWrapper} finds an L{ICredentialFactory} to issue
a challenge, it calls the C{getChallenge} method with the request as an
argument.
"""
class DumbCredentialFactory(object):
implements(ICredentialFactory)
scheme = 'dumb'
def __init__(self):
self.requests = []
def getChallenge(self, request):
self.requests.append(request)
return {}
factory = DumbCredentialFactory()
self.credentialFactories.append(factory)
request = self.makeRequest([self.childName])
child = getChildForRequest(self.wrapper, request)
d = request.notifyFinish()
def cbFinished(ignored):
self.assertEqual(factory.requests, [request])
d.addCallback(cbFinished)
request.render(child)
return d
def _logoutTest(self):
"""
Issue a request for an authentication-protected resource using valid
credentials and then return the C{DummyRequest} instance which was
used.
This is a helper for tests about the behavior of the logout
callback.
"""
self.credentialFactories.append(BasicCredentialFactory('example.com'))
class SlowerResource(Resource):
def render(self, request):
return NOT_DONE_YET
self.avatar.putChild(self.childName, SlowerResource())
request = self.makeRequest([self.childName])
child = self._authorizedBasicLogin(request)
request.render(child)
self.assertEqual(self.realm.loggedOut, 0)
return request
def test_logout(self):
"""
The realm's logout callback is invoked after the resource is rendered.
"""
request = self._logoutTest()
request.finish()
self.assertEqual(self.realm.loggedOut, 1)
def test_logoutOnError(self):
"""
The realm's logout callback is also invoked if there is an error
generating the response (for example, if the client disconnects
early).
"""
request = self._logoutTest()
request.processingFailed(
Failure(ConnectionDone("Simulated disconnect")))
self.assertEqual(self.realm.loggedOut, 1)
def test_decodeRaises(self):
"""
Resource traversal which enouncters an L{HTTPAuthSessionWrapper}
results in an L{UnauthorizedResource} when the request has a I{Basic
Authorization} header which cannot be decoded using base64.
"""
self.credentialFactories.append(BasicCredentialFactory('example.com'))
request = self.makeRequest([self.childName])
request.headers['authorization'] = 'Basic decode should fail'
child = getChildForRequest(self.wrapper, request)
self.assertIsInstance(child, UnauthorizedResource)
def test_selectParseResponse(self):
"""
L{HTTPAuthSessionWrapper._selectParseHeader} returns a two-tuple giving
the L{ICredentialFactory} to use to parse the header and a string
containing the portion of the header which remains to be parsed.
"""
basicAuthorization = 'Basic abcdef123456'
self.assertEqual(
self.wrapper._selectParseHeader(basicAuthorization),
(None, None))
factory = BasicCredentialFactory('example.com')
self.credentialFactories.append(factory)
self.assertEqual(
self.wrapper._selectParseHeader(basicAuthorization),
(factory, 'abcdef123456'))
def test_unexpectedDecodeError(self):
"""
Any unexpected exception raised by the credential factory's C{decode}
method results in a 500 response code and causes the exception to be
logged.
"""
class UnexpectedException(Exception):
pass
class BadFactory(object):
scheme = 'bad'
def getChallenge(self, client):
return {}
def decode(self, response, request):
raise UnexpectedException()
self.credentialFactories.append(BadFactory())
request = self.makeRequest([self.childName])
request.headers['authorization'] = 'Bad abc'
child = getChildForRequest(self.wrapper, request)
request.render(child)
self.assertEqual(request.responseCode, 500)
self.assertEqual(len(self.flushLoggedErrors(UnexpectedException)), 1)
def test_unexpectedLoginError(self):
"""
Any unexpected failure from L{Portal.login} results in a 500 response
code and causes the failure to be logged.
"""
class UnexpectedException(Exception):
pass
class BrokenChecker(object):
credentialInterfaces = (IUsernamePassword,)
def requestAvatarId(self, credentials):
raise UnexpectedException()
self.portal.registerChecker(BrokenChecker())
self.credentialFactories.append(BasicCredentialFactory('example.com'))
request = self.makeRequest([self.childName])
child = self._authorizedBasicLogin(request)
request.render(child)
self.assertEqual(request.responseCode, 500)
self.assertEqual(len(self.flushLoggedErrors(UnexpectedException)), 1)
def test_anonymousAccess(self):
"""
Anonymous requests are allowed if a L{Portal} has an anonymous checker
registered.
"""
unprotectedContents = "contents of the unprotected child resource"
self.avatars[ANONYMOUS] = Resource()
self.avatars[ANONYMOUS].putChild(
self.childName, Data(unprotectedContents, 'text/plain'))
self.portal.registerChecker(AllowAnonymousAccess())
self.credentialFactories.append(BasicCredentialFactory('example.com'))
request = self.makeRequest([self.childName])
child = getChildForRequest(self.wrapper, request)
d = request.notifyFinish()
def cbFinished(ignored):
self.assertEqual(request.written, [unprotectedContents])
d.addCallback(cbFinished)
request.render(child)
return d

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,544 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test for L{twisted.web.proxy}.
"""
from twisted.trial.unittest import TestCase
from twisted.test.proto_helpers import StringTransportWithDisconnection
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
from twisted.web.server import Site
from twisted.web.proxy import ReverseProxyResource, ProxyClientFactory
from twisted.web.proxy import ProxyClient, ProxyRequest, ReverseProxyRequest
from twisted.web.test.test_web import DummyRequest
class ReverseProxyResourceTestCase(TestCase):
"""
Tests for L{ReverseProxyResource}.
"""
def _testRender(self, uri, expectedURI):
"""
Check that a request pointing at C{uri} produce a new proxy connection,
with the path of this request pointing at C{expectedURI}.
"""
root = Resource()
reactor = MemoryReactor()
resource = ReverseProxyResource("127.0.0.1", 1234, "/path", reactor)
root.putChild('index', resource)
site = Site(root)
transport = StringTransportWithDisconnection()
channel = site.buildProtocol(None)
channel.makeConnection(transport)
# Clear the timeout if the tests failed
self.addCleanup(channel.connectionLost, None)
channel.dataReceived("GET %s HTTP/1.1\r\nAccept: text/html\r\n\r\n" %
(uri,))
# Check that one connection has been created, to the good host/port
self.assertEqual(len(reactor.tcpClients), 1)
self.assertEqual(reactor.tcpClients[0][0], "127.0.0.1")
self.assertEqual(reactor.tcpClients[0][1], 1234)
# Check the factory passed to the connect, and its given path
factory = reactor.tcpClients[0][2]
self.assertIsInstance(factory, ProxyClientFactory)
self.assertEqual(factory.rest, expectedURI)
self.assertEqual(factory.headers["host"], "127.0.0.1:1234")
def test_render(self):
"""
Test that L{ReverseProxyResource.render} initiates a connection to the
given server with a L{ProxyClientFactory} as parameter.
"""
return self._testRender("/index", "/path")
def test_renderWithQuery(self):
"""
Test that L{ReverseProxyResource.render} passes query parameters to the
created factory.
"""
return self._testRender("/index?foo=bar", "/path?foo=bar")
def test_getChild(self):
"""
The L{ReverseProxyResource.getChild} method should return a resource
instance with the same class as the originating resource, forward
port, host, and reactor values, and update the path value with the
value passed.
"""
reactor = MemoryReactor()
resource = ReverseProxyResource("127.0.0.1", 1234, "/path", reactor)
child = resource.getChild('foo', None)
# The child should keep the same class
self.assertIsInstance(child, ReverseProxyResource)
self.assertEqual(child.path, "/path/foo")
self.assertEqual(child.port, 1234)
self.assertEqual(child.host, "127.0.0.1")
self.assertIdentical(child.reactor, resource.reactor)
def test_getChildWithSpecial(self):
"""
The L{ReverseProxyResource} return by C{getChild} has a path which has
already been quoted.
"""
resource = ReverseProxyResource("127.0.0.1", 1234, "/path")
child = resource.getChild(' /%', None)
self.assertEqual(child.path, "/path/%20%2F%25")
class DummyChannel(object):
"""
A dummy HTTP channel, that does nothing but holds a transport and saves
connection lost.
@ivar transport: the transport used by the client.
@ivar lostReason: the reason saved at connection lost.
"""
def __init__(self, transport):
"""
Hold a reference to the transport.
"""
self.transport = transport
self.lostReason = None
def connectionLost(self, reason):
"""
Keep track of the connection lost reason.
"""
self.lostReason = reason
class ProxyClientTestCase(TestCase):
"""
Tests for L{ProxyClient}.
"""
def _parseOutHeaders(self, content):
"""
Parse the headers out of some web content.
@param content: Bytes received from a web server.
@return: A tuple of (requestLine, headers, body). C{headers} is a dict
of headers, C{requestLine} is the first line (e.g. "POST /foo ...")
and C{body} is whatever is left.
"""
headers, body = content.split('\r\n\r\n')
headers = headers.split('\r\n')
requestLine = headers.pop(0)
return (
requestLine, dict(header.split(': ') for header in headers), body)
def makeRequest(self, path):
"""
Make a dummy request object for the URL path.
@param path: A URL path, beginning with a slash.
@return: A L{DummyRequest}.
"""
return DummyRequest(path)
def makeProxyClient(self, request, method="GET", headers=None,
requestBody=""):
"""
Make a L{ProxyClient} object used for testing.
@param request: The request to use.
@param method: The HTTP method to use, GET by default.
@param headers: The HTTP headers to use expressed as a dict. If not
provided, defaults to {'accept': 'text/html'}.
@param requestBody: The body of the request. Defaults to the empty
string.
@return: A L{ProxyClient}
"""
if headers is None:
headers = {"accept": "text/html"}
path = '/' + request.postpath
return ProxyClient(
method, path, 'HTTP/1.0', headers, requestBody, request)
def connectProxy(self, proxyClient):
"""
Connect a proxy client to a L{StringTransportWithDisconnection}.
@param proxyClient: A L{ProxyClient}.
@return: The L{StringTransportWithDisconnection}.
"""
clientTransport = StringTransportWithDisconnection()
clientTransport.protocol = proxyClient
proxyClient.makeConnection(clientTransport)
return clientTransport
def assertForwardsHeaders(self, proxyClient, requestLine, headers):
"""
Assert that C{proxyClient} sends C{headers} when it connects.
@param proxyClient: A L{ProxyClient}.
@param requestLine: The request line we expect to be sent.
@param headers: A dict of headers we expect to be sent.
@return: If the assertion is successful, return the request body as
bytes.
"""
self.connectProxy(proxyClient)
requestContent = proxyClient.transport.value()
receivedLine, receivedHeaders, body = self._parseOutHeaders(
requestContent)
self.assertEqual(receivedLine, requestLine)
self.assertEqual(receivedHeaders, headers)
return body
def makeResponseBytes(self, code, message, headers, body):
lines = ["HTTP/1.0 %d %s" % (code, message)]
for header, values in headers:
for value in values:
lines.append("%s: %s" % (header, value))
lines.extend(['', body])
return '\r\n'.join(lines)
def assertForwardsResponse(self, request, code, message, headers, body):
"""
Assert that C{request} has forwarded a response from the server.
@param request: A L{DummyRequest}.
@param code: The expected HTTP response code.
@param message: The expected HTTP message.
@param headers: The expected HTTP headers.
@param body: The expected response body.
"""
self.assertEqual(request.responseCode, code)
self.assertEqual(request.responseMessage, message)
receivedHeaders = list(request.responseHeaders.getAllRawHeaders())
receivedHeaders.sort()
expectedHeaders = headers[:]
expectedHeaders.sort()
self.assertEqual(receivedHeaders, expectedHeaders)
self.assertEqual(''.join(request.written), body)
def _testDataForward(self, code, message, headers, body, method="GET",
requestBody="", loseConnection=True):
"""
Build a fake proxy connection, and send C{data} over it, checking that
it's forwarded to the originating request.
"""
request = self.makeRequest('foo')
client = self.makeProxyClient(
request, method, {'accept': 'text/html'}, requestBody)
receivedBody = self.assertForwardsHeaders(
client, '%s /foo HTTP/1.0' % (method,),
{'connection': 'close', 'accept': 'text/html'})
self.assertEqual(receivedBody, requestBody)
# Fake an answer
client.dataReceived(
self.makeResponseBytes(code, message, headers, body))
# Check that the response data has been forwarded back to the original
# requester.
self.assertForwardsResponse(request, code, message, headers, body)
# Check that when the response is done, the request is finished.
if loseConnection:
client.transport.loseConnection()
# Even if we didn't call loseConnection, the transport should be
# disconnected. This lets us not rely on the server to close our
# sockets for us.
self.assertFalse(client.transport.connected)
self.assertEqual(request.finished, 1)
def test_forward(self):
"""
When connected to the server, L{ProxyClient} should send the saved
request, with modifications of the headers, and then forward the result
to the parent request.
"""
return self._testDataForward(
200, "OK", [("Foo", ["bar", "baz"])], "Some data\r\n")
def test_postData(self):
"""
Try to post content in the request, and check that the proxy client
forward the body of the request.
"""
return self._testDataForward(
200, "OK", [("Foo", ["bar"])], "Some data\r\n", "POST", "Some content")
def test_statusWithMessage(self):
"""
If the response contains a status with a message, it should be
forwarded to the parent request with all the information.
"""
return self._testDataForward(
404, "Not Found", [], "")
def test_contentLength(self):
"""
If the response contains a I{Content-Length} header, the inbound
request object should still only have C{finish} called on it once.
"""
data = "foo bar baz"
return self._testDataForward(
200, "OK", [("Content-Length", [str(len(data))])], data)
def test_losesConnection(self):
"""
If the response contains a I{Content-Length} header, the outgoing
connection is closed when all response body data has been received.
"""
data = "foo bar baz"
return self._testDataForward(
200, "OK", [("Content-Length", [str(len(data))])], data,
loseConnection=False)
def test_headersCleanups(self):
"""
The headers given at initialization should be modified:
B{proxy-connection} should be removed if present, and B{connection}
should be added.
"""
client = ProxyClient('GET', '/foo', 'HTTP/1.0',
{"accept": "text/html", "proxy-connection": "foo"}, '', None)
self.assertEqual(client.headers,
{"accept": "text/html", "connection": "close"})
def test_keepaliveNotForwarded(self):
"""
The proxy doesn't really know what to do with keepalive things from
the remote server, so we stomp over any keepalive header we get from
the client.
"""
headers = {
"accept": "text/html",
'keep-alive': '300',
'connection': 'keep-alive',
}
expectedHeaders = headers.copy()
expectedHeaders['connection'] = 'close'
del expectedHeaders['keep-alive']
client = ProxyClient('GET', '/foo', 'HTTP/1.0', headers, '', None)
self.assertForwardsHeaders(
client, 'GET /foo HTTP/1.0', expectedHeaders)
def test_defaultHeadersOverridden(self):
"""
L{server.Request} within the proxy sets certain response headers by
default. When we get these headers back from the remote server, the
defaults are overridden rather than simply appended.
"""
request = self.makeRequest('foo')
request.responseHeaders.setRawHeaders('server', ['old-bar'])
request.responseHeaders.setRawHeaders('date', ['old-baz'])
request.responseHeaders.setRawHeaders('content-type', ["old/qux"])
client = self.makeProxyClient(request, headers={'accept': 'text/html'})
self.connectProxy(client)
headers = {
'Server': ['bar'],
'Date': ['2010-01-01'],
'Content-Type': ['application/x-baz'],
}
client.dataReceived(
self.makeResponseBytes(200, "OK", headers.items(), ''))
self.assertForwardsResponse(
request, 200, 'OK', headers.items(), '')
class ProxyClientFactoryTestCase(TestCase):
"""
Tests for L{ProxyClientFactory}.
"""
def test_connectionFailed(self):
"""
Check that L{ProxyClientFactory.clientConnectionFailed} produces
a B{501} response to the parent request.
"""
request = DummyRequest(['foo'])
factory = ProxyClientFactory('GET', '/foo', 'HTTP/1.0',
{"accept": "text/html"}, '', request)
factory.clientConnectionFailed(None, None)
self.assertEqual(request.responseCode, 501)
self.assertEqual(request.responseMessage, "Gateway error")
self.assertEqual(
list(request.responseHeaders.getAllRawHeaders()),
[("Content-Type", ["text/html"])])
self.assertEqual(
''.join(request.written),
"<H1>Could not connect</H1>")
self.assertEqual(request.finished, 1)
def test_buildProtocol(self):
"""
L{ProxyClientFactory.buildProtocol} should produce a L{ProxyClient}
with the same values of attributes (with updates on the headers).
"""
factory = ProxyClientFactory('GET', '/foo', 'HTTP/1.0',
{"accept": "text/html"}, 'Some data',
None)
proto = factory.buildProtocol(None)
self.assertIsInstance(proto, ProxyClient)
self.assertEqual(proto.command, 'GET')
self.assertEqual(proto.rest, '/foo')
self.assertEqual(proto.data, 'Some data')
self.assertEqual(proto.headers,
{"accept": "text/html", "connection": "close"})
class ProxyRequestTestCase(TestCase):
"""
Tests for L{ProxyRequest}.
"""
def _testProcess(self, uri, expectedURI, method="GET", data=""):
"""
Build a request pointing at C{uri}, and check that a proxied request
is created, pointing a C{expectedURI}.
"""
transport = StringTransportWithDisconnection()
channel = DummyChannel(transport)
reactor = MemoryReactor()
request = ProxyRequest(channel, False, reactor)
request.gotLength(len(data))
request.handleContentChunk(data)
request.requestReceived(method, 'http://example.com%s' % (uri,),
'HTTP/1.0')
self.assertEqual(len(reactor.tcpClients), 1)
self.assertEqual(reactor.tcpClients[0][0], "example.com")
self.assertEqual(reactor.tcpClients[0][1], 80)
factory = reactor.tcpClients[0][2]
self.assertIsInstance(factory, ProxyClientFactory)
self.assertEqual(factory.command, method)
self.assertEqual(factory.version, 'HTTP/1.0')
self.assertEqual(factory.headers, {'host': 'example.com'})
self.assertEqual(factory.data, data)
self.assertEqual(factory.rest, expectedURI)
self.assertEqual(factory.father, request)
def test_process(self):
"""
L{ProxyRequest.process} should create a connection to the given server,
with a L{ProxyClientFactory} as connection factory, with the correct
parameters:
- forward comment, version and data values
- update headers with the B{host} value
- remove the host from the URL
- pass the request as parent request
"""
return self._testProcess("/foo/bar", "/foo/bar")
def test_processWithoutTrailingSlash(self):
"""
If the incoming request doesn't contain a slash,
L{ProxyRequest.process} should add one when instantiating
L{ProxyClientFactory}.
"""
return self._testProcess("", "/")
def test_processWithData(self):
"""
L{ProxyRequest.process} should be able to retrieve request body and
to forward it.
"""
return self._testProcess(
"/foo/bar", "/foo/bar", "POST", "Some content")
def test_processWithPort(self):
"""
Check that L{ProxyRequest.process} correctly parse port in the incoming
URL, and create a outgoing connection with this port.
"""
transport = StringTransportWithDisconnection()
channel = DummyChannel(transport)
reactor = MemoryReactor()
request = ProxyRequest(channel, False, reactor)
request.gotLength(0)
request.requestReceived('GET', 'http://example.com:1234/foo/bar',
'HTTP/1.0')
# That should create one connection, with the port parsed from the URL
self.assertEqual(len(reactor.tcpClients), 1)
self.assertEqual(reactor.tcpClients[0][0], "example.com")
self.assertEqual(reactor.tcpClients[0][1], 1234)
class DummyFactory(object):
"""
A simple holder for C{host} and C{port} information.
"""
def __init__(self, host, port):
self.host = host
self.port = port
class ReverseProxyRequestTestCase(TestCase):
"""
Tests for L{ReverseProxyRequest}.
"""
def test_process(self):
"""
L{ReverseProxyRequest.process} should create a connection to its
factory host/port, using a L{ProxyClientFactory} instantiated with the
correct parameters, and particulary set the B{host} header to the
factory host.
"""
transport = StringTransportWithDisconnection()
channel = DummyChannel(transport)
reactor = MemoryReactor()
request = ReverseProxyRequest(channel, False, reactor)
request.factory = DummyFactory("example.com", 1234)
request.gotLength(0)
request.requestReceived('GET', '/foo/bar', 'HTTP/1.0')
# Check that one connection has been created, to the good host/port
self.assertEqual(len(reactor.tcpClients), 1)
self.assertEqual(reactor.tcpClients[0][0], "example.com")
self.assertEqual(reactor.tcpClients[0][1], 1234)
# Check the factory passed to the connect, and its headers
factory = reactor.tcpClients[0][2]
self.assertIsInstance(factory, ProxyClientFactory)
self.assertEqual(factory.headers, {'host': 'example.com'})

View file

@ -0,0 +1,261 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.resource}.
"""
from twisted.trial.unittest import TestCase
from twisted.web.error import UnsupportedMethod
from twisted.web.resource import (
NOT_FOUND, FORBIDDEN, Resource, ErrorPage, NoResource, ForbiddenResource,
getChildForRequest)
from twisted.web.test.requesthelper import DummyRequest
class ErrorPageTests(TestCase):
"""
Tests for L{ErrorPage}, L{NoResource}, and L{ForbiddenResource}.
"""
errorPage = ErrorPage
noResource = NoResource
forbiddenResource = ForbiddenResource
def test_getChild(self):
"""
The C{getChild} method of L{ErrorPage} returns the L{ErrorPage} it is
called on.
"""
page = self.errorPage(321, "foo", "bar")
self.assertIdentical(page.getChild(b"name", object()), page)
def _pageRenderingTest(self, page, code, brief, detail):
request = DummyRequest([b''])
template = (
u"\n"
u"<html>\n"
u" <head><title>%s - %s</title></head>\n"
u" <body>\n"
u" <h1>%s</h1>\n"
u" <p>%s</p>\n"
u" </body>\n"
u"</html>\n")
expected = template % (code, brief, brief, detail)
self.assertEqual(
page.render(request), expected.encode('utf-8'))
self.assertEqual(request.responseCode, code)
self.assertEqual(
request.outgoingHeaders,
{b'content-type': b'text/html; charset=utf-8'})
def test_errorPageRendering(self):
"""
L{ErrorPage.render} returns a C{bytes} describing the error defined by
the response code and message passed to L{ErrorPage.__init__}. It also
uses that response code to set the response code on the L{Request}
passed in.
"""
code = 321
brief = "brief description text"
detail = "much longer text might go here"
page = self.errorPage(code, brief, detail)
self._pageRenderingTest(page, code, brief, detail)
def test_noResourceRendering(self):
"""
L{NoResource} sets the HTTP I{NOT FOUND} code.
"""
detail = "long message"
page = self.noResource(detail)
self._pageRenderingTest(page, NOT_FOUND, "No Such Resource", detail)
def test_forbiddenResourceRendering(self):
"""
L{ForbiddenResource} sets the HTTP I{FORBIDDEN} code.
"""
detail = "longer message"
page = self.forbiddenResource(detail)
self._pageRenderingTest(page, FORBIDDEN, "Forbidden Resource", detail)
class DynamicChild(Resource):
"""
A L{Resource} to be created on the fly by L{DynamicChildren}.
"""
def __init__(self, path, request):
Resource.__init__(self)
self.path = path
self.request = request
class DynamicChildren(Resource):
"""
A L{Resource} with dynamic children.
"""
def getChild(self, path, request):
return DynamicChild(path, request)
class BytesReturnedRenderable(Resource):
"""
A L{Resource} with minimal capabilities to render a response.
"""
def __init__(self, response):
"""
@param response: A C{bytes} object giving the value to return from
C{render_GET}.
"""
Resource.__init__(self)
self._response = response
def render_GET(self, request):
"""
Render a response to a I{GET} request by returning a short byte string
to be written by the server.
"""
return self._response
class ImplicitAllowedMethods(Resource):
"""
A L{Resource} which implicitly defines its allowed methods by defining
renderers to handle them.
"""
def render_GET(self, request):
pass
def render_PUT(self, request):
pass
class ResourceTests(TestCase):
"""
Tests for L{Resource}.
"""
def test_staticChildren(self):
"""
L{Resource.putChild} adds a I{static} child to the resource. That child
is returned from any call to L{Resource.getChildWithDefault} for the
child's path.
"""
resource = Resource()
child = Resource()
sibling = Resource()
resource.putChild(b"foo", child)
resource.putChild(b"bar", sibling)
self.assertIdentical(
child, resource.getChildWithDefault(b"foo", DummyRequest([])))
def test_dynamicChildren(self):
"""
L{Resource.getChildWithDefault} delegates to L{Resource.getChild} when
the requested path is not associated with any static child.
"""
path = b"foo"
request = DummyRequest([])
resource = DynamicChildren()
child = resource.getChildWithDefault(path, request)
self.assertIsInstance(child, DynamicChild)
self.assertEqual(child.path, path)
self.assertIdentical(child.request, request)
def test_defaultHEAD(self):
"""
When not otherwise overridden, L{Resource.render} treats a I{HEAD}
request as if it were a I{GET} request.
"""
expected = b"insert response here"
request = DummyRequest([])
request.method = b'HEAD'
resource = BytesReturnedRenderable(expected)
self.assertEqual(expected, resource.render(request))
def test_explicitAllowedMethods(self):
"""
The L{UnsupportedMethod} raised by L{Resource.render} for an unsupported
request method has a C{allowedMethods} attribute set to the value of the
C{allowedMethods} attribute of the L{Resource}, if it has one.
"""
expected = [b'GET', b'HEAD', b'PUT']
resource = Resource()
resource.allowedMethods = expected
request = DummyRequest([])
request.method = b'FICTIONAL'
exc = self.assertRaises(UnsupportedMethod, resource.render, request)
self.assertEqual(set(expected), set(exc.allowedMethods))
def test_implicitAllowedMethods(self):
"""
The L{UnsupportedMethod} raised by L{Resource.render} for an unsupported
request method has a C{allowedMethods} attribute set to a list of the
methods supported by the L{Resource}, as determined by the
I{render_}-prefixed methods which it defines, if C{allowedMethods} is
not explicitly defined by the L{Resource}.
"""
expected = set([b'GET', b'HEAD', b'PUT'])
resource = ImplicitAllowedMethods()
request = DummyRequest([])
request.method = b'FICTIONAL'
exc = self.assertRaises(UnsupportedMethod, resource.render, request)
self.assertEqual(expected, set(exc.allowedMethods))
class GetChildForRequestTests(TestCase):
"""
Tests for L{getChildForRequest}.
"""
def test_exhaustedPostPath(self):
"""
L{getChildForRequest} returns whatever resource has been reached by the
time the request's C{postpath} is empty.
"""
request = DummyRequest([])
resource = Resource()
result = getChildForRequest(resource, request)
self.assertIdentical(resource, result)
def test_leafResource(self):
"""
L{getChildForRequest} returns the first resource it encounters with a
C{isLeaf} attribute set to C{True}.
"""
request = DummyRequest([b"foo", b"bar"])
resource = Resource()
resource.isLeaf = True
result = getChildForRequest(resource, request)
self.assertIdentical(resource, result)
def test_postPathToPrePath(self):
"""
As path segments from the request are traversed, they are taken from
C{postpath} and put into C{prepath}.
"""
request = DummyRequest([b"foo", b"bar"])
root = Resource()
child = Resource()
child.isLeaf = True
root.putChild(b"foo", child)
self.assertIdentical(child, getChildForRequest(root, request))
self.assertEqual(request.prepath, [b"foo"])
self.assertEqual(request.postpath, [b"bar"])

View file

@ -0,0 +1,70 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.script}.
"""
import os
from twisted.trial.unittest import TestCase
from twisted.web.http import NOT_FOUND
from twisted.web.script import ResourceScriptDirectory, PythonScript
from twisted.web.test._util import _render
from twisted.web.test.test_web import DummyRequest
class ResourceScriptDirectoryTests(TestCase):
"""
Tests for L{ResourceScriptDirectory}.
"""
def test_render(self):
"""
L{ResourceScriptDirectory.render} sets the HTTP response code to I{NOT
FOUND}.
"""
resource = ResourceScriptDirectory(self.mktemp())
request = DummyRequest([''])
d = _render(resource, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
def test_notFoundChild(self):
"""
L{ResourceScriptDirectory.getChild} returns a resource which renders an
response with the HTTP I{NOT FOUND} status code if the indicated child
does not exist as an entry in the directory used to initialized the
L{ResourceScriptDirectory}.
"""
path = self.mktemp()
os.makedirs(path)
resource = ResourceScriptDirectory(path)
request = DummyRequest(['foo'])
child = resource.getChild("foo", request)
d = _render(child, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
class PythonScriptTests(TestCase):
"""
Tests for L{PythonScript}.
"""
def test_notFoundRender(self):
"""
If the source file a L{PythonScript} is initialized with doesn't exist,
L{PythonScript.render} sets the HTTP response code to I{NOT FOUND}.
"""
resource = PythonScript(self.mktemp(), None)
request = DummyRequest([''])
d = _render(resource, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d

View file

@ -0,0 +1,114 @@
#
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
"""Test SOAP support."""
try:
import SOAPpy
except ImportError:
SOAPpy = None
class SOAPPublisher: pass
else:
from twisted.web import soap
SOAPPublisher = soap.SOAPPublisher
from twisted.trial import unittest
from twisted.web import server, error
from twisted.internet import reactor, defer
class Test(SOAPPublisher):
def soap_add(self, a, b):
return a + b
def soap_kwargs(self, a=1, b=2):
return a + b
soap_kwargs.useKeywords=True
def soap_triple(self, string, num):
return [string, num, None]
def soap_struct(self):
return SOAPpy.structType({"a": "c"})
def soap_defer(self, x):
return defer.succeed(x)
def soap_deferFail(self):
return defer.fail(ValueError())
def soap_fail(self):
raise RuntimeError
def soap_deferFault(self):
return defer.fail(ValueError())
def soap_complex(self):
return {"a": ["b", "c", 12, []], "D": "foo"}
def soap_dict(self, map, key):
return map[key]
class SOAPTestCase(unittest.TestCase):
def setUp(self):
self.publisher = Test()
self.p = reactor.listenTCP(0, server.Site(self.publisher),
interface="127.0.0.1")
self.port = self.p.getHost().port
def tearDown(self):
return self.p.stopListening()
def proxy(self):
return soap.Proxy("http://127.0.0.1:%d/" % self.port)
def testResults(self):
inputOutput = [
("add", (2, 3), 5),
("defer", ("a",), "a"),
("dict", ({"a": 1}, "a"), 1),
("triple", ("a", 1), ["a", 1, None])]
dl = []
for meth, args, outp in inputOutput:
d = self.proxy().callRemote(meth, *args)
d.addCallback(self.assertEqual, outp)
dl.append(d)
# SOAPpy kinda blows.
d = self.proxy().callRemote('complex')
d.addCallback(lambda result: result._asdict())
d.addCallback(self.assertEqual, {"a": ["b", "c", 12, []], "D": "foo"})
dl.append(d)
# We now return to our regularly scheduled program, already in progress.
return defer.DeferredList(dl, fireOnOneErrback=True)
def testMethodNotFound(self):
"""
Check that a non existing method return error 500.
"""
d = self.proxy().callRemote('doesntexist')
self.assertFailure(d, error.Error)
def cb(err):
self.assertEqual(int(err.status), 500)
d.addCallback(cb)
return d
def testLookupFunction(self):
"""
Test lookupFunction method on publisher, to see available remote
methods.
"""
self.assertTrue(self.publisher.lookupFunction("add"))
self.assertTrue(self.publisher.lookupFunction("fail"))
self.assertFalse(self.publisher.lookupFunction("foobar"))
if not SOAPpy:
SOAPTestCase.skip = "SOAPpy not installed"

View file

@ -0,0 +1,139 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web._stan} portion of the L{twisted.web.template}
implementation.
"""
from twisted.web.template import Comment, CDATA, CharRef, Tag
from twisted.trial.unittest import TestCase
def proto(*a, **kw):
"""
Produce a new tag for testing.
"""
return Tag('hello')(*a, **kw)
class TestTag(TestCase):
"""
Tests for L{Tag}.
"""
def test_fillSlots(self):
"""
L{Tag.fillSlots} returns self.
"""
tag = proto()
self.assertIdentical(tag, tag.fillSlots(test='test'))
def test_cloneShallow(self):
"""
L{Tag.clone} copies all attributes and children of a tag, including its
render attribute. If the shallow flag is C{False}, that's where it
stops.
"""
innerList = ["inner list"]
tag = proto("How are you", innerList,
hello="world", render="aSampleMethod")
tag.fillSlots(foo='bar')
tag.filename = "foo/bar"
tag.lineNumber = 6
tag.columnNumber = 12
clone = tag.clone(deep=False)
self.assertEqual(clone.attributes['hello'], 'world')
self.assertNotIdentical(clone.attributes, tag.attributes)
self.assertEqual(clone.children, ["How are you", innerList])
self.assertNotIdentical(clone.children, tag.children)
self.assertIdentical(clone.children[1], innerList)
self.assertEqual(tag.slotData, clone.slotData)
self.assertNotIdentical(tag.slotData, clone.slotData)
self.assertEqual(clone.filename, "foo/bar")
self.assertEqual(clone.lineNumber, 6)
self.assertEqual(clone.columnNumber, 12)
self.assertEqual(clone.render, "aSampleMethod")
def test_cloneDeep(self):
"""
L{Tag.clone} copies all attributes and children of a tag, including its
render attribute. In its normal operating mode (where the deep flag is
C{True}, as is the default), it will clone all sub-lists and sub-tags.
"""
innerTag = proto("inner")
innerList = ["inner list"]
tag = proto("How are you", innerTag, innerList,
hello="world", render="aSampleMethod")
tag.fillSlots(foo='bar')
tag.filename = "foo/bar"
tag.lineNumber = 6
tag.columnNumber = 12
clone = tag.clone()
self.assertEqual(clone.attributes['hello'], 'world')
self.assertNotIdentical(clone.attributes, tag.attributes)
self.assertNotIdentical(clone.children, tag.children)
# sanity check
self.assertIdentical(tag.children[1], innerTag)
# clone should have sub-clone
self.assertNotIdentical(clone.children[1], innerTag)
# sanity check
self.assertIdentical(tag.children[2], innerList)
# clone should have sub-clone
self.assertNotIdentical(clone.children[2], innerList)
self.assertEqual(tag.slotData, clone.slotData)
self.assertNotIdentical(tag.slotData, clone.slotData)
self.assertEqual(clone.filename, "foo/bar")
self.assertEqual(clone.lineNumber, 6)
self.assertEqual(clone.columnNumber, 12)
self.assertEqual(clone.render, "aSampleMethod")
def test_clear(self):
"""
L{Tag.clear} removes all children from a tag, but leaves its attributes
in place.
"""
tag = proto("these are", "children", "cool", andSoIs='this-attribute')
tag.clear()
self.assertEqual(tag.children, [])
self.assertEqual(tag.attributes, {'andSoIs': 'this-attribute'})
def test_suffix(self):
"""
L{Tag.__call__} accepts Python keywords with a suffixed underscore as
the DOM attribute of that literal suffix.
"""
proto = Tag('div')
tag = proto()
tag(class_='a')
self.assertEqual(tag.attributes, {'class': 'a'})
def test_commentRepr(self):
"""
L{Comment.__repr__} returns a value which makes it easy to see what's in
the comment.
"""
self.assertEqual(repr(Comment(u"hello there")),
"Comment(u'hello there')")
def test_cdataRepr(self):
"""
L{CDATA.__repr__} returns a value which makes it easy to see what's in
the comment.
"""
self.assertEqual(repr(CDATA(u"test data")),
"CDATA(u'test data')")
def test_charrefRepr(self):
"""
L{CharRef.__repr__} returns a value which makes it easy to see what
character is referred to.
"""
snowman = ord(u"\N{SNOWMAN}")
self.assertEqual(repr(CharRef(snowman)), "CharRef(9731)")

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,196 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.tap}.
"""
import os, stat
from twisted.python.usage import UsageError
from twisted.python.filepath import FilePath
from twisted.internet.interfaces import IReactorUNIX
from twisted.internet import reactor
from twisted.python.threadpool import ThreadPool
from twisted.trial.unittest import TestCase
from twisted.application import strports
from twisted.web.server import Site
from twisted.web.static import Data, File
from twisted.web.distrib import ResourcePublisher, UserDirectory
from twisted.web.wsgi import WSGIResource
from twisted.web.tap import Options, makePersonalServerFactory, makeService
from twisted.web.twcgi import CGIScript
from twisted.web.script import PythonScript
from twisted.spread.pb import PBServerFactory
application = object()
class ServiceTests(TestCase):
"""
Tests for the service creation APIs in L{twisted.web.tap}.
"""
def _pathOption(self):
"""
Helper for the I{--path} tests which creates a directory and creates
an L{Options} object which uses that directory as its static
filesystem root.
@return: A two-tuple of a L{FilePath} referring to the directory and
the value associated with the C{'root'} key in the L{Options}
instance after parsing a I{--path} option.
"""
path = FilePath(self.mktemp())
path.makedirs()
options = Options()
options.parseOptions(['--path', path.path])
root = options['root']
return path, root
def test_path(self):
"""
The I{--path} option causes L{Options} to create a root resource
which serves responses from the specified path.
"""
path, root = self._pathOption()
self.assertIsInstance(root, File)
self.assertEqual(root.path, path.path)
def test_cgiProcessor(self):
"""
The I{--path} option creates a root resource which serves a
L{CGIScript} instance for any child with the C{".cgi"} extension.
"""
path, root = self._pathOption()
path.child("foo.cgi").setContent("")
self.assertIsInstance(root.getChild("foo.cgi", None), CGIScript)
def test_epyProcessor(self):
"""
The I{--path} option creates a root resource which serves a
L{PythonScript} instance for any child with the C{".epy"} extension.
"""
path, root = self._pathOption()
path.child("foo.epy").setContent("")
self.assertIsInstance(root.getChild("foo.epy", None), PythonScript)
def test_rpyProcessor(self):
"""
The I{--path} option creates a root resource which serves the
C{resource} global defined by the Python source in any child with
the C{".rpy"} extension.
"""
path, root = self._pathOption()
path.child("foo.rpy").setContent(
"from twisted.web.static import Data\n"
"resource = Data('content', 'major/minor')\n")
child = root.getChild("foo.rpy", None)
self.assertIsInstance(child, Data)
self.assertEqual(child.data, 'content')
self.assertEqual(child.type, 'major/minor')
def test_makePersonalServerFactory(self):
"""
L{makePersonalServerFactory} returns a PB server factory which has
as its root object a L{ResourcePublisher}.
"""
# The fact that this pile of objects can actually be used somehow is
# verified by twisted.web.test.test_distrib.
site = Site(Data("foo bar", "text/plain"))
serverFactory = makePersonalServerFactory(site)
self.assertIsInstance(serverFactory, PBServerFactory)
self.assertIsInstance(serverFactory.root, ResourcePublisher)
self.assertIdentical(serverFactory.root.site, site)
def test_personalServer(self):
"""
The I{--personal} option to L{makeService} causes it to return a
service which will listen on the server address given by the I{--port}
option.
"""
port = self.mktemp()
options = Options()
options.parseOptions(['--port', 'unix:' + port, '--personal'])
service = makeService(options)
service.startService()
self.addCleanup(service.stopService)
self.assertTrue(os.path.exists(port))
self.assertTrue(stat.S_ISSOCK(os.stat(port).st_mode))
if not IReactorUNIX.providedBy(reactor):
test_personalServer.skip = (
"The reactor does not support UNIX domain sockets")
def test_defaultPersonalPath(self):
"""
If the I{--port} option not specified but the I{--personal} option is,
L{Options} defaults the port to C{UserDirectory.userSocketName} in the
user's home directory.
"""
options = Options()
options.parseOptions(['--personal'])
path = os.path.expanduser(
os.path.join('~', UserDirectory.userSocketName))
self.assertEqual(
strports.parse(options['port'], None)[:2],
('UNIX', (path, None)))
if not IReactorUNIX.providedBy(reactor):
test_defaultPersonalPath.skip = (
"The reactor does not support UNIX domain sockets")
def test_defaultPort(self):
"""
If the I{--port} option is not specified, L{Options} defaults the port
to C{8080}.
"""
options = Options()
options.parseOptions([])
self.assertEqual(
strports.parse(options['port'], None)[:2],
('TCP', (8080, None)))
def test_wsgi(self):
"""
The I{--wsgi} option takes the fully-qualifed Python name of a WSGI
application object and creates a L{WSGIResource} at the root which
serves that application.
"""
options = Options()
options.parseOptions(['--wsgi', __name__ + '.application'])
root = options['root']
self.assertTrue(root, WSGIResource)
self.assertIdentical(root._reactor, reactor)
self.assertTrue(isinstance(root._threadpool, ThreadPool))
self.assertIdentical(root._application, application)
# The threadpool should start and stop with the reactor.
self.assertFalse(root._threadpool.started)
reactor.fireSystemEvent('startup')
self.assertTrue(root._threadpool.started)
self.assertFalse(root._threadpool.joined)
reactor.fireSystemEvent('shutdown')
self.assertTrue(root._threadpool.joined)
def test_invalidApplication(self):
"""
If I{--wsgi} is given an invalid name, L{Options.parseOptions}
raises L{UsageError}.
"""
options = Options()
for name in [__name__ + '.nosuchthing', 'foo.']:
exc = self.assertRaises(
UsageError, options.parseOptions, ['--wsgi', name])
self.assertEqual(str(exc), "No such WSGI application: %r" % (name,))

View file

@ -0,0 +1,820 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.template}
"""
from cStringIO import StringIO
from zope.interface.verify import verifyObject
from twisted.internet.defer import succeed, gatherResults
from twisted.python.filepath import FilePath
from twisted.trial.unittest import TestCase
from twisted.trial.util import suppress as SUPPRESS
from twisted.web.template import (
Element, TagLoader, renderer, tags, XMLFile, XMLString)
from twisted.web.iweb import ITemplateLoader
from twisted.web.error import (FlattenerError, MissingTemplateLoader,
MissingRenderMethod)
from twisted.web.template import renderElement
from twisted.web._element import UnexposedMethodError
from twisted.web.test._util import FlattenTestCase
from twisted.web.test.test_web import DummyRequest
from twisted.web.server import NOT_DONE_YET
_xmlFileSuppress = SUPPRESS(category=DeprecationWarning,
message="Passing filenames or file objects to XMLFile is "
"deprecated since Twisted 12.1. Pass a FilePath instead.")
class TagFactoryTests(TestCase):
"""
Tests for L{_TagFactory} through the publicly-exposed L{tags} object.
"""
def test_lookupTag(self):
"""
HTML tags can be retrieved through C{tags}.
"""
tag = tags.a
self.assertEqual(tag.tagName, "a")
def test_lookupHTML5Tag(self):
"""
Twisted supports the latest and greatest HTML tags from the HTML5
specification.
"""
tag = tags.video
self.assertEqual(tag.tagName, "video")
def test_lookupTransparentTag(self):
"""
To support transparent inclusion in templates, there is a special tag,
the transparent tag, which has no name of its own but is accessed
through the "transparent" attribute.
"""
tag = tags.transparent
self.assertEqual(tag.tagName, "")
def test_lookupInvalidTag(self):
"""
Invalid tags which are not part of HTML cause AttributeErrors when
accessed through C{tags}.
"""
self.assertRaises(AttributeError, getattr, tags, "invalid")
def test_lookupXMP(self):
"""
As a special case, the <xmp> tag is simply not available through
C{tags} or any other part of the templating machinery.
"""
self.assertRaises(AttributeError, getattr, tags, "xmp")
class ElementTests(TestCase):
"""
Tests for the awesome new L{Element} class.
"""
def test_missingTemplateLoader(self):
"""
L{Element.render} raises L{MissingTemplateLoader} if the C{loader}
attribute is C{None}.
"""
element = Element()
err = self.assertRaises(MissingTemplateLoader, element.render, None)
self.assertIdentical(err.element, element)
def test_missingTemplateLoaderRepr(self):
"""
A L{MissingTemplateLoader} instance can be repr()'d without error.
"""
class PrettyReprElement(Element):
def __repr__(self):
return 'Pretty Repr Element'
self.assertIn('Pretty Repr Element',
repr(MissingTemplateLoader(PrettyReprElement())))
def test_missingRendererMethod(self):
"""
When called with the name which is not associated with a render method,
L{Element.lookupRenderMethod} raises L{MissingRenderMethod}.
"""
element = Element()
err = self.assertRaises(
MissingRenderMethod, element.lookupRenderMethod, "foo")
self.assertIdentical(err.element, element)
self.assertEqual(err.renderName, "foo")
def test_missingRenderMethodRepr(self):
"""
A L{MissingRenderMethod} instance can be repr()'d without error.
"""
class PrettyReprElement(Element):
def __repr__(self):
return 'Pretty Repr Element'
s = repr(MissingRenderMethod(PrettyReprElement(),
'expectedMethod'))
self.assertIn('Pretty Repr Element', s)
self.assertIn('expectedMethod', s)
def test_definedRenderer(self):
"""
When called with the name of a defined render method,
L{Element.lookupRenderMethod} returns that render method.
"""
class ElementWithRenderMethod(Element):
@renderer
def foo(self, request, tag):
return "bar"
foo = ElementWithRenderMethod().lookupRenderMethod("foo")
self.assertEqual(foo(None, None), "bar")
def test_render(self):
"""
L{Element.render} loads a document from the C{loader} attribute and
returns it.
"""
class TemplateLoader(object):
def load(self):
return "result"
class StubElement(Element):
loader = TemplateLoader()
element = StubElement()
self.assertEqual(element.render(None), "result")
def test_misuseRenderer(self):
"""
If the L{renderer} decorator is called without any arguments, it will
raise a comprehensible exception.
"""
te = self.assertRaises(TypeError, renderer)
self.assertEqual(str(te),
"expose() takes at least 1 argument (0 given)")
def test_renderGetDirectlyError(self):
"""
Called directly, without a default, L{renderer.get} raises
L{UnexposedMethodError} when it cannot find a renderer.
"""
self.assertRaises(UnexposedMethodError, renderer.get, None,
"notARenderer")
class XMLFileReprTests(TestCase):
"""
Tests for L{twisted.web.template.XMLFile}'s C{__repr__}.
"""
def test_filePath(self):
"""
An L{XMLFile} with a L{FilePath} returns a useful repr().
"""
path = FilePath("/tmp/fake.xml")
self.assertEqual('<XMLFile of %r>' % (path,), repr(XMLFile(path)))
def test_filename(self):
"""
An L{XMLFile} with a filename returns a useful repr().
"""
fname = "/tmp/fake.xml"
self.assertEqual('<XMLFile of %r>' % (fname,), repr(XMLFile(fname)))
test_filename.suppress = [_xmlFileSuppress]
def test_file(self):
"""
An L{XMLFile} with a file object returns a useful repr().
"""
fobj = StringIO("not xml")
self.assertEqual('<XMLFile of %r>' % (fobj,), repr(XMLFile(fobj)))
test_file.suppress = [_xmlFileSuppress]
class XMLLoaderTestsMixin(object):
"""
@ivar templateString: Simple template to use to excercise the loaders.
@ivar deprecatedUse: C{True} if this use of L{XMLFile} is deprecated and
should emit a C{DeprecationWarning}.
"""
loaderFactory = None
templateString = '<p>Hello, world.</p>'
def test_load(self):
"""
Verify that the loader returns a tag with the correct children.
"""
loader = self.loaderFactory()
tag, = loader.load()
warnings = self.flushWarnings(offendingFunctions=[self.loaderFactory])
if self.deprecatedUse:
self.assertEqual(len(warnings), 1)
self.assertEqual(warnings[0]['category'], DeprecationWarning)
self.assertEqual(
warnings[0]['message'],
"Passing filenames or file objects to XMLFile is "
"deprecated since Twisted 12.1. Pass a FilePath instead.")
else:
self.assertEqual(len(warnings), 0)
self.assertEqual(tag.tagName, 'p')
self.assertEqual(tag.children, [u'Hello, world.'])
def test_loadTwice(self):
"""
If {load()} can be called on a loader twice the result should be the
same.
"""
loader = self.loaderFactory()
tags1 = loader.load()
tags2 = loader.load()
self.assertEqual(tags1, tags2)
test_loadTwice.suppress = [_xmlFileSuppress]
class XMLStringLoaderTests(TestCase, XMLLoaderTestsMixin):
"""
Tests for L{twisted.web.template.XMLString}
"""
deprecatedUse = False
def loaderFactory(self):
"""
@return: an L{XMLString} constructed with C{self.templateString}.
"""
return XMLString(self.templateString)
class XMLFileWithFilePathTests(TestCase, XMLLoaderTestsMixin):
"""
Tests for L{twisted.web.template.XMLFile}'s L{FilePath} support.
"""
deprecatedUse = False
def loaderFactory(self):
"""
@return: an L{XMLString} constructed with a L{FilePath} pointing to a
file that contains C{self.templateString}.
"""
fp = FilePath(self.mktemp())
fp.setContent(self.templateString)
return XMLFile(fp)
class XMLFileWithFileTests(TestCase, XMLLoaderTestsMixin):
"""
Tests for L{twisted.web.template.XMLFile}'s deprecated file object support.
"""
deprecatedUse = True
def loaderFactory(self):
"""
@return: an L{XMLString} constructed with a file object that contains
C{self.templateString}.
"""
return XMLFile(StringIO(self.templateString))
class XMLFileWithFilenameTests(TestCase, XMLLoaderTestsMixin):
"""
Tests for L{twisted.web.template.XMLFile}'s deprecated filename support.
"""
deprecatedUse = True
def loaderFactory(self):
"""
@return: an L{XMLString} constructed with a filename that points to a
file containing C{self.templateString}.
"""
fp = FilePath(self.mktemp())
fp.setContent(self.templateString)
return XMLFile(fp.path)
class FlattenIntegrationTests(FlattenTestCase):
"""
Tests for integration between L{Element} and
L{twisted.web._flatten.flatten}.
"""
def test_roundTrip(self):
"""
Given a series of parsable XML strings, verify that
L{twisted.web._flatten.flatten} will flatten the L{Element} back to the
input when sent on a round trip.
"""
fragments = [
"<p>Hello, world.</p>",
"<p><!-- hello, world --></p>",
"<p><![CDATA[Hello, world.]]></p>",
'<test1 xmlns:test2="urn:test2">'
'<test2:test3></test2:test3></test1>',
'<test1 xmlns="urn:test2"><test3></test3></test1>',
'<p>\xe2\x98\x83</p>',
]
deferreds = [
self.assertFlattensTo(Element(loader=XMLString(xml)), xml)
for xml in fragments]
return gatherResults(deferreds)
def test_entityConversion(self):
"""
When flattening an HTML entity, it should flatten out to the utf-8
representation if possible.
"""
element = Element(loader=XMLString('<p>&#9731;</p>'))
return self.assertFlattensTo(element, '<p>\xe2\x98\x83</p>')
def test_missingTemplateLoader(self):
"""
Rendering a Element without a loader attribute raises the appropriate
exception.
"""
return self.assertFlatteningRaises(Element(), MissingTemplateLoader)
def test_missingRenderMethod(self):
"""
Flattening an L{Element} with a C{loader} which has a tag with a render
directive fails with L{FlattenerError} if there is no available render
method to satisfy that directive.
"""
element = Element(loader=XMLString("""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="unknownMethod" />
"""))
return self.assertFlatteningRaises(element, MissingRenderMethod)
def test_transparentRendering(self):
"""
A C{transparent} element should be eliminated from the DOM and rendered as
only its children.
"""
element = Element(loader=XMLString(
'<t:transparent '
'xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">'
'Hello, world.'
'</t:transparent>'
))
return self.assertFlattensTo(element, "Hello, world.")
def test_attrRendering(self):
"""
An Element with an attr tag renders the vaule of its attr tag as an
attribute of its containing tag.
"""
element = Element(loader=XMLString(
'<a xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">'
'<t:attr name="href">http://example.com</t:attr>'
'Hello, world.'
'</a>'
))
return self.assertFlattensTo(element,
'<a href="http://example.com">Hello, world.</a>')
def test_errorToplevelAttr(self):
"""
A template with a toplevel C{attr} tag will not load; it will raise
L{AssertionError} if you try.
"""
self.assertRaises(
AssertionError,
XMLString,
"""<t:attr
xmlns:t='http://twistedmatrix.com/ns/twisted.web.template/0.1'
name='something'
>hello</t:attr>
""")
def test_errorUnnamedAttr(self):
"""
A template with an C{attr} tag with no C{name} attribute will not load;
it will raise L{AssertionError} if you try.
"""
self.assertRaises(
AssertionError,
XMLString,
"""<html><t:attr
xmlns:t='http://twistedmatrix.com/ns/twisted.web.template/0.1'
>hello</t:attr></html>""")
def test_lenientPrefixBehavior(self):
"""
If the parser sees a prefix it doesn't recognize on an attribute, it
will pass it on through to serialization.
"""
theInput = (
'<hello:world hello:sample="testing" '
'xmlns:hello="http://made-up.example.com/ns/not-real">'
'This is a made-up tag.</hello:world>')
element = Element(loader=XMLString(theInput))
self.assertFlattensTo(element, theInput)
def test_deferredRendering(self):
"""
An Element with a render method which returns a Deferred will render
correctly.
"""
class RenderfulElement(Element):
@renderer
def renderMethod(self, request, tag):
return succeed("Hello, world.")
element = RenderfulElement(loader=XMLString("""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="renderMethod">
Goodbye, world.
</p>
"""))
return self.assertFlattensTo(element, "Hello, world.")
def test_loaderClassAttribute(self):
"""
If there is a non-None loader attribute on the class of an Element
instance but none on the instance itself, the class attribute is used.
"""
class SubElement(Element):
loader = XMLString("<p>Hello, world.</p>")
return self.assertFlattensTo(SubElement(), "<p>Hello, world.</p>")
def test_directiveRendering(self):
"""
An Element with a valid render directive has that directive invoked and
the result added to the output.
"""
renders = []
class RenderfulElement(Element):
@renderer
def renderMethod(self, request, tag):
renders.append((self, request))
return tag("Hello, world.")
element = RenderfulElement(loader=XMLString("""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="renderMethod" />
"""))
return self.assertFlattensTo(element, "<p>Hello, world.</p>")
def test_directiveRenderingOmittingTag(self):
"""
An Element with a render method which omits the containing tag
successfully removes that tag from the output.
"""
class RenderfulElement(Element):
@renderer
def renderMethod(self, request, tag):
return "Hello, world."
element = RenderfulElement(loader=XMLString("""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="renderMethod">
Goodbye, world.
</p>
"""))
return self.assertFlattensTo(element, "Hello, world.")
def test_elementContainingStaticElement(self):
"""
An Element which is returned by the render method of another Element is
rendered properly.
"""
class RenderfulElement(Element):
@renderer
def renderMethod(self, request, tag):
return tag(Element(
loader=XMLString("<em>Hello, world.</em>")))
element = RenderfulElement(loader=XMLString("""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="renderMethod" />
"""))
return self.assertFlattensTo(element, "<p><em>Hello, world.</em></p>")
def test_elementUsingSlots(self):
"""
An Element which is returned by the render method of another Element is
rendered properly.
"""
class RenderfulElement(Element):
@renderer
def renderMethod(self, request, tag):
return tag.fillSlots(test2='world.')
element = RenderfulElement(loader=XMLString(
'<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"'
' t:render="renderMethod">'
'<t:slot name="test1" default="Hello, " />'
'<t:slot name="test2" />'
'</p>'
))
return self.assertFlattensTo(element, "<p>Hello, world.</p>")
def test_elementContainingDynamicElement(self):
"""
Directives in the document factory of a Element returned from a render
method of another Element are satisfied from the correct object: the
"inner" Element.
"""
class OuterElement(Element):
@renderer
def outerMethod(self, request, tag):
return tag(InnerElement(loader=XMLString("""
<t:ignored
xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="innerMethod" />
""")))
class InnerElement(Element):
@renderer
def innerMethod(self, request, tag):
return "Hello, world."
element = OuterElement(loader=XMLString("""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="outerMethod" />
"""))
return self.assertFlattensTo(element, "<p>Hello, world.</p>")
def test_sameLoaderTwice(self):
"""
Rendering the output of a loader, or even the same element, should
return different output each time.
"""
sharedLoader = XMLString(
'<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">'
'<t:transparent t:render="classCounter" /> '
'<t:transparent t:render="instanceCounter" />'
'</p>')
class DestructiveElement(Element):
count = 0
instanceCount = 0
loader = sharedLoader
@renderer
def classCounter(self, request, tag):
DestructiveElement.count += 1
return tag(str(DestructiveElement.count))
@renderer
def instanceCounter(self, request, tag):
self.instanceCount += 1
return tag(str(self.instanceCount))
e1 = DestructiveElement()
e2 = DestructiveElement()
self.assertFlattensImmediately(e1, "<p>1 1</p>")
self.assertFlattensImmediately(e1, "<p>2 2</p>")
self.assertFlattensImmediately(e2, "<p>3 1</p>")
class TagLoaderTests(FlattenTestCase):
"""
Tests for L{TagLoader}.
"""
def setUp(self):
self.loader = TagLoader(tags.i('test'))
def test_interface(self):
"""
An instance of L{TagLoader} provides L{ITemplateLoader}.
"""
self.assertTrue(verifyObject(ITemplateLoader, self.loader))
def test_loadsList(self):
"""
L{TagLoader.load} returns a list, per L{ITemplateLoader}.
"""
self.assertIsInstance(self.loader.load(), list)
def test_flatten(self):
"""
L{TagLoader} can be used in an L{Element}, and flattens as the tag used
to construct the L{TagLoader} would flatten.
"""
e = Element(self.loader)
self.assertFlattensImmediately(e, '<i>test</i>')
class TestElement(Element):
"""
An L{Element} that can be rendered successfully.
"""
loader = XMLString(
'<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">'
'Hello, world.'
'</p>')
class TestFailureElement(Element):
"""
An L{Element} that can be used in place of L{FailureElement} to verify
that L{renderElement} can render failures properly.
"""
loader = XMLString(
'<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">'
'I failed.'
'</p>')
def __init__(self, failure, loader=None):
self.failure = failure
class FailingElement(Element):
"""
An element that raises an exception when rendered.
"""
def render(self, request):
a = 42
b = 0
return a // b
class FakeSite(object):
"""
A minimal L{Site} object that we can use to test displayTracebacks
"""
displayTracebacks = False
class TestRenderElement(TestCase):
"""
Test L{renderElement}
"""
def setUp(self):
"""
Set up a common L{DummyRequest} and L{FakeSite}.
"""
self.request = DummyRequest([""])
self.request.site = FakeSite()
def test_simpleRender(self):
"""
L{renderElement} returns NOT_DONE_YET and eventually
writes the rendered L{Element} to the request before finishing the
request.
"""
element = TestElement()
d = self.request.notifyFinish()
def check(_):
self.assertEqual(
"".join(self.request.written),
"<!DOCTYPE html>\n"
"<p>Hello, world.</p>")
self.assertTrue(self.request.finished)
d.addCallback(check)
self.assertIdentical(NOT_DONE_YET, renderElement(self.request, element))
return d
def test_simpleFailure(self):
"""
L{renderElement} handles failures by writing a minimal
error message to the request and finishing it.
"""
element = FailingElement()
d = self.request.notifyFinish()
def check(_):
flushed = self.flushLoggedErrors(FlattenerError)
self.assertEqual(len(flushed), 1)
self.assertEqual(
"".join(self.request.written),
('<!DOCTYPE html>\n'
'<div style="font-size:800%;'
'background-color:#FFF;'
'color:#F00'
'">An error occurred while rendering the response.</div>'))
self.assertTrue(self.request.finished)
d.addCallback(check)
self.assertIdentical(NOT_DONE_YET, renderElement(self.request, element))
return d
def test_simpleFailureWithTraceback(self):
"""
L{renderElement} will render a traceback when rendering of
the element fails and our site is configured to display tracebacks.
"""
self.request.site.displayTracebacks = True
element = FailingElement()
d = self.request.notifyFinish()
def check(_):
flushed = self.flushLoggedErrors(FlattenerError)
self.assertEqual(len(flushed), 1)
self.assertEqual(
"".join(self.request.written),
"<!DOCTYPE html>\n<p>I failed.</p>")
self.assertTrue(self.request.finished)
d.addCallback(check)
renderElement(self.request, element, _failElement=TestFailureElement)
return d
def test_nonDefaultDoctype(self):
"""
L{renderElement} will write the doctype string specified by the
doctype keyword argument.
"""
element = TestElement()
d = self.request.notifyFinish()
def check(_):
self.assertEqual(
"".join(self.request.written),
('<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN"'
' "http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd">\n'
'<p>Hello, world.</p>'))
d.addCallback(check)
renderElement(
self.request,
element,
doctype=(
'<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN"'
' "http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd">'))
return d
def test_noneDoctype(self):
"""
L{renderElement} will not write out a doctype if the doctype keyword
argument is C{None}.
"""
element = TestElement()
d = self.request.notifyFinish()
def check(_):
self.assertEqual(
"".join(self.request.written),
'<p>Hello, world.</p>')
d.addCallback(check)
renderElement(self.request, element, doctype=None)
return d

View file

@ -0,0 +1,472 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.util}.
"""
from twisted.python.failure import Failure
from twisted.trial.unittest import TestCase
from twisted.internet import defer
from twisted.web import util
from twisted.web.error import FlattenerError
from twisted.web.util import (
redirectTo, _SourceLineElement,
_SourceFragmentElement, _FrameElement, _StackElement,
FailureElement, formatFailure, DeferredResource, htmlIndent)
from twisted.web.http import FOUND
from twisted.web.server import Request
from twisted.web.template import TagLoader, flattenString, tags
from twisted.web import resource
from twisted.web.test.requesthelper import DummyChannel, DummyRequest
class RedirectToTestCase(TestCase):
"""
Tests for L{redirectTo}.
"""
def test_headersAndCode(self):
"""
L{redirectTo} will set the C{Location} and C{Content-Type} headers on
its request, and set the response code to C{FOUND}, so the browser will
be redirected.
"""
request = Request(DummyChannel(), True)
request.method = 'GET'
targetURL = "http://target.example.com/4321"
redirectTo(targetURL, request)
self.assertEqual(request.code, FOUND)
self.assertEqual(
request.responseHeaders.getRawHeaders('location'), [targetURL])
self.assertEqual(
request.responseHeaders.getRawHeaders('content-type'),
['text/html; charset=utf-8'])
def test_redirectToUnicodeURL(self) :
"""
L{redirectTo} will raise TypeError if unicode object is passed in URL
"""
request = Request(DummyChannel(), True)
request.method = 'GET'
targetURL = u'http://target.example.com/4321'
self.assertRaises(TypeError, redirectTo, targetURL, request)
class FailureElementTests(TestCase):
"""
Tests for L{FailureElement} and related helpers which can render a
L{Failure} as an HTML string.
"""
def setUp(self):
"""
Create a L{Failure} which can be used by the rendering tests.
"""
def lineNumberProbeAlsoBroken():
message = "This is a problem"
raise Exception(message)
# Figure out the line number from which the exception will be raised.
self.base = lineNumberProbeAlsoBroken.func_code.co_firstlineno + 1
try:
lineNumberProbeAlsoBroken()
except:
self.failure = Failure(captureVars=True)
self.frame = self.failure.frames[-1]
def test_sourceLineElement(self):
"""
L{_SourceLineElement} renders a source line and line number.
"""
element = _SourceLineElement(
TagLoader(tags.div(
tags.span(render="lineNumber"),
tags.span(render="sourceLine"))),
50, " print 'hello'")
d = flattenString(None, element)
expected = (
u"<div><span>50</span><span>"
u" \N{NO-BREAK SPACE} \N{NO-BREAK SPACE}print 'hello'</span></div>")
d.addCallback(
self.assertEqual, expected.encode('utf-8'))
return d
def test_sourceFragmentElement(self):
"""
L{_SourceFragmentElement} renders source lines at and around the line
number indicated by a frame object.
"""
element = _SourceFragmentElement(
TagLoader(tags.div(
tags.span(render="lineNumber"),
tags.span(render="sourceLine"),
render="sourceLines")),
self.frame)
source = [
u' \N{NO-BREAK SPACE} \N{NO-BREAK SPACE}message = '
u'"This is a problem"',
u' \N{NO-BREAK SPACE} \N{NO-BREAK SPACE}raise Exception(message)',
u'# Figure out the line number from which the exception will be '
u'raised.',
]
d = flattenString(None, element)
d.addCallback(
self.assertEqual,
''.join([
'<div class="snippet%sLine"><span>%d</span><span>%s</span>'
'</div>' % (
["", "Highlight"][lineNumber == 1],
self.base + lineNumber,
(u" \N{NO-BREAK SPACE}" * 4 + sourceLine).encode(
'utf-8'))
for (lineNumber, sourceLine)
in enumerate(source)]))
return d
def test_frameElementFilename(self):
"""
The I{filename} renderer of L{_FrameElement} renders the filename
associated with the frame object used to initialize the
L{_FrameElement}.
"""
element = _FrameElement(
TagLoader(tags.span(render="filename")),
self.frame)
d = flattenString(None, element)
d.addCallback(
# __file__ differs depending on whether an up-to-date .pyc file
# already existed.
self.assertEqual, "<span>" + __file__.rstrip('c') + "</span>")
return d
def test_frameElementLineNumber(self):
"""
The I{lineNumber} renderer of L{_FrameElement} renders the line number
associated with the frame object used to initialize the
L{_FrameElement}.
"""
element = _FrameElement(
TagLoader(tags.span(render="lineNumber")),
self.frame)
d = flattenString(None, element)
d.addCallback(
self.assertEqual, "<span>" + str(self.base + 1) + "</span>")
return d
def test_frameElementFunction(self):
"""
The I{function} renderer of L{_FrameElement} renders the line number
associated with the frame object used to initialize the
L{_FrameElement}.
"""
element = _FrameElement(
TagLoader(tags.span(render="function")),
self.frame)
d = flattenString(None, element)
d.addCallback(
self.assertEqual, "<span>lineNumberProbeAlsoBroken</span>")
return d
def test_frameElementSource(self):
"""
The I{source} renderer of L{_FrameElement} renders the source code near
the source filename/line number associated with the frame object used to
initialize the L{_FrameElement}.
"""
element = _FrameElement(None, self.frame)
renderer = element.lookupRenderMethod("source")
tag = tags.div()
result = renderer(None, tag)
self.assertIsInstance(result, _SourceFragmentElement)
self.assertIdentical(result.frame, self.frame)
self.assertEqual([tag], result.loader.load())
def test_stackElement(self):
"""
The I{frames} renderer of L{_StackElement} renders each stack frame in
the list of frames used to initialize the L{_StackElement}.
"""
element = _StackElement(None, self.failure.frames[:2])
renderer = element.lookupRenderMethod("frames")
tag = tags.div()
result = renderer(None, tag)
self.assertIsInstance(result, list)
self.assertIsInstance(result[0], _FrameElement)
self.assertIdentical(result[0].frame, self.failure.frames[0])
self.assertIsInstance(result[1], _FrameElement)
self.assertIdentical(result[1].frame, self.failure.frames[1])
# They must not share the same tag object.
self.assertNotEqual(result[0].loader.load(), result[1].loader.load())
self.assertEqual(2, len(result))
def test_failureElementTraceback(self):
"""
The I{traceback} renderer of L{FailureElement} renders the failure's
stack frames using L{_StackElement}.
"""
element = FailureElement(self.failure)
renderer = element.lookupRenderMethod("traceback")
tag = tags.div()
result = renderer(None, tag)
self.assertIsInstance(result, _StackElement)
self.assertIdentical(result.stackFrames, self.failure.frames)
self.assertEqual([tag], result.loader.load())
def test_failureElementType(self):
"""
The I{type} renderer of L{FailureElement} renders the failure's
exception type.
"""
element = FailureElement(
self.failure, TagLoader(tags.span(render="type")))
d = flattenString(None, element)
d.addCallback(
self.assertEqual, "<span>exceptions.Exception</span>")
return d
def test_failureElementValue(self):
"""
The I{value} renderer of L{FailureElement} renders the value's exception
value.
"""
element = FailureElement(
self.failure, TagLoader(tags.span(render="value")))
d = flattenString(None, element)
d.addCallback(
self.assertEqual, '<span>This is a problem</span>')
return d
class FormatFailureTests(TestCase):
"""
Tests for L{twisted.web.util.formatFailure} which returns an HTML string
representing the L{Failure} instance passed to it.
"""
def test_flattenerError(self):
"""
If there is an error flattening the L{Failure} instance,
L{formatFailure} raises L{FlattenerError}.
"""
self.assertRaises(FlattenerError, formatFailure, object())
def test_returnsBytes(self):
"""
The return value of L{formatFailure} is a C{str} instance (not a
C{unicode} instance) with numeric character references for any non-ASCII
characters meant to appear in the output.
"""
try:
raise Exception("Fake bug")
except:
result = formatFailure(Failure())
self.assertIsInstance(result, str)
self.assertTrue(all(ord(ch) < 128 for ch in result))
# Indentation happens to rely on NO-BREAK SPACE
self.assertIn("&#160;", result)
class DeprecatedHTMLHelpers(TestCase):
"""
The various HTML generation helper APIs in L{twisted.web.util} are
deprecated.
"""
def _htmlHelperDeprecationTest(self, functionName):
"""
Helper method which asserts that using the name indicated by
C{functionName} from the L{twisted.web.util} module emits a deprecation
warning.
"""
getattr(util, functionName)
warnings = self.flushWarnings([self._htmlHelperDeprecationTest])
self.assertEqual(warnings[0]['category'], DeprecationWarning)
self.assertEqual(
warnings[0]['message'],
"twisted.web.util.%s was deprecated in Twisted 12.1.0: "
"See twisted.web.template." % (functionName,))
def test_htmlrepr(self):
"""
L{twisted.web.util.htmlrepr} is deprecated.
"""
self._htmlHelperDeprecationTest("htmlrepr")
def test_saferepr(self):
"""
L{twisted.web.util.saferepr} is deprecated.
"""
self._htmlHelperDeprecationTest("saferepr")
def test_htmlUnknown(self):
"""
L{twisted.web.util.htmlUnknown} is deprecated.
"""
self._htmlHelperDeprecationTest("htmlUnknown")
def test_htmlDict(self):
"""
L{twisted.web.util.htmlDict} is deprecated.
"""
self._htmlHelperDeprecationTest("htmlDict")
def test_htmlList(self):
"""
L{twisted.web.util.htmlList} is deprecated.
"""
self._htmlHelperDeprecationTest("htmlList")
def test_htmlInst(self):
"""
L{twisted.web.util.htmlInst} is deprecated.
"""
self._htmlHelperDeprecationTest("htmlInst")
def test_htmlString(self):
"""
L{twisted.web.util.htmlString} is deprecated.
"""
self._htmlHelperDeprecationTest("htmlString")
def test_htmlIndent(self):
"""
L{twisted.web.util.htmlIndent} is deprecated.
"""
self._htmlHelperDeprecationTest("htmlIndent")
def test_htmlFunc(self):
"""
L{twisted.web.util.htmlFunc} is deprecated.
"""
self._htmlHelperDeprecationTest("htmlFunc")
def test_htmlReprTypes(self):
"""
L{twisted.web.util.htmlReprTypes} is deprecated.
"""
self._htmlHelperDeprecationTest("htmlReprTypes")
def test_stylesheet(self):
"""
L{twisted.web.util.stylesheet} is deprecated.
"""
self._htmlHelperDeprecationTest("stylesheet")
class SDResource(resource.Resource):
def __init__(self,default):
self.default = default
def getChildWithDefault(self, name, request):
d = defer.succeed(self.default)
resource = util.DeferredResource(d)
return resource.getChildWithDefault(name, request)
class DeferredResourceTests(TestCase):
"""
Tests for L{DeferredResource}.
"""
def testDeferredResource(self):
r = resource.Resource()
r.isLeaf = 1
s = SDResource(r)
d = DummyRequest(['foo', 'bar', 'baz'])
resource.getChildForRequest(s, d)
self.assertEqual(d.postpath, ['bar', 'baz'])
def test_render(self):
"""
L{DeferredResource} uses the request object's C{render} method to
render the resource which is the result of the L{Deferred} being
handled.
"""
rendered = []
request = DummyRequest([])
request.render = rendered.append
result = resource.Resource()
deferredResource = DeferredResource(defer.succeed(result))
deferredResource.render(request)
self.assertEqual(rendered, [result])
class HtmlIndentTests(TestCase):
"""
Tests for L{htmlIndent}
"""
def test_simpleInput(self):
"""
L{htmlIndent} transparently processes input with no special cases
inside.
"""
line = "foo bar"
self.assertEqual(line, htmlIndent(line))
def test_escapeHtml(self):
"""
L{htmlIndent} escapes HTML from its input.
"""
line = "<br />"
self.assertEqual("&lt;br /&gt;", htmlIndent(line))
def test_stripTrailingWhitespace(self):
"""
L{htmlIndent} removes trailing whitespaces from its input.
"""
line = " foo bar "
self.assertEqual(" foo bar", htmlIndent(line))
def test_forceSpacingFromSpaceCharacters(self):
"""
If L{htmlIndent} detects consecutive space characters, it forces the
rendering by substituting unbreakable space.
"""
line = " foo bar"
self.assertEqual("&nbsp;foo&nbsp;bar", htmlIndent(line))
def test_indentFromTabCharacters(self):
"""
L{htmlIndent} replaces tab characters with unbreakable spaces.
"""
line = "\tfoo"
self.assertEqual("&nbsp; &nbsp; &nbsp; &nbsp; foo", htmlIndent(line))

View file

@ -0,0 +1,105 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.vhost}.
"""
from twisted.internet.defer import gatherResults
from twisted.trial.unittest import TestCase
from twisted.web.http import NOT_FOUND
from twisted.web.static import Data
from twisted.web.vhost import NameVirtualHost
from twisted.web.test.test_web import DummyRequest
from twisted.web.test._util import _render
class NameVirtualHostTests(TestCase):
"""
Tests for L{NameVirtualHost}.
"""
def test_renderWithoutHost(self):
"""
L{NameVirtualHost.render} returns the result of rendering the
instance's C{default} if it is not C{None} and there is no I{Host}
header in the request.
"""
virtualHostResource = NameVirtualHost()
virtualHostResource.default = Data("correct result", "")
request = DummyRequest([''])
self.assertEqual(
virtualHostResource.render(request), "correct result")
def test_renderWithoutHostNoDefault(self):
"""
L{NameVirtualHost.render} returns a response with a status of I{NOT
FOUND} if the instance's C{default} is C{None} and there is no I{Host}
header in the request.
"""
virtualHostResource = NameVirtualHost()
request = DummyRequest([''])
d = _render(virtualHostResource, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
def test_renderWithHost(self):
"""
L{NameVirtualHost.render} returns the result of rendering the resource
which is the value in the instance's C{host} dictionary corresponding
to the key indicated by the value of the I{Host} header in the request.
"""
virtualHostResource = NameVirtualHost()
virtualHostResource.addHost('example.org', Data("winner", ""))
request = DummyRequest([''])
request.headers['host'] = 'example.org'
d = _render(virtualHostResource, request)
def cbRendered(ignored, request):
self.assertEqual(''.join(request.written), "winner")
d.addCallback(cbRendered, request)
# The port portion of the Host header should not be considered.
requestWithPort = DummyRequest([''])
requestWithPort.headers['host'] = 'example.org:8000'
dWithPort = _render(virtualHostResource, requestWithPort)
def cbRendered(ignored, requestWithPort):
self.assertEqual(''.join(requestWithPort.written), "winner")
dWithPort.addCallback(cbRendered, requestWithPort)
return gatherResults([d, dWithPort])
def test_renderWithUnknownHost(self):
"""
L{NameVirtualHost.render} returns the result of rendering the
instance's C{default} if it is not C{None} and there is no host
matching the value of the I{Host} header in the request.
"""
virtualHostResource = NameVirtualHost()
virtualHostResource.default = Data("correct data", "")
request = DummyRequest([''])
request.headers['host'] = 'example.com'
d = _render(virtualHostResource, request)
def cbRendered(ignored):
self.assertEqual(''.join(request.written), "correct data")
d.addCallback(cbRendered)
return d
def test_renderWithUnknownHostNoDefault(self):
"""
L{NameVirtualHost.render} returns a response with a status of I{NOT
FOUND} if the instance's C{default} is C{None} and there is no host
matching the value of the I{Host} header in the request.
"""
virtualHostResource = NameVirtualHost()
request = DummyRequest([''])
request.headers['host'] = 'example.com'
d = _render(virtualHostResource, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,818 @@
# -*- test-case-name: twisted.web.test.test_xmlrpc -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for XML-RPC support in L{twisted.web.xmlrpc}.
"""
import datetime
import xmlrpclib
from StringIO import StringIO
from twisted.trial import unittest
from twisted.web import xmlrpc
from twisted.web.xmlrpc import (
XMLRPC, payloadTemplate, addIntrospection, _QueryFactory, withRequest)
from twisted.web import server, static, client, error, http
from twisted.internet import reactor, defer
from twisted.internet.error import ConnectionDone
from twisted.python import failure
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.test.test_web import DummyRequest
try:
import twisted.internet.ssl
except ImportError:
sslSkip = "OpenSSL not present"
else:
sslSkip = None
class AsyncXMLRPCTests(unittest.TestCase):
"""
Tests for L{XMLRPC}'s support of Deferreds.
"""
def setUp(self):
self.request = DummyRequest([''])
self.request.method = 'POST'
self.request.content = StringIO(
payloadTemplate % ('async', xmlrpclib.dumps(())))
result = self.result = defer.Deferred()
class AsyncResource(XMLRPC):
def xmlrpc_async(self):
return result
self.resource = AsyncResource()
def test_deferredResponse(self):
"""
If an L{XMLRPC} C{xmlrpc_*} method returns a L{defer.Deferred}, the
response to the request is the result of that L{defer.Deferred}.
"""
self.resource.render(self.request)
self.assertEqual(self.request.written, [])
self.result.callback("result")
resp = xmlrpclib.loads("".join(self.request.written))
self.assertEqual(resp, (('result',), None))
self.assertEqual(self.request.finished, 1)
def test_interruptedDeferredResponse(self):
"""
While waiting for the L{Deferred} returned by an L{XMLRPC} C{xmlrpc_*}
method to fire, the connection the request was issued over may close.
If this happens, neither C{write} nor C{finish} is called on the
request.
"""
self.resource.render(self.request)
self.request.processingFailed(
failure.Failure(ConnectionDone("Simulated")))
self.result.callback("result")
self.assertEqual(self.request.written, [])
self.assertEqual(self.request.finished, 0)
class TestRuntimeError(RuntimeError):
pass
class TestValueError(ValueError):
pass
class Test(XMLRPC):
# If you add xmlrpc_ methods to this class, go change test_listMethods
# below.
FAILURE = 666
NOT_FOUND = 23
SESSION_EXPIRED = 42
def xmlrpc_echo(self, arg):
return arg
# the doc string is part of the test
def xmlrpc_add(self, a, b):
"""
This function add two numbers.
"""
return a + b
xmlrpc_add.signature = [['int', 'int', 'int'],
['double', 'double', 'double']]
# the doc string is part of the test
def xmlrpc_pair(self, string, num):
"""
This function puts the two arguments in an array.
"""
return [string, num]
xmlrpc_pair.signature = [['array', 'string', 'int']]
# the doc string is part of the test
def xmlrpc_defer(self, x):
"""Help for defer."""
return defer.succeed(x)
def xmlrpc_deferFail(self):
return defer.fail(TestValueError())
# don't add a doc string, it's part of the test
def xmlrpc_fail(self):
raise TestRuntimeError
def xmlrpc_fault(self):
return xmlrpc.Fault(12, "hello")
def xmlrpc_deferFault(self):
return defer.fail(xmlrpc.Fault(17, "hi"))
def xmlrpc_complex(self):
return {"a": ["b", "c", 12, []], "D": "foo"}
def xmlrpc_dict(self, map, key):
return map[key]
xmlrpc_dict.help = 'Help for dict.'
@withRequest
def xmlrpc_withRequest(self, request, other):
"""
A method decorated with L{withRequest} which can be called by
a test to verify that the request object really is passed as
an argument.
"""
return (
# as a proof that request is a request
request.method +
# plus proof other arguments are still passed along
' ' + other)
def lookupProcedure(self, procedurePath):
try:
return XMLRPC.lookupProcedure(self, procedurePath)
except xmlrpc.NoSuchFunction:
if procedurePath.startswith("SESSION"):
raise xmlrpc.Fault(self.SESSION_EXPIRED,
"Session non-existant/expired.")
else:
raise
class TestLookupProcedure(XMLRPC):
"""
This is a resource which customizes procedure lookup to be used by the tests
of support for this customization.
"""
def echo(self, x):
return x
def lookupProcedure(self, procedureName):
"""
Lookup a procedure from a fixed set of choices, either I{echo} or
I{system.listeMethods}.
"""
if procedureName == 'echo':
return self.echo
raise xmlrpc.NoSuchFunction(
self.NOT_FOUND, 'procedure %s not found' % (procedureName,))
class TestListProcedures(XMLRPC):
"""
This is a resource which customizes procedure enumeration to be used by the
tests of support for this customization.
"""
def listProcedures(self):
"""
Return a list of a single method this resource will claim to support.
"""
return ['foo']
class TestAuthHeader(Test):
"""
This is used to get the header info so that we can test
authentication.
"""
def __init__(self):
Test.__init__(self)
self.request = None
def render(self, request):
self.request = request
return Test.render(self, request)
def xmlrpc_authinfo(self):
return self.request.getUser(), self.request.getPassword()
class TestQueryProtocol(xmlrpc.QueryProtocol):
"""
QueryProtocol for tests that saves headers received inside the factory.
"""
def connectionMade(self):
self.factory.transport = self.transport
xmlrpc.QueryProtocol.connectionMade(self)
def handleHeader(self, key, val):
self.factory.headers[key.lower()] = val
class TestQueryFactory(xmlrpc._QueryFactory):
"""
QueryFactory using L{TestQueryProtocol} for saving headers.
"""
protocol = TestQueryProtocol
def __init__(self, *args, **kwargs):
self.headers = {}
xmlrpc._QueryFactory.__init__(self, *args, **kwargs)
class TestQueryFactoryCancel(xmlrpc._QueryFactory):
"""
QueryFactory that saves a reference to the
L{twisted.internet.interfaces.IConnector} to test connection lost.
"""
def startedConnecting(self, connector):
self.connector = connector
class XMLRPCTestCase(unittest.TestCase):
def setUp(self):
self.p = reactor.listenTCP(0, server.Site(Test()),
interface="127.0.0.1")
self.port = self.p.getHost().port
self.factories = []
def tearDown(self):
self.factories = []
return self.p.stopListening()
def queryFactory(self, *args, **kwargs):
"""
Specific queryFactory for proxy that uses our custom
L{TestQueryFactory}, and save factories.
"""
factory = TestQueryFactory(*args, **kwargs)
self.factories.append(factory)
return factory
def proxy(self, factory=None):
"""
Return a new xmlrpc.Proxy for the test site created in
setUp(), using the given factory as the queryFactory, or
self.queryFactory if no factory is provided.
"""
p = xmlrpc.Proxy("http://127.0.0.1:%d/" % self.port)
if factory is None:
p.queryFactory = self.queryFactory
else:
p.queryFactory = factory
return p
def test_results(self):
inputOutput = [
("add", (2, 3), 5),
("defer", ("a",), "a"),
("dict", ({"a": 1}, "a"), 1),
("pair", ("a", 1), ["a", 1]),
("complex", (), {"a": ["b", "c", 12, []], "D": "foo"})]
dl = []
for meth, args, outp in inputOutput:
d = self.proxy().callRemote(meth, *args)
d.addCallback(self.assertEqual, outp)
dl.append(d)
return defer.DeferredList(dl, fireOnOneErrback=True)
def test_errors(self):
"""
Verify that for each way a method exposed via XML-RPC can fail, the
correct 'Content-type' header is set in the response and that the
client-side Deferred is errbacked with an appropriate C{Fault}
instance.
"""
dl = []
for code, methodName in [(666, "fail"), (666, "deferFail"),
(12, "fault"), (23, "noSuchMethod"),
(17, "deferFault"), (42, "SESSION_TEST")]:
d = self.proxy().callRemote(methodName)
d = self.assertFailure(d, xmlrpc.Fault)
d.addCallback(lambda exc, code=code:
self.assertEqual(exc.faultCode, code))
dl.append(d)
d = defer.DeferredList(dl, fireOnOneErrback=True)
def cb(ign):
for factory in self.factories:
self.assertEqual(factory.headers['content-type'],
'text/xml')
self.flushLoggedErrors(TestRuntimeError, TestValueError)
d.addCallback(cb)
return d
def test_cancel(self):
"""
A deferred from the Proxy can be cancelled, disconnecting
the L{twisted.internet.interfaces.IConnector}.
"""
def factory(*args, **kw):
factory.f = TestQueryFactoryCancel(*args, **kw)
return factory.f
d = self.proxy(factory).callRemote('add', 2, 3)
self.assertNotEquals(factory.f.connector.state, "disconnected")
d.cancel()
self.assertEqual(factory.f.connector.state, "disconnected")
d = self.assertFailure(d, defer.CancelledError)
return d
def test_errorGet(self):
"""
A classic GET on the xml server should return a NOT_ALLOWED.
"""
d = client.getPage("http://127.0.0.1:%d/" % (self.port,))
d = self.assertFailure(d, error.Error)
d.addCallback(
lambda exc: self.assertEqual(int(exc.args[0]), http.NOT_ALLOWED))
return d
def test_errorXMLContent(self):
"""
Test that an invalid XML input returns an L{xmlrpc.Fault}.
"""
d = client.getPage("http://127.0.0.1:%d/" % (self.port,),
method="POST", postdata="foo")
def cb(result):
self.assertRaises(xmlrpc.Fault, xmlrpclib.loads, result)
d.addCallback(cb)
return d
def test_datetimeRoundtrip(self):
"""
If an L{xmlrpclib.DateTime} is passed as an argument to an XML-RPC
call and then returned by the server unmodified, the result should
be equal to the original object.
"""
when = xmlrpclib.DateTime()
d = self.proxy().callRemote("echo", when)
d.addCallback(self.assertEqual, when)
return d
def test_doubleEncodingError(self):
"""
If it is not possible to encode a response to the request (for example,
because L{xmlrpclib.dumps} raises an exception when encoding a
L{Fault}) the exception which prevents the response from being
generated is logged and the request object is finished anyway.
"""
d = self.proxy().callRemote("echo", "")
# *Now* break xmlrpclib.dumps. Hopefully the client already used it.
def fakeDumps(*args, **kwargs):
raise RuntimeError("Cannot encode anything at all!")
self.patch(xmlrpclib, 'dumps', fakeDumps)
# It doesn't matter how it fails, so long as it does. Also, it happens
# to fail with an implementation detail exception right now, not
# something suitable as part of a public interface.
d = self.assertFailure(d, Exception)
def cbFailed(ignored):
# The fakeDumps exception should have been logged.
self.assertEqual(len(self.flushLoggedErrors(RuntimeError)), 1)
d.addCallback(cbFailed)
return d
def test_closeConnectionAfterRequest(self):
"""
The connection to the web server is closed when the request is done.
"""
d = self.proxy().callRemote('echo', '')
def responseDone(ignored):
[factory] = self.factories
self.assertFalse(factory.transport.connected)
self.assertTrue(factory.transport.disconnected)
return d.addCallback(responseDone)
def test_tcpTimeout(self):
"""
For I{HTTP} URIs, L{xmlrpc.Proxy.callRemote} passes the value it
received for the C{connectTimeout} parameter as the C{timeout} argument
to the underlying connectTCP call.
"""
reactor = MemoryReactor()
proxy = xmlrpc.Proxy("http://127.0.0.1:69", connectTimeout=2.0,
reactor=reactor)
proxy.callRemote("someMethod")
self.assertEqual(reactor.tcpClients[0][3], 2.0)
def test_sslTimeout(self):
"""
For I{HTTPS} URIs, L{xmlrpc.Proxy.callRemote} passes the value it
received for the C{connectTimeout} parameter as the C{timeout} argument
to the underlying connectSSL call.
"""
reactor = MemoryReactor()
proxy = xmlrpc.Proxy("https://127.0.0.1:69", connectTimeout=3.0,
reactor=reactor)
proxy.callRemote("someMethod")
self.assertEqual(reactor.sslClients[0][4], 3.0)
test_sslTimeout.skip = sslSkip
class XMLRPCTestCase2(XMLRPCTestCase):
"""
Test with proxy that doesn't add a slash.
"""
def proxy(self, factory=None):
p = xmlrpc.Proxy("http://127.0.0.1:%d" % self.port)
if factory is None:
p.queryFactory = self.queryFactory
else:
p.queryFactory = factory
return p
class XMLRPCTestPublicLookupProcedure(unittest.TestCase):
"""
Tests for L{XMLRPC}'s support of subclasses which override
C{lookupProcedure} and C{listProcedures}.
"""
def createServer(self, resource):
self.p = reactor.listenTCP(
0, server.Site(resource), interface="127.0.0.1")
self.addCleanup(self.p.stopListening)
self.port = self.p.getHost().port
self.proxy = xmlrpc.Proxy('http://127.0.0.1:%d' % self.port)
def test_lookupProcedure(self):
"""
A subclass of L{XMLRPC} can override C{lookupProcedure} to find
procedures that are not defined using a C{xmlrpc_}-prefixed method name.
"""
self.createServer(TestLookupProcedure())
what = "hello"
d = self.proxy.callRemote("echo", what)
d.addCallback(self.assertEqual, what)
return d
def test_errors(self):
"""
A subclass of L{XMLRPC} can override C{lookupProcedure} to raise
L{NoSuchFunction} to indicate that a requested method is not available
to be called, signalling a fault to the XML-RPC client.
"""
self.createServer(TestLookupProcedure())
d = self.proxy.callRemote("xxxx", "hello")
d = self.assertFailure(d, xmlrpc.Fault)
return d
def test_listMethods(self):
"""
A subclass of L{XMLRPC} can override C{listProcedures} to define
Overriding listProcedures should prevent introspection from being
broken.
"""
resource = TestListProcedures()
addIntrospection(resource)
self.createServer(resource)
d = self.proxy.callRemote("system.listMethods")
def listed(procedures):
# The list will also include other introspection procedures added by
# addIntrospection. We just want to see "foo" from our customized
# listProcedures.
self.assertIn('foo', procedures)
d.addCallback(listed)
return d
class SerializationConfigMixin:
"""
Mixin which defines a couple tests which should pass when a particular flag
is passed to L{XMLRPC}.
These are not meant to be exhaustive serialization tests, since L{xmlrpclib}
does all of the actual serialization work. They are just meant to exercise
a few codepaths to make sure we are calling into xmlrpclib correctly.
@ivar flagName: A C{str} giving the name of the flag which must be passed to
L{XMLRPC} to allow the tests to pass. Subclasses should set this.
@ivar value: A value which the specified flag will allow the serialization
of. Subclasses should set this.
"""
def setUp(self):
"""
Create a new XML-RPC server with C{allowNone} set to C{True}.
"""
kwargs = {self.flagName: True}
self.p = reactor.listenTCP(
0, server.Site(Test(**kwargs)), interface="127.0.0.1")
self.addCleanup(self.p.stopListening)
self.port = self.p.getHost().port
self.proxy = xmlrpc.Proxy(
"http://127.0.0.1:%d/" % (self.port,), **kwargs)
def test_roundtripValue(self):
"""
C{self.value} can be round-tripped over an XMLRPC method call/response.
"""
d = self.proxy.callRemote('defer', self.value)
d.addCallback(self.assertEqual, self.value)
return d
def test_roundtripNestedValue(self):
"""
A C{dict} which contains C{self.value} can be round-tripped over an
XMLRPC method call/response.
"""
d = self.proxy.callRemote('defer', {'a': self.value})
d.addCallback(self.assertEqual, {'a': self.value})
return d
class XMLRPCAllowNoneTestCase(SerializationConfigMixin, unittest.TestCase):
"""
Tests for passing C{None} when the C{allowNone} flag is set.
"""
flagName = "allowNone"
value = None
class XMLRPCUseDateTimeTestCase(SerializationConfigMixin, unittest.TestCase):
"""
Tests for passing a C{datetime.datetime} instance when the C{useDateTime}
flag is set.
"""
flagName = "useDateTime"
value = datetime.datetime(2000, 12, 28, 3, 45, 59)
class XMLRPCTestAuthenticated(XMLRPCTestCase):
"""
Test with authenticated proxy. We run this with the same inout/ouput as
above.
"""
user = "username"
password = "asecret"
def setUp(self):
self.p = reactor.listenTCP(0, server.Site(TestAuthHeader()),
interface="127.0.0.1")
self.port = self.p.getHost().port
self.factories = []
def test_authInfoInURL(self):
p = xmlrpc.Proxy("http://%s:%s@127.0.0.1:%d/" % (
self.user, self.password, self.port))
d = p.callRemote("authinfo")
d.addCallback(self.assertEqual, [self.user, self.password])
return d
def test_explicitAuthInfo(self):
p = xmlrpc.Proxy("http://127.0.0.1:%d/" % (
self.port,), self.user, self.password)
d = p.callRemote("authinfo")
d.addCallback(self.assertEqual, [self.user, self.password])
return d
def test_longPassword(self):
"""
C{QueryProtocol} uses the C{base64.b64encode} function to encode user
name and password in the I{Authorization} header, so that it doesn't
embed new lines when using long inputs.
"""
longPassword = self.password * 40
p = xmlrpc.Proxy("http://127.0.0.1:%d/" % (
self.port,), self.user, longPassword)
d = p.callRemote("authinfo")
d.addCallback(self.assertEqual, [self.user, longPassword])
return d
def test_explicitAuthInfoOverride(self):
p = xmlrpc.Proxy("http://wrong:info@127.0.0.1:%d/" % (
self.port,), self.user, self.password)
d = p.callRemote("authinfo")
d.addCallback(self.assertEqual, [self.user, self.password])
return d
class XMLRPCTestIntrospection(XMLRPCTestCase):
def setUp(self):
xmlrpc = Test()
addIntrospection(xmlrpc)
self.p = reactor.listenTCP(0, server.Site(xmlrpc),interface="127.0.0.1")
self.port = self.p.getHost().port
self.factories = []
def test_listMethods(self):
def cbMethods(meths):
meths.sort()
self.assertEqual(
meths,
['add', 'complex', 'defer', 'deferFail',
'deferFault', 'dict', 'echo', 'fail', 'fault',
'pair', 'system.listMethods',
'system.methodHelp',
'system.methodSignature', 'withRequest'])
d = self.proxy().callRemote("system.listMethods")
d.addCallback(cbMethods)
return d
def test_methodHelp(self):
inputOutputs = [
("defer", "Help for defer."),
("fail", ""),
("dict", "Help for dict.")]
dl = []
for meth, expected in inputOutputs:
d = self.proxy().callRemote("system.methodHelp", meth)
d.addCallback(self.assertEqual, expected)
dl.append(d)
return defer.DeferredList(dl, fireOnOneErrback=True)
def test_methodSignature(self):
inputOutputs = [
("defer", ""),
("add", [['int', 'int', 'int'],
['double', 'double', 'double']]),
("pair", [['array', 'string', 'int']])]
dl = []
for meth, expected in inputOutputs:
d = self.proxy().callRemote("system.methodSignature", meth)
d.addCallback(self.assertEqual, expected)
dl.append(d)
return defer.DeferredList(dl, fireOnOneErrback=True)
class XMLRPCClientErrorHandling(unittest.TestCase):
"""
Test error handling on the xmlrpc client.
"""
def setUp(self):
self.resource = static.Data(
"This text is not a valid XML-RPC response.",
"text/plain")
self.resource.isLeaf = True
self.port = reactor.listenTCP(0, server.Site(self.resource),
interface='127.0.0.1')
def tearDown(self):
return self.port.stopListening()
def test_erroneousResponse(self):
"""
Test that calling the xmlrpc client on a static http server raises
an exception.
"""
proxy = xmlrpc.Proxy("http://127.0.0.1:%d/" %
(self.port.getHost().port,))
return self.assertFailure(proxy.callRemote("someMethod"), Exception)
class TestQueryFactoryParseResponse(unittest.TestCase):
"""
Test the behaviour of L{_QueryFactory.parseResponse}.
"""
def setUp(self):
# The _QueryFactory that we are testing. We don't care about any
# of the constructor parameters.
self.queryFactory = _QueryFactory(
path=None, host=None, method='POST', user=None, password=None,
allowNone=False, args=())
# An XML-RPC response that will parse without raising an error.
self.goodContents = xmlrpclib.dumps(('',))
# An 'XML-RPC response' that will raise a parsing error.
self.badContents = 'invalid xml'
# A dummy 'reason' to pass to clientConnectionLost. We don't care
# what it is.
self.reason = failure.Failure(ConnectionDone())
def test_parseResponseCallbackSafety(self):
"""
We can safely call L{_QueryFactory.clientConnectionLost} as a callback
of L{_QueryFactory.parseResponse}.
"""
d = self.queryFactory.deferred
# The failure mode is that this callback raises an AlreadyCalled
# error. We have to add it now so that it gets called synchronously
# and triggers the race condition.
d.addCallback(self.queryFactory.clientConnectionLost, self.reason)
self.queryFactory.parseResponse(self.goodContents)
return d
def test_parseResponseErrbackSafety(self):
"""
We can safely call L{_QueryFactory.clientConnectionLost} as an errback
of L{_QueryFactory.parseResponse}.
"""
d = self.queryFactory.deferred
# The failure mode is that this callback raises an AlreadyCalled
# error. We have to add it now so that it gets called synchronously
# and triggers the race condition.
d.addErrback(self.queryFactory.clientConnectionLost, self.reason)
self.queryFactory.parseResponse(self.badContents)
return d
def test_badStatusErrbackSafety(self):
"""
We can safely call L{_QueryFactory.clientConnectionLost} as an errback
of L{_QueryFactory.badStatus}.
"""
d = self.queryFactory.deferred
# The failure mode is that this callback raises an AlreadyCalled
# error. We have to add it now so that it gets called synchronously
# and triggers the race condition.
d.addErrback(self.queryFactory.clientConnectionLost, self.reason)
self.queryFactory.badStatus('status', 'message')
return d
def test_parseResponseWithoutData(self):
"""
Some server can send a response without any data:
L{_QueryFactory.parseResponse} should catch the error and call the
result errback.
"""
content = """
<methodResponse>
<params>
<param>
</param>
</params>
</methodResponse>"""
d = self.queryFactory.deferred
self.queryFactory.parseResponse(content)
return self.assertFailure(d, IndexError)
class XMLRPCTestWithRequest(unittest.TestCase):
def setUp(self):
self.resource = Test()
def test_withRequest(self):
"""
When an XML-RPC method is called and the implementation is
decorated with L{withRequest}, the request object is passed as
the first argument.
"""
request = DummyRequest('/RPC2')
request.method = "POST"
request.content = StringIO(xmlrpclib.dumps(("foo",), 'withRequest'))
def valid(n, request):
data = xmlrpclib.loads(request.written[0])
self.assertEqual(data, (('POST foo',), None))
d = request.notifyFinish().addCallback(valid, request)
self.resource.render_POST(request)
return d