add Linux_i686

This commit is contained in:
j 2014-05-17 18:11:40 +00:00 committed by Ubuntu
commit 95cd9b11f2
1644 changed files with 564260 additions and 0 deletions

View file

@ -0,0 +1,10 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Twisted Test: Unit Tests for Twisted.
"""

View file

@ -0,0 +1,17 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
# This makes sure Twisted-using child processes used in the test suite import
# the correct version of Twisted (ie, the version of Twisted under test).
# This is a copy of the bin/_preamble.py script because it's not clear how to
# use the functionality for both things without having a copy.
import sys, os
path = os.path.abspath(sys.argv[0])
while os.path.dirname(path) != path:
if os.path.exists(os.path.join(path, 'twisted', '__init__.py')):
sys.path.insert(0, path)
break
path = os.path.dirname(path)

View file

@ -0,0 +1,34 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.python import components
from zope.interface import implements, Interface
def foo():
return 2
class X:
def __init__(self, x):
self.x = x
def do(self):
#print 'X',self.x,'doing!'
pass
class XComponent(components.Componentized):
pass
class IX(Interface):
pass
class XA(components.Adapter):
implements(IX)
def method(self):
# Kick start :(
pass
components.registerAdapter(XA, X, IX)

View file

@ -0,0 +1,407 @@
# -*- test-case-name: twisted.test.test_amp.TLSTest,twisted.test.test_iosim -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Utilities and helpers for simulating a network
"""
from __future__ import print_function
import itertools
try:
from OpenSSL.SSL import Error as NativeOpenSSLError
except ImportError:
pass
from zope.interface import implementer, directlyProvides
from twisted.python.failure import Failure
from twisted.internet import error
from twisted.internet import interfaces
class TLSNegotiation:
def __init__(self, obj, connectState):
self.obj = obj
self.connectState = connectState
self.sent = False
self.readyToSend = connectState
def __repr__(self):
return 'TLSNegotiation(%r)' % (self.obj,)
def pretendToVerify(self, other, tpt):
# Set the transport problems list here? disconnections?
# hmmmmm... need some negative path tests.
if not self.obj.iosimVerify(other.obj):
tpt.disconnectReason = NativeOpenSSLError()
tpt.loseConnection()
implementer(interfaces.IAddress)
class FakeAddress(object):
"""
The default address type for the host and peer of L{FakeTransport}
connections.
"""
@implementer(interfaces.ITransport,
interfaces.ITLSTransport)
class FakeTransport:
"""
A wrapper around a file-like object to make it behave as a Transport.
This doesn't actually stream the file to the attached protocol,
and is thus useful mainly as a utility for debugging protocols.
"""
_nextserial = staticmethod(lambda counter=itertools.count(): next(counter))
closed = 0
disconnecting = 0
disconnected = 0
disconnectReason = error.ConnectionDone("Connection done")
producer = None
streamingProducer = 0
tls = None
def __init__(self, protocol, isServer, hostAddress=None, peerAddress=None):
"""
@param protocol: This transport will deliver bytes to this protocol.
@type protocol: L{IProtocol} provider
@param isServer: C{True} if this is the accepting side of the
connection, C{False} if it is the connecting side.
@type isServer: L{bool}
@param hostAddress: The value to return from C{getHost}. C{None}
results in a new L{FakeAddress} being created to use as the value.
@type hostAddress: L{IAddress} provider or L{NoneType}
@param peerAddress: The value to return from C{getPeer}. C{None}
results in a new L{FakeAddress} being created to use as the value.
@type peerAddress: L{IAddress} provider or L{NoneType}
"""
self.protocol = protocol
self.isServer = isServer
self.stream = []
self.serial = self._nextserial()
if hostAddress is None:
hostAddress = FakeAddress()
self.hostAddress = hostAddress
if peerAddress is None:
peerAddress = FakeAddress()
self.peerAddress = peerAddress
def __repr__(self):
return 'FakeTransport<%s,%s,%s>' % (
self.isServer and 'S' or 'C', self.serial,
self.protocol.__class__.__name__)
def write(self, data):
if self.tls is not None:
self.tlsbuf.append(data)
else:
self.stream.append(data)
def _checkProducer(self):
# Cheating; this is called at "idle" times to allow producers to be
# found and dealt with
if self.producer:
self.producer.resumeProducing()
def registerProducer(self, producer, streaming):
"""From abstract.FileDescriptor
"""
self.producer = producer
self.streamingProducer = streaming
if not streaming:
producer.resumeProducing()
def unregisterProducer(self):
self.producer = None
def stopConsuming(self):
self.unregisterProducer()
self.loseConnection()
def writeSequence(self, iovec):
self.write("".join(iovec))
def loseConnection(self):
self.disconnecting = True
def abortConnection(self):
"""
For the time being, this is the same as loseConnection; no buffered
data will be lost.
"""
self.disconnecting = True
def reportDisconnect(self):
if self.tls is not None:
# We were in the middle of negotiating! Must have been a TLS problem.
err = NativeOpenSSLError()
else:
err = self.disconnectReason
self.protocol.connectionLost(Failure(err))
def logPrefix(self):
"""
Identify this transport/event source to the logging system.
"""
return "iosim"
def getPeer(self):
return self.peerAddress
def getHost(self):
return self.hostAddress
def resumeProducing(self):
# Never sends data anyways
pass
def pauseProducing(self):
# Never sends data anyways
pass
def stopProducing(self):
self.loseConnection()
def startTLS(self, contextFactory, beNormal=True):
# Nothing's using this feature yet, but startTLS has an undocumented
# second argument which defaults to true; if set to False, servers will
# behave like clients and clients will behave like servers.
connectState = self.isServer ^ beNormal
self.tls = TLSNegotiation(contextFactory, connectState)
self.tlsbuf = []
def getOutBuffer(self):
"""
Get the pending writes from this transport, clearing them from the
pending buffer.
@return: the bytes written with C{transport.write}
@rtype: L{bytes}
"""
S = self.stream
if S:
self.stream = []
return b''.join(S)
elif self.tls is not None:
if self.tls.readyToSend:
# Only _send_ the TLS negotiation "packet" if I'm ready to.
self.tls.sent = True
return self.tls
else:
return None
else:
return None
def bufferReceived(self, buf):
if isinstance(buf, TLSNegotiation):
assert self.tls is not None # By the time you're receiving a
# negotiation, you have to have called
# startTLS already.
if self.tls.sent:
self.tls.pretendToVerify(buf, self)
self.tls = None # we're done with the handshake if we've gotten
# this far... although maybe it failed...?
# TLS started! Unbuffer...
b, self.tlsbuf = self.tlsbuf, None
self.writeSequence(b)
directlyProvides(self, interfaces.ISSLTransport)
else:
# We haven't sent our own TLS negotiation: time to do that!
self.tls.readyToSend = True
else:
self.protocol.dataReceived(buf)
def makeFakeClient(clientProtocol):
"""
Create and return a new in-memory transport hooked up to the given protocol.
@param clientProtocol: The client protocol to use.
@type clientProtocol: L{IProtocol} provider
@return: The transport.
@rtype: L{FakeTransport}
"""
return FakeTransport(clientProtocol, isServer=False)
def makeFakeServer(serverProtocol):
"""
Create and return a new in-memory transport hooked up to the given protocol.
@param serverProtocol: The server protocol to use.
@type serverProtocol: L{IProtocol} provider
@return: The transport.
@rtype: L{FakeTransport}
"""
return FakeTransport(serverProtocol, isServer=True)
class IOPump:
"""Utility to pump data between clients and servers for protocol testing.
Perhaps this is a utility worthy of being in protocol.py?
"""
def __init__(self, client, server, clientIO, serverIO, debug):
self.client = client
self.server = server
self.clientIO = clientIO
self.serverIO = serverIO
self.debug = debug
def flush(self, debug=False):
"""Pump until there is no more input or output.
Returns whether any data was moved.
"""
result = False
for x in range(1000):
if self.pump(debug):
result = True
else:
break
else:
assert 0, "Too long"
return result
def pump(self, debug=False):
"""Move data back and forth.
Returns whether any data was moved.
"""
if self.debug or debug:
print('-- GLUG --')
sData = self.serverIO.getOutBuffer()
cData = self.clientIO.getOutBuffer()
self.clientIO._checkProducer()
self.serverIO._checkProducer()
if self.debug or debug:
print('.')
# XXX slightly buggy in the face of incremental output
if cData:
print('C: ' + repr(cData))
if sData:
print('S: ' + repr(sData))
if cData:
self.serverIO.bufferReceived(cData)
if sData:
self.clientIO.bufferReceived(sData)
if cData or sData:
return True
if (self.serverIO.disconnecting and
not self.serverIO.disconnected):
if self.debug or debug:
print('* C')
self.serverIO.disconnected = True
self.clientIO.disconnecting = True
self.clientIO.reportDisconnect()
return True
if self.clientIO.disconnecting and not self.clientIO.disconnected:
if self.debug or debug:
print('* S')
self.clientIO.disconnected = True
self.serverIO.disconnecting = True
self.serverIO.reportDisconnect()
return True
return False
def connect(serverProtocol, serverTransport,
clientProtocol, clientTransport, debug=False):
"""
Create a new L{IOPump} connecting two protocols.
@param serverProtocol: The protocol to use on the accepting side of the
connection.
@type serverProtocol: L{IProtocol} provider
@param serverTransport: The transport to associate with C{serverProtocol}.
@type serverTransport: L{FakeTransport}
@param clientProtocol: The protocol to use on the initiating side of the
connection.
@type clientProtocol: L{IProtocol} provider
@param clientTransport: The transport to associate with C{clientProtocol}.
@type clientTransport: L{FakeTransport}
@param debug: A flag indicating whether to log information about what the
L{IOPump} is doing.
@type debug: L{bool}
@return: An L{IOPump} which connects C{serverProtocol} and
C{clientProtocol} and delivers bytes between them when it is pumped.
@rtype: L{IOPump}
"""
serverProtocol.makeConnection(serverTransport)
clientProtocol.makeConnection(clientTransport)
pump = IOPump(
clientProtocol, serverProtocol, clientTransport, serverTransport, debug
)
# kick off server greeting, etc
pump.flush()
return pump
def connectedServerAndClient(ServerClass, ClientClass,
clientTransportFactory=makeFakeClient,
serverTransportFactory=makeFakeServer,
debug=False):
"""
Connect a given server and client class to each other.
@param ServerClass: a callable that produces the server-side protocol.
@type ServerClass: 0-argument callable returning L{IProtocol} provider.
@param ClientClass: like C{ServerClass} but for the other side of the
connection.
@type ClientClass: 0-argument callable returning L{IProtocol} provider.
@param clientTransportFactory: a callable that produces the transport which
will be attached to the protocol returned from C{ClientClass}.
@type clientTransportFactory: callable taking (L{IProtocol}) and returning
L{FakeTransport}
@param serverTransportFactory: a callable that produces the transport which
will be attached to the protocol returned from C{ServerClass}.
@type serverTransportFactory: callable taking (L{IProtocol}) and returning
L{FakeTransport}
@param debug: Should this dump an escaped version of all traffic on this
connection to stdout for inspection?
@type debug: L{bool}
@return: the client protocol, the server protocol, and an L{IOPump} which,
when its C{pump} and C{flush} methods are called, will move data
between the created client and server protocol instances.
@rtype: 3-L{tuple} of L{IProtocol}, L{IProtocol}, L{IOPump}
"""
c = ClientClass()
s = ServerClass()
cio = clientTransportFactory(c)
sio = serverTransportFactory(s)
return c, s, connect(s, sio, c, cio, debug)

View file

@ -0,0 +1,48 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This is a mock win32process module.
The purpose of this module is mock process creation for the PID test.
CreateProcess(...) will spawn a process, and always return a PID of 42.
"""
import win32process
GetExitCodeProcess = win32process.GetExitCodeProcess
STARTUPINFO = win32process.STARTUPINFO
STARTF_USESTDHANDLES = win32process.STARTF_USESTDHANDLES
def CreateProcess(appName,
cmdline,
procSecurity,
threadSecurity,
inheritHandles,
newEnvironment,
env,
workingDir,
startupInfo):
"""
This function mocks the generated pid aspect of the win32.CreateProcess
function.
- the true win32process.CreateProcess is called
- return values are harvested in a tuple.
- all return values from createProcess are passed back to the calling
function except for the pid, the returned pid is hardcoded to 42
"""
hProcess, hThread, dwPid, dwTid = win32process.CreateProcess(
appName,
cmdline,
procSecurity,
threadSecurity,
inheritHandles,
newEnvironment,
env,
workingDir,
startupInfo)
dwPid = 42
return (hProcess, hThread, dwPid, dwTid)

View file

@ -0,0 +1,15 @@
class A:
def a(self):
return 'a'
try:
object
except NameError:
pass
else:
class B(object, A):
def b(self):
return 'b'
class Inherit(A):
def a(self):
return 'c'

View file

@ -0,0 +1,16 @@
class A:
def a(self):
return 'b'
try:
object
except NameError:
pass
else:
class B(A, object):
def b(self):
return 'c'
class Inherit(A):
def a(self):
return 'd'

View file

@ -0,0 +1,57 @@
# Copyright (c) 2005 Divmod, Inc.
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
# Don't change the docstring, it's part of the tests
"""
I'm a test drop-in. The plugin system's unit tests use me. No one
else should.
"""
from zope.interface import classProvides
from twisted.plugin import IPlugin
from twisted.test.test_plugin import ITestPlugin, ITestPlugin2
class TestPlugin:
"""
A plugin used solely for testing purposes.
"""
classProvides(ITestPlugin,
IPlugin)
def test1():
pass
test1 = staticmethod(test1)
class AnotherTestPlugin:
"""
Another plugin used solely for testing purposes.
"""
classProvides(ITestPlugin2,
IPlugin)
def test():
pass
test = staticmethod(test)
class ThirdTestPlugin:
"""
Another plugin used solely for testing purposes.
"""
classProvides(ITestPlugin2,
IPlugin)
def test():
pass
test = staticmethod(test)

View file

@ -0,0 +1,23 @@
# Copyright (c) 2005 Divmod, Inc.
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test plugin used in L{twisted.test.test_plugin}.
"""
from zope.interface import classProvides
from twisted.plugin import IPlugin
from twisted.test.test_plugin import ITestPlugin
class FourthTestPlugin:
classProvides(ITestPlugin,
IPlugin)
def test1():
pass
test1 = staticmethod(test1)

View file

@ -0,0 +1,35 @@
# Copyright (c) 2005 Divmod, Inc.
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test plugin used in L{twisted.test.test_plugin}.
"""
from zope.interface import classProvides
from twisted.plugin import IPlugin
from twisted.test.test_plugin import ITestPlugin
class FourthTestPlugin:
classProvides(ITestPlugin,
IPlugin)
def test1():
pass
test1 = staticmethod(test1)
class FifthTestPlugin:
"""
More documentation: I hate you.
"""
classProvides(ITestPlugin,
IPlugin)
def test1():
pass
test1 = staticmethod(test1)

View file

@ -0,0 +1,5 @@
"""Write to stdout the command line args it received, one per line."""
import sys
for x in sys.argv[1:]:
print x

View file

@ -0,0 +1,11 @@
"""Write back all data it receives."""
import sys
data = sys.stdin.read(1)
while data:
sys.stdout.write(data)
sys.stdout.flush()
data = sys.stdin.read(1)
sys.stderr.write("byebye")
sys.stderr.flush()

View file

@ -0,0 +1,40 @@
"""Write to a handful of file descriptors, to test the childFDs= argument of
reactor.spawnProcess()
"""
import os, sys
debug = 0
if debug: stderr = os.fdopen(2, "w")
if debug: print >>stderr, "this is stderr"
abcd = os.read(0, 4)
if debug: print >>stderr, "read(0):", abcd
if abcd != "abcd":
sys.exit(1)
if debug: print >>stderr, "os.write(1, righto)"
os.write(1, "righto")
efgh = os.read(3, 4)
if debug: print >>stderr, "read(3):", efgh
if efgh != "efgh":
sys.exit(2)
if debug: print >>stderr, "os.close(4)"
os.close(4)
eof = os.read(5, 4)
if debug: print >>stderr, "read(5):", eof
if eof != "":
sys.exit(3)
if debug: print >>stderr, "os.write(1, closed)"
os.write(1, "closed")
if debug: print >>stderr, "sys.exit(0)"
sys.exit(0)

View file

@ -0,0 +1,17 @@
"""Write to a file descriptor and then close it, waiting a few seconds before
quitting. This serves to make sure SIGCHLD is actually being noticed.
"""
import os, sys, time
print "here is some text"
time.sleep(1)
print "goodbye"
os.close(1)
os.close(2)
time.sleep(2)
sys.exit(0)

View file

@ -0,0 +1,12 @@
"""Script used by test_process.TestTwoProcesses"""
# run until stdin is closed, then quit
import sys
while 1:
d = sys.stdin.read()
if len(d) == 0:
sys.exit(0)

View file

@ -0,0 +1,8 @@
import sys, signal
signal.signal(signal.SIGINT, signal.SIG_DFL)
if getattr(signal, "SIGHUP", None) is not None:
signal.signal(signal.SIGHUP, signal.SIG_DFL)
print 'ok, signal us'
sys.stdin.read()
sys.exit(1)

View file

@ -0,0 +1,23 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""Script used by twisted.test.test_process on win32."""
import sys, time, os, msvcrt
msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
msvcrt.setmode(sys.stderr.fileno(), os.O_BINARY)
sys.stdout.write("out\n")
sys.stdout.flush()
sys.stderr.write("err\n")
sys.stderr.flush()
data = sys.stdin.read()
sys.stdout.write(data)
sys.stdout.write("\nout\n")
sys.stderr.write("err\n")
sys.stdout.flush()
sys.stderr.flush()

View file

@ -0,0 +1,37 @@
"""Test program for processes."""
import sys, os
test_file_match = "process_test.log.*"
test_file = "process_test.log.%d" % os.getpid()
def main():
f = open(test_file, 'wb')
# stage 1
bytes = sys.stdin.read(4)
f.write("one: %r\n" % bytes)
# stage 2
sys.stdout.write(bytes)
sys.stdout.flush()
os.close(sys.stdout.fileno())
# and a one, and a two, and a...
bytes = sys.stdin.read(4)
f.write("two: %r\n" % bytes)
# stage 3
sys.stderr.write(bytes)
sys.stderr.flush()
os.close(sys.stderr.fileno())
# stage 4
bytes = sys.stdin.read(4)
f.write("three: %r\n" % bytes)
# exit with status code 23
sys.exit(23)
if __name__ == '__main__':
main()

View file

@ -0,0 +1,6 @@
"""Test to make sure we can open /dev/tty"""
f = open("/dev/tty", "r+")
a = f.readline()
f.write(a)
f.close()

View file

@ -0,0 +1,43 @@
"""A process that reads from stdin and out using Twisted."""
### Twisted Preamble
# This makes sure that users don't have to set up their environment
# specially in order to run these programs from bin/.
import sys, os
pos = os.path.abspath(sys.argv[0]).find(os.sep+'Twisted')
if pos != -1:
sys.path.insert(0, os.path.abspath(sys.argv[0])[:pos+8])
sys.path.insert(0, os.curdir)
### end of preamble
from twisted.python import log
from zope.interface import implements
from twisted.internet import interfaces
log.startLogging(sys.stderr)
from twisted.internet import protocol, reactor, stdio
class Echo(protocol.Protocol):
implements(interfaces.IHalfCloseableProtocol)
def connectionMade(self):
print "connection made"
def dataReceived(self, data):
self.transport.write(data)
def readConnectionLost(self):
print "readConnectionLost"
self.transport.loseConnection()
def writeConnectionLost(self):
print "writeConnectionLost"
def connectionLost(self, reason):
print "connectionLost", reason
reactor.stop()
stdio.StandardIO(Echo())
reactor.run()

View file

@ -0,0 +1,690 @@
# -*- test-case-name: twisted.test.test_stringtransport -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Assorted functionality which is commonly useful when writing unit tests.
"""
from __future__ import division, absolute_import
from socket import AF_INET, AF_INET6
from io import BytesIO
from zope.interface import implementer, implementedBy
from zope.interface.verify import verifyClass
from twisted.python import failure
from twisted.python.compat import unicode
from twisted.internet.interfaces import (
ITransport, IConsumer, IPushProducer, IConnector, IReactorTCP, IReactorSSL,
IReactorUNIX, IReactorSocket, IListeningPort, IReactorFDSet
)
from twisted.internet.abstract import isIPv6Address
from twisted.internet.error import UnsupportedAddressFamily
from twisted.protocols import basic
from twisted.internet import protocol, error, address
from twisted.internet.task import Clock
from twisted.internet.address import IPv4Address, UNIXAddress, IPv6Address
class AccumulatingProtocol(protocol.Protocol):
"""
L{AccumulatingProtocol} is an L{IProtocol} implementation which collects
the data delivered to it and can fire a Deferred when it is connected or
disconnected.
@ivar made: A flag indicating whether C{connectionMade} has been called.
@ivar data: Bytes giving all the data passed to C{dataReceived}.
@ivar closed: A flag indicated whether C{connectionLost} has been called.
@ivar closedReason: The value of the I{reason} parameter passed to
C{connectionLost}.
@ivar closedDeferred: If set to a L{Deferred}, this will be fired when
C{connectionLost} is called.
"""
made = closed = 0
closedReason = None
closedDeferred = None
data = b""
factory = None
def connectionMade(self):
self.made = 1
if (self.factory is not None and
self.factory.protocolConnectionMade is not None):
d = self.factory.protocolConnectionMade
self.factory.protocolConnectionMade = None
d.callback(self)
def dataReceived(self, data):
self.data += data
def connectionLost(self, reason):
self.closed = 1
self.closedReason = reason
if self.closedDeferred is not None:
d, self.closedDeferred = self.closedDeferred, None
d.callback(None)
class LineSendingProtocol(basic.LineReceiver):
lostConn = False
def __init__(self, lines, start = True):
self.lines = lines[:]
self.response = []
self.start = start
def connectionMade(self):
if self.start:
for line in self.lines:
self.sendLine(line)
def lineReceived(self, line):
if not self.start:
for line in self.lines:
self.sendLine(line)
self.lines = []
self.response.append(line)
def connectionLost(self, reason):
self.lostConn = True
class FakeDatagramTransport:
noAddr = object()
def __init__(self):
self.written = []
def write(self, packet, addr=noAddr):
self.written.append((packet, addr))
@implementer(ITransport, IConsumer, IPushProducer)
class StringTransport:
"""
A transport implementation which buffers data in memory and keeps track of
its other state without providing any behavior.
L{StringTransport} has a number of attributes which are not part of any of
the interfaces it claims to implement. These attributes are provided for
testing purposes. Implementation code should not use any of these
attributes; they are not provided by other transports.
@ivar disconnecting: A C{bool} which is C{False} until L{loseConnection} is
called, then C{True}.
@ivar producer: If a producer is currently registered, C{producer} is a
reference to it. Otherwise, C{None}.
@ivar streaming: If a producer is currently registered, C{streaming} refers
to the value of the second parameter passed to C{registerProducer}.
@ivar hostAddr: C{None} or an object which will be returned as the host
address of this transport. If C{None}, a nasty tuple will be returned
instead.
@ivar peerAddr: C{None} or an object which will be returned as the peer
address of this transport. If C{None}, a nasty tuple will be returned
instead.
@ivar producerState: The state of this L{StringTransport} in its capacity
as an L{IPushProducer}. One of C{'producing'}, C{'paused'}, or
C{'stopped'}.
@ivar io: A L{BytesIO} which holds the data which has been written to this
transport since the last call to L{clear}. Use L{value} instead of
accessing this directly.
"""
disconnecting = False
producer = None
streaming = None
hostAddr = None
peerAddr = None
producerState = 'producing'
def __init__(self, hostAddress=None, peerAddress=None):
self.clear()
if hostAddress is not None:
self.hostAddr = hostAddress
if peerAddress is not None:
self.peerAddr = peerAddress
self.connected = True
def clear(self):
"""
Discard all data written to this transport so far.
This is not a transport method. It is intended for tests. Do not use
it in implementation code.
"""
self.io = BytesIO()
def value(self):
"""
Retrieve all data which has been buffered by this transport.
This is not a transport method. It is intended for tests. Do not use
it in implementation code.
@return: A C{bytes} giving all data written to this transport since the
last call to L{clear}.
@rtype: C{bytes}
"""
return self.io.getvalue()
# ITransport
def write(self, data):
if isinstance(data, unicode): # no, really, I mean it
raise TypeError("Data must not be unicode")
self.io.write(data)
def writeSequence(self, data):
self.io.write(b''.join(data))
def loseConnection(self):
"""
Close the connection. Does nothing besides toggle the C{disconnecting}
instance variable to C{True}.
"""
self.disconnecting = True
def getPeer(self):
if self.peerAddr is None:
return address.IPv4Address('TCP', '192.168.1.1', 54321)
return self.peerAddr
def getHost(self):
if self.hostAddr is None:
return address.IPv4Address('TCP', '10.0.0.1', 12345)
return self.hostAddr
# IConsumer
def registerProducer(self, producer, streaming):
if self.producer is not None:
raise RuntimeError("Cannot register two producers")
self.producer = producer
self.streaming = streaming
def unregisterProducer(self):
if self.producer is None:
raise RuntimeError(
"Cannot unregister a producer unless one is registered")
self.producer = None
self.streaming = None
# IPushProducer
def _checkState(self):
if self.disconnecting:
raise RuntimeError(
"Cannot resume producing after loseConnection")
if self.producerState == 'stopped':
raise RuntimeError("Cannot resume a stopped producer")
def pauseProducing(self):
self._checkState()
self.producerState = 'paused'
def stopProducing(self):
self.producerState = 'stopped'
def resumeProducing(self):
self._checkState()
self.producerState = 'producing'
class StringTransportWithDisconnection(StringTransport):
"""
A L{StringTransport} which can be disconnected.
"""
def loseConnection(self):
if self.connected:
self.connected = False
self.protocol.connectionLost(
failure.Failure(error.ConnectionDone("Bye.")))
class StringIOWithoutClosing(BytesIO):
"""
A BytesIO that can't be closed.
"""
def close(self):
"""
Do nothing.
"""
@implementer(IListeningPort)
class _FakePort(object):
"""
A fake L{IListeningPort} to be used in tests.
@ivar _hostAddress: The L{IAddress} this L{IListeningPort} is pretending
to be listening on.
"""
def __init__(self, hostAddress):
"""
@param hostAddress: An L{IAddress} this L{IListeningPort} should
pretend to be listening on.
"""
self._hostAddress = hostAddress
def startListening(self):
"""
Fake L{IListeningPort.startListening} that doesn't do anything.
"""
def stopListening(self):
"""
Fake L{IListeningPort.stopListening} that doesn't do anything.
"""
def getHost(self):
"""
Fake L{IListeningPort.getHost} that returns our L{IAddress}.
"""
return self._hostAddress
@implementer(IConnector)
class _FakeConnector(object):
"""
A fake L{IConnector} that allows us to inspect if it has been told to stop
connecting.
@ivar stoppedConnecting: has this connector's
L{FakeConnector.stopConnecting} method been invoked yet?
@ivar _address: An L{IAddress} provider that represents our destination.
"""
_disconnected = False
stoppedConnecting = False
def __init__(self, address):
"""
@param address: An L{IAddress} provider that represents this
connector's destination.
"""
self._address = address
def stopConnecting(self):
"""
Implement L{IConnector.stopConnecting} and set
L{FakeConnector.stoppedConnecting} to C{True}
"""
self.stoppedConnecting = True
def disconnect(self):
"""
Implement L{IConnector.disconnect} as a no-op.
"""
self._disconnected = True
def connect(self):
"""
Implement L{IConnector.connect} as a no-op.
"""
def getDestination(self):
"""
Implement L{IConnector.getDestination} to return the C{address} passed
to C{__init__}.
"""
return self._address
@implementer(
IReactorTCP, IReactorSSL, IReactorUNIX, IReactorSocket, IReactorFDSet
)
class MemoryReactor(object):
"""
A fake reactor to be used in tests. This reactor doesn't actually do
much that's useful yet. It accepts TCP connection setup attempts, but
they will never succeed.
@ivar tcpClients: a list that keeps track of connection attempts (ie, calls
to C{connectTCP}).
@type tcpClients: C{list}
@ivar tcpServers: a list that keeps track of server listen attempts (ie, calls
to C{listenTCP}).
@type tcpServers: C{list}
@ivar sslClients: a list that keeps track of connection attempts (ie,
calls to C{connectSSL}).
@type sslClients: C{list}
@ivar sslServers: a list that keeps track of server listen attempts (ie,
calls to C{listenSSL}).
@type sslServers: C{list}
@ivar unixClients: a list that keeps track of connection attempts (ie,
calls to C{connectUNIX}).
@type unixClients: C{list}
@ivar unixServers: a list that keeps track of server listen attempts (ie,
calls to C{listenUNIX}).
@type unixServers: C{list}
@ivar adoptedPorts: a list that keeps track of server listen attempts (ie,
calls to C{adoptStreamPort}).
@ivar adoptedStreamConnections: a list that keeps track of stream-oriented
connections added using C{adoptStreamConnection}.
"""
def __init__(self):
"""
Initialize the tracking lists.
"""
self.tcpClients = []
self.tcpServers = []
self.sslClients = []
self.sslServers = []
self.unixClients = []
self.unixServers = []
self.adoptedPorts = []
self.adoptedStreamConnections = []
self.connectors = []
self.readers = set()
self.writers = set()
def adoptStreamPort(self, fileno, addressFamily, factory):
"""
Fake L{IReactorSocket.adoptStreamPort}, that logs the call and returns
an L{IListeningPort}.
"""
if addressFamily == AF_INET:
addr = IPv4Address('TCP', '0.0.0.0', 1234)
elif addressFamily == AF_INET6:
addr = IPv6Address('TCP', '::', 1234)
else:
raise UnsupportedAddressFamily()
self.adoptedPorts.append((fileno, addressFamily, factory))
return _FakePort(addr)
def adoptStreamConnection(self, fileDescriptor, addressFamily, factory):
"""
Record the given stream connection in C{adoptedStreamConnections}.
@see: L{twisted.internet.interfaces.IReactorSocket.adoptStreamConnection}
"""
self.adoptedStreamConnections.append((
fileDescriptor, addressFamily, factory))
def adoptDatagramPort(self, fileno, addressFamily, protocol,
maxPacketSize=8192):
"""
Fake L{IReactorSocket.adoptDatagramPort}, that logs the call and returns
a fake L{IListeningPort}.
@see: L{twisted.internet.interfaces.IReactorSocket.adoptDatagramPort}
"""
if addressFamily == AF_INET:
addr = IPv4Address('UDP', '0.0.0.0', 1234)
elif addressFamily == AF_INET6:
addr = IPv6Address('UDP', '::', 1234)
else:
raise UnsupportedAddressFamily()
self.adoptedPorts.append(
(fileno, addressFamily, protocol, maxPacketSize))
return _FakePort(addr)
def listenTCP(self, port, factory, backlog=50, interface=''):
"""
Fake L{reactor.listenTCP}, that logs the call and returns an
L{IListeningPort}.
"""
self.tcpServers.append((port, factory, backlog, interface))
if isIPv6Address(interface):
address = IPv6Address('TCP', interface, port)
else:
address = IPv4Address('TCP', '0.0.0.0', port)
return _FakePort(address)
def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
"""
Fake L{reactor.connectTCP}, that logs the call and returns an
L{IConnector}.
"""
self.tcpClients.append((host, port, factory, timeout, bindAddress))
if isIPv6Address(host):
conn = _FakeConnector(IPv6Address('TCP', host, port))
else:
conn = _FakeConnector(IPv4Address('TCP', host, port))
factory.startedConnecting(conn)
self.connectors.append(conn)
return conn
def listenSSL(self, port, factory, contextFactory,
backlog=50, interface=''):
"""
Fake L{reactor.listenSSL}, that logs the call and returns an
L{IListeningPort}.
"""
self.sslServers.append((port, factory, contextFactory,
backlog, interface))
return _FakePort(IPv4Address('TCP', '0.0.0.0', port))
def connectSSL(self, host, port, factory, contextFactory,
timeout=30, bindAddress=None):
"""
Fake L{reactor.connectSSL}, that logs the call and returns an
L{IConnector}.
"""
self.sslClients.append((host, port, factory, contextFactory,
timeout, bindAddress))
conn = _FakeConnector(IPv4Address('TCP', host, port))
factory.startedConnecting(conn)
self.connectors.append(conn)
return conn
def listenUNIX(self, address, factory,
backlog=50, mode=0o666, wantPID=0):
"""
Fake L{reactor.listenUNIX}, that logs the call and returns an
L{IListeningPort}.
"""
self.unixServers.append((address, factory, backlog, mode, wantPID))
return _FakePort(UNIXAddress(address))
def connectUNIX(self, address, factory, timeout=30, checkPID=0):
"""
Fake L{reactor.connectUNIX}, that logs the call and returns an
L{IConnector}.
"""
self.unixClients.append((address, factory, timeout, checkPID))
conn = _FakeConnector(UNIXAddress(address))
factory.startedConnecting(conn)
self.connectors.append(conn)
return conn
def addReader(self, reader):
"""
Fake L{IReactorFDSet.addReader} which adds the reader to a local set.
"""
self.readers.add(reader)
def removeReader(self, reader):
"""
Fake L{IReactorFDSet.removeReader} which removes the reader from a
local set.
"""
self.readers.discard(reader)
def addWriter(self, writer):
"""
Fake L{IReactorFDSet.addWriter} which adds the writer to a local set.
"""
self.writers.add(writer)
def removeWriter(self, writer):
"""
Fake L{IReactorFDSet.removeWriter} which removes the writer from a
local set.
"""
self.writers.discard(writer)
def getReaders(self):
"""
Fake L{IReactorFDSet.getReaders} which returns a list of readers from
the local set.
"""
return list(self.readers)
def getWriters(self):
"""
Fake L{IReactorFDSet.getWriters} which returns a list of writers from
the local set.
"""
return list(self.writers)
def removeAll(self):
"""
Fake L{IReactorFDSet.removeAll} which removed all readers and writers
from the local sets.
"""
self.readers.clear()
self.writers.clear()
for iface in implementedBy(MemoryReactor):
verifyClass(iface, MemoryReactor)
class MemoryReactorClock(MemoryReactor, Clock):
def __init__(self):
MemoryReactor.__init__(self)
Clock.__init__(self)
@implementer(IReactorTCP, IReactorSSL, IReactorUNIX, IReactorSocket)
class RaisingMemoryReactor(object):
"""
A fake reactor to be used in tests. It accepts TCP connection setup
attempts, but they will fail.
@ivar _listenException: An instance of an L{Exception}
@ivar _connectException: An instance of an L{Exception}
"""
def __init__(self, listenException=None, connectException=None):
"""
@param listenException: An instance of an L{Exception} to raise when any
C{listen} method is called.
@param connectException: An instance of an L{Exception} to raise when
any C{connect} method is called.
"""
self._listenException = listenException
self._connectException = connectException
def adoptStreamPort(self, fileno, addressFamily, factory):
"""
Fake L{IReactorSocket.adoptStreamPort}, that raises
L{self._listenException}.
"""
raise self._listenException
def listenTCP(self, port, factory, backlog=50, interface=''):
"""
Fake L{reactor.listenTCP}, that raises L{self._listenException}.
"""
raise self._listenException
def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
"""
Fake L{reactor.connectTCP}, that raises L{self._connectException}.
"""
raise self._connectException
def listenSSL(self, port, factory, contextFactory,
backlog=50, interface=''):
"""
Fake L{reactor.listenSSL}, that raises L{self._listenException}.
"""
raise self._listenException
def connectSSL(self, host, port, factory, contextFactory,
timeout=30, bindAddress=None):
"""
Fake L{reactor.connectSSL}, that raises L{self._connectException}.
"""
raise self._connectException
def listenUNIX(self, address, factory,
backlog=50, mode=0o666, wantPID=0):
"""
Fake L{reactor.listenUNIX}, that raises L{self._listenException}.
"""
raise self._listenException
def connectUNIX(self, address, factory, timeout=30, checkPID=0):
"""
Fake L{reactor.connectUNIX}, that raises L{self._connectException}.
"""
raise self._connectException

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,21 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
A trivial extension that just raises an exception.
See L{twisted.test.test_failure.test_failureConstructionWithMungedStackSucceeds}.
"""
class RaiserException(Exception):
"""
A speficic exception only used to be identified in tests.
"""
def raiseException():
"""
Raise L{RaiserException}.
"""
raise RaiserException("This function is intentionally broken")

View file

@ -0,0 +1,4 @@
# Helper for a test_reflect test
import idonotexist

View file

@ -0,0 +1,4 @@
# Helper for a test_reflect test
raise ValueError("Stuff is broken and things")

View file

@ -0,0 +1,4 @@
# Helper module for a test_reflect test
1//0

View file

@ -0,0 +1,36 @@
-----BEGIN CERTIFICATE-----
MIIDBjCCAm+gAwIBAgIBATANBgkqhkiG9w0BAQQFADB7MQswCQYDVQQGEwJTRzER
MA8GA1UEChMITTJDcnlwdG8xFDASBgNVBAsTC00yQ3J5cHRvIENBMSQwIgYDVQQD
ExtNMkNyeXB0byBDZXJ0aWZpY2F0ZSBNYXN0ZXIxHTAbBgkqhkiG9w0BCQEWDm5n
cHNAcG9zdDEuY29tMB4XDTAwMDkxMDA5NTEzMFoXDTAyMDkxMDA5NTEzMFowUzEL
MAkGA1UEBhMCU0cxETAPBgNVBAoTCE0yQ3J5cHRvMRIwEAYDVQQDEwlsb2NhbGhv
c3QxHTAbBgkqhkiG9w0BCQEWDm5ncHNAcG9zdDEuY29tMFwwDQYJKoZIhvcNAQEB
BQADSwAwSAJBAKy+e3dulvXzV7zoTZWc5TzgApr8DmeQHTYC8ydfzH7EECe4R1Xh
5kwIzOuuFfn178FBiS84gngaNcrFi0Z5fAkCAwEAAaOCAQQwggEAMAkGA1UdEwQC
MAAwLAYJYIZIAYb4QgENBB8WHU9wZW5TU0wgR2VuZXJhdGVkIENlcnRpZmljYXRl
MB0GA1UdDgQWBBTPhIKSvnsmYsBVNWjj0m3M2z0qVTCBpQYDVR0jBIGdMIGagBT7
hyNp65w6kxXlxb8pUU/+7Sg4AaF/pH0wezELMAkGA1UEBhMCU0cxETAPBgNVBAoT
CE0yQ3J5cHRvMRQwEgYDVQQLEwtNMkNyeXB0byBDQTEkMCIGA1UEAxMbTTJDcnlw
dG8gQ2VydGlmaWNhdGUgTWFzdGVyMR0wGwYJKoZIhvcNAQkBFg5uZ3BzQHBvc3Qx
LmNvbYIBADANBgkqhkiG9w0BAQQFAAOBgQA7/CqT6PoHycTdhEStWNZde7M/2Yc6
BoJuVwnW8YxGO8Sn6UJ4FeffZNcYZddSDKosw8LtPOeWoK3JINjAk5jiPQ2cww++
7QGG/g5NDjxFZNDJP1dGiLAxPW6JXwov4v0FmdzfLOZ01jDcgQQZqEpYlgpuI5JE
WUQ9Ho4EzbYCOQ==
-----END CERTIFICATE-----
-----BEGIN RSA PRIVATE KEY-----
MIIBPAIBAAJBAKy+e3dulvXzV7zoTZWc5TzgApr8DmeQHTYC8ydfzH7EECe4R1Xh
5kwIzOuuFfn178FBiS84gngaNcrFi0Z5fAkCAwEAAQJBAIqm/bz4NA1H++Vx5Ewx
OcKp3w19QSaZAwlGRtsUxrP7436QjnREM3Bm8ygU11BjkPVmtrKm6AayQfCHqJoT
ZIECIQDW0BoMoL0HOYM/mrTLhaykYAVqgIeJsPjvkEhTFXWBuQIhAM3deFAvWNu4
nklUQ37XsCT2c9tmNt1LAT+slG2JOTTRAiAuXDtC/m3NYVwyHfFm+zKHRzHkClk2
HjubeEgjpj32AQIhAJqMGTaZVOwevTXvvHwNEH+vRWsAYU/gbx+OQB+7VOcBAiEA
oolb6NMg/R3enNPvS1O4UU1H8wpaF77L4yiSWlE0p4w=
-----END RSA PRIVATE KEY-----
-----BEGIN CERTIFICATE REQUEST-----
MIIBDTCBuAIBADBTMQswCQYDVQQGEwJTRzERMA8GA1UEChMITTJDcnlwdG8xEjAQ
BgNVBAMTCWxvY2FsaG9zdDEdMBsGCSqGSIb3DQEJARYObmdwc0Bwb3N0MS5jb20w
XDANBgkqhkiG9w0BAQEFAANLADBIAkEArL57d26W9fNXvOhNlZzlPOACmvwOZ5Ad
NgLzJ1/MfsQQJ7hHVeHmTAjM664V+fXvwUGJLziCeBo1ysWLRnl8CQIDAQABoAAw
DQYJKoZIhvcNAQEEBQADQQA7uqbrNTjVWpF6By5ZNPvhZ4YdFgkeXFVWi5ao/TaP
Vq4BG021fJ9nlHRtr4rotpgHDX1rr+iWeHKsx4+5DRSy
-----END CERTIFICATE REQUEST-----

View file

@ -0,0 +1,37 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Helper classes for twisted.test.test_ssl.
They are in a separate module so they will not prevent test_ssl importing if
pyOpenSSL is unavailable.
"""
from __future__ import division, absolute_import
from twisted.python.compat import nativeString
from twisted.internet import ssl
from twisted.python.filepath import FilePath
from OpenSSL import SSL
certPath = nativeString(FilePath(__file__.encode("utf-8")
).sibling(b"server.pem").path)
class ClientTLSContext(ssl.ClientContextFactory):
isClient = 1
def getContext(self):
return SSL.Context(SSL.TLSv1_METHOD)
class ServerTLSContext:
isClient = 0
def __init__(self, filename=certPath):
self.filename = filename
def getContext(self):
ctx = SSL.Context(SSL.TLSv1_METHOD)
ctx.use_certificate_file(self.filename)
ctx.use_privatekey_file(self.filename)
return ctx

View file

@ -0,0 +1,39 @@
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTestCase.test_consumer -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Main program for the child process run by
L{twisted.test.test_stdio.StandardInputOutputTestCase.test_consumer} to test
that process transports implement IConsumer properly.
"""
import sys, _preamble
from twisted.python import log, reflect
from twisted.internet import stdio, protocol
from twisted.protocols import basic
def failed(err):
log.startLogging(sys.stderr)
log.err(err)
class ConsumerChild(protocol.Protocol):
def __init__(self, junkPath):
self.junkPath = junkPath
def connectionMade(self):
d = basic.FileSender().beginFileTransfer(file(self.junkPath), self.transport)
d.addErrback(failed)
d.addCallback(lambda ign: self.transport.loseConnection())
def connectionLost(self, reason):
reactor.stop()
if __name__ == '__main__':
reflect.namedAny(sys.argv[1]).install()
from twisted.internet import reactor
stdio.StandardIO(ConsumerChild(sys.argv[2]))
reactor.run()

View file

@ -0,0 +1,66 @@
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTestCase.test_readConnectionLost -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Main program for the child process run by
L{twisted.test.test_stdio.StandardInputOutputTestCase.test_readConnectionLost}
to test that IHalfCloseableProtocol.readConnectionLost works for process
transports.
"""
import sys, _preamble
from zope.interface import implements
from twisted.internet.interfaces import IHalfCloseableProtocol
from twisted.internet import stdio, protocol
from twisted.python import reflect, log
class HalfCloseProtocol(protocol.Protocol):
"""
A protocol to hook up to stdio and observe its transport being
half-closed. If all goes as expected, C{exitCode} will be set to C{0};
otherwise it will be set to C{1} to indicate failure.
"""
implements(IHalfCloseableProtocol)
exitCode = None
def connectionMade(self):
"""
Signal the parent process that we're ready.
"""
self.transport.write("x")
def readConnectionLost(self):
"""
This is the desired event. Once it has happened, stop the reactor so
the process will exit.
"""
self.exitCode = 0
reactor.stop()
def connectionLost(self, reason):
"""
This may only be invoked after C{readConnectionLost}. If it happens
otherwise, mark it as an error and shut down.
"""
if self.exitCode is None:
self.exitCode = 1
log.err(reason, "Unexpected call to connectionLost")
reactor.stop()
if __name__ == '__main__':
reflect.namedAny(sys.argv[1]).install()
log.startLogging(file(sys.argv[2], 'w'))
from twisted.internet import reactor
protocol = HalfCloseProtocol()
stdio.StandardIO(protocol)
reactor.run()
sys.exit(protocol.exitCode)

View file

@ -0,0 +1,32 @@
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTestCase.test_hostAndPeer -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Main program for the child process run by
L{twisted.test.test_stdio.StandardInputOutputTestCase.test_hostAndPeer} to test
that ITransport.getHost() and ITransport.getPeer() work for process transports.
"""
import sys, _preamble
from twisted.internet import stdio, protocol
from twisted.python import reflect
class HostPeerChild(protocol.Protocol):
def connectionMade(self):
self.transport.write('\n'.join([
str(self.transport.getHost()),
str(self.transport.getPeer())]))
self.transport.loseConnection()
def connectionLost(self, reason):
reactor.stop()
if __name__ == '__main__':
reflect.namedAny(sys.argv[1]).install()
from twisted.internet import reactor
stdio.StandardIO(HostPeerChild())
reactor.run()

View file

@ -0,0 +1,45 @@
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTestCase.test_lastWriteReceived -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Main program for the child process run by
L{twisted.test.test_stdio.StandardInputOutputTestCase.test_lastWriteReceived}
to test that L{os.write} can be reliably used after
L{twisted.internet.stdio.StandardIO} has finished.
"""
import sys, _preamble
from twisted.internet.protocol import Protocol
from twisted.internet.stdio import StandardIO
from twisted.python.reflect import namedAny
class LastWriteChild(Protocol):
def __init__(self, reactor, magicString):
self.reactor = reactor
self.magicString = magicString
def connectionMade(self):
self.transport.write(self.magicString)
self.transport.loseConnection()
def connectionLost(self, reason):
self.reactor.stop()
def main(reactor, magicString):
p = LastWriteChild(reactor, magicString)
StandardIO(p)
reactor.run()
if __name__ == '__main__':
namedAny(sys.argv[1]).install()
from twisted.internet import reactor
main(reactor, sys.argv[2])

View file

@ -0,0 +1,48 @@
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTestCase.test_loseConnection -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Main program for the child process run by
L{twisted.test.test_stdio.StandardInputOutputTestCase.test_loseConnection} to
test that ITransport.loseConnection() works for process transports.
"""
import sys, _preamble
from twisted.internet.error import ConnectionDone
from twisted.internet import stdio, protocol
from twisted.python import reflect, log
class LoseConnChild(protocol.Protocol):
exitCode = 0
def connectionMade(self):
self.transport.loseConnection()
def connectionLost(self, reason):
"""
Check that C{reason} is a L{Failure} wrapping a L{ConnectionDone}
instance and stop the reactor. If C{reason} is wrong for some reason,
log something about that in C{self.errorLogFile} and make sure the
process exits with a non-zero status.
"""
try:
try:
reason.trap(ConnectionDone)
except:
log.err(None, "Problem with reason passed to connectionLost")
self.exitCode = 1
finally:
reactor.stop()
if __name__ == '__main__':
reflect.namedAny(sys.argv[1]).install()
log.startLogging(file(sys.argv[2], 'w'))
from twisted.internet import reactor
protocol = LoseConnChild()
stdio.StandardIO(protocol)
reactor.run()
sys.exit(protocol.exitCode)

View file

@ -0,0 +1,55 @@
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTestCase.test_producer -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Main program for the child process run by
L{twisted.test.test_stdio.StandardInputOutputTestCase.test_producer} to test
that process transports implement IProducer properly.
"""
import sys, _preamble
from twisted.internet import stdio, protocol
from twisted.python import log, reflect
class ProducerChild(protocol.Protocol):
_paused = False
buf = ''
def connectionLost(self, reason):
log.msg("*****OVER*****")
reactor.callLater(1, reactor.stop)
# reactor.stop()
def dataReceived(self, bytes):
self.buf += bytes
if self._paused:
log.startLogging(sys.stderr)
log.msg("dataReceived while transport paused!")
self.transport.loseConnection()
else:
self.transport.write(bytes)
if self.buf.endswith('\n0\n'):
self.transport.loseConnection()
else:
self.pause()
def pause(self):
self._paused = True
self.transport.pauseProducing()
reactor.callLater(0.01, self.unpause)
def unpause(self):
self._paused = False
self.transport.resumeProducing()
if __name__ == '__main__':
reflect.namedAny(sys.argv[1]).install()
from twisted.internet import reactor
stdio.StandardIO(ProducerChild())
reactor.run()

View file

@ -0,0 +1,31 @@
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTestCase.test_write -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Main program for the child process run by
L{twisted.test.test_stdio.StandardInputOutputTestCase.test_write} to test that
ITransport.write() works for process transports.
"""
import sys, _preamble
from twisted.internet import stdio, protocol
from twisted.python import reflect
class WriteChild(protocol.Protocol):
def connectionMade(self):
for ch in 'ok!':
self.transport.write(ch)
self.transport.loseConnection()
def connectionLost(self, reason):
reactor.stop()
if __name__ == '__main__':
reflect.namedAny(sys.argv[1]).install()
from twisted.internet import reactor
stdio.StandardIO(WriteChild())
reactor.run()

View file

@ -0,0 +1,30 @@
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTestCase.test_writeSequence -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Main program for the child process run by
L{twisted.test.test_stdio.StandardInputOutputTestCase.test_writeSequence} to test that
ITransport.writeSequence() works for process transports.
"""
import sys, _preamble
from twisted.internet import stdio, protocol
from twisted.python import reflect
class WriteSequenceChild(protocol.Protocol):
def connectionMade(self):
self.transport.writeSequence(list('ok!'))
self.transport.loseConnection()
def connectionLost(self, reason):
reactor.stop()
if __name__ == '__main__':
reflect.namedAny(sys.argv[1]).install()
from twisted.internet import reactor
stdio.StandardIO(WriteSequenceChild())
reactor.run()

View file

@ -0,0 +1,85 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for generic file descriptor based reactor support code.
"""
from __future__ import division, absolute_import
from twisted.trial.unittest import TestCase
from twisted.internet.abstract import isIPAddress
class AddressTests(TestCase):
"""
Tests for address-related functionality.
"""
def test_decimalDotted(self):
"""
L{isIPAddress} should return C{True} for any decimal dotted
representation of an IPv4 address.
"""
self.assertTrue(isIPAddress('0.1.2.3'))
self.assertTrue(isIPAddress('252.253.254.255'))
def test_shortDecimalDotted(self):
"""
L{isIPAddress} should return C{False} for a dotted decimal
representation with fewer or more than four octets.
"""
self.assertFalse(isIPAddress('0'))
self.assertFalse(isIPAddress('0.1'))
self.assertFalse(isIPAddress('0.1.2'))
self.assertFalse(isIPAddress('0.1.2.3.4'))
def test_invalidLetters(self):
"""
L{isIPAddress} should return C{False} for any non-decimal dotted
representation including letters.
"""
self.assertFalse(isIPAddress('a.2.3.4'))
self.assertFalse(isIPAddress('1.b.3.4'))
def test_invalidPunctuation(self):
"""
L{isIPAddress} should return C{False} for a string containing
strange punctuation.
"""
self.assertFalse(isIPAddress(','))
self.assertFalse(isIPAddress('1,2'))
self.assertFalse(isIPAddress('1,2,3'))
self.assertFalse(isIPAddress('1.,.3,4'))
def test_emptyString(self):
"""
L{isIPAddress} should return C{False} for the empty string.
"""
self.assertFalse(isIPAddress(''))
def test_invalidNegative(self):
"""
L{isIPAddress} should return C{False} for negative decimal values.
"""
self.assertFalse(isIPAddress('-1'))
self.assertFalse(isIPAddress('1.-2'))
self.assertFalse(isIPAddress('1.2.-3'))
self.assertFalse(isIPAddress('1.2.-3.4'))
def test_invalidPositive(self):
"""
L{isIPAddress} should return C{False} for a string containing
positive decimal values greater than 255.
"""
self.assertFalse(isIPAddress('256.0.0.0'))
self.assertFalse(isIPAddress('0.256.0.0'))
self.assertFalse(isIPAddress('0.0.256.0'))
self.assertFalse(isIPAddress('0.0.0.256'))
self.assertFalse(isIPAddress('256.256.256.256'))

View file

@ -0,0 +1,819 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for twisted.enterprise.adbapi.
"""
from twisted.trial import unittest
import os, stat
import types
from twisted.enterprise.adbapi import ConnectionPool, ConnectionLost
from twisted.enterprise.adbapi import Connection, Transaction
from twisted.internet import reactor, defer, interfaces
from twisted.python.failure import Failure
simple_table_schema = """
CREATE TABLE simple (
x integer
)
"""
class ADBAPITestBase:
"""Test the asynchronous DB-API code."""
openfun_called = {}
if interfaces.IReactorThreads(reactor, None) is None:
skip = "ADB-API requires threads, no way to test without them"
def extraSetUp(self):
"""
Set up the database and create a connection pool pointing at it.
"""
self.startDB()
self.dbpool = self.makePool(cp_openfun=self.openfun)
self.dbpool.start()
def tearDown(self):
d = self.dbpool.runOperation('DROP TABLE simple')
d.addCallback(lambda res: self.dbpool.close())
d.addCallback(lambda res: self.stopDB())
return d
def openfun(self, conn):
self.openfun_called[conn] = True
def checkOpenfunCalled(self, conn=None):
if not conn:
self.failUnless(self.openfun_called)
else:
self.failUnless(self.openfun_called.has_key(conn))
def testPool(self):
d = self.dbpool.runOperation(simple_table_schema)
if self.test_failures:
d.addCallback(self._testPool_1_1)
d.addCallback(self._testPool_1_2)
d.addCallback(self._testPool_1_3)
d.addCallback(self._testPool_1_4)
d.addCallback(lambda res: self.flushLoggedErrors())
d.addCallback(self._testPool_2)
d.addCallback(self._testPool_3)
d.addCallback(self._testPool_4)
d.addCallback(self._testPool_5)
d.addCallback(self._testPool_6)
d.addCallback(self._testPool_7)
d.addCallback(self._testPool_8)
d.addCallback(self._testPool_9)
return d
def _testPool_1_1(self, res):
d = defer.maybeDeferred(self.dbpool.runQuery, "select * from NOTABLE")
d.addCallbacks(lambda res: self.fail('no exception'),
lambda f: None)
return d
def _testPool_1_2(self, res):
d = defer.maybeDeferred(self.dbpool.runOperation,
"deletexxx from NOTABLE")
d.addCallbacks(lambda res: self.fail('no exception'),
lambda f: None)
return d
def _testPool_1_3(self, res):
d = defer.maybeDeferred(self.dbpool.runInteraction,
self.bad_interaction)
d.addCallbacks(lambda res: self.fail('no exception'),
lambda f: None)
return d
def _testPool_1_4(self, res):
d = defer.maybeDeferred(self.dbpool.runWithConnection,
self.bad_withConnection)
d.addCallbacks(lambda res: self.fail('no exception'),
lambda f: None)
return d
def _testPool_2(self, res):
# verify simple table is empty
sql = "select count(1) from simple"
d = self.dbpool.runQuery(sql)
def _check(row):
self.failUnless(int(row[0][0]) == 0, "Interaction not rolled back")
self.checkOpenfunCalled()
d.addCallback(_check)
return d
def _testPool_3(self, res):
sql = "select count(1) from simple"
inserts = []
# add some rows to simple table (runOperation)
for i in range(self.num_iterations):
sql = "insert into simple(x) values(%d)" % i
inserts.append(self.dbpool.runOperation(sql))
d = defer.gatherResults(inserts)
def _select(res):
# make sure they were added (runQuery)
sql = "select x from simple order by x";
d = self.dbpool.runQuery(sql)
return d
d.addCallback(_select)
def _check(rows):
self.failUnless(len(rows) == self.num_iterations,
"Wrong number of rows")
for i in range(self.num_iterations):
self.failUnless(len(rows[i]) == 1, "Wrong size row")
self.failUnless(rows[i][0] == i, "Values not returned.")
d.addCallback(_check)
return d
def _testPool_4(self, res):
# runInteraction
d = self.dbpool.runInteraction(self.interaction)
d.addCallback(lambda res: self.assertEqual(res, "done"))
return d
def _testPool_5(self, res):
# withConnection
d = self.dbpool.runWithConnection(self.withConnection)
d.addCallback(lambda res: self.assertEqual(res, "done"))
return d
def _testPool_6(self, res):
# Test a withConnection cannot be closed
d = self.dbpool.runWithConnection(self.close_withConnection)
return d
def _testPool_7(self, res):
# give the pool a workout
ds = []
for i in range(self.num_iterations):
sql = "select x from simple where x = %d" % i
ds.append(self.dbpool.runQuery(sql))
dlist = defer.DeferredList(ds, fireOnOneErrback=True)
def _check(result):
for i in range(self.num_iterations):
self.failUnless(result[i][1][0][0] == i, "Value not returned")
dlist.addCallback(_check)
return dlist
def _testPool_8(self, res):
# now delete everything
ds = []
for i in range(self.num_iterations):
sql = "delete from simple where x = %d" % i
ds.append(self.dbpool.runOperation(sql))
dlist = defer.DeferredList(ds, fireOnOneErrback=True)
return dlist
def _testPool_9(self, res):
# verify simple table is empty
sql = "select count(1) from simple"
d = self.dbpool.runQuery(sql)
def _check(row):
self.failUnless(int(row[0][0]) == 0,
"Didn't successfully delete table contents")
self.checkConnect()
d.addCallback(_check)
return d
def checkConnect(self):
"""Check the connect/disconnect synchronous calls."""
conn = self.dbpool.connect()
self.checkOpenfunCalled(conn)
curs = conn.cursor()
curs.execute("insert into simple(x) values(1)")
curs.execute("select x from simple")
res = curs.fetchall()
self.assertEqual(len(res), 1)
self.assertEqual(len(res[0]), 1)
self.assertEqual(res[0][0], 1)
curs.execute("delete from simple")
curs.execute("select x from simple")
self.assertEqual(len(curs.fetchall()), 0)
curs.close()
self.dbpool.disconnect(conn)
def interaction(self, transaction):
transaction.execute("select x from simple order by x")
for i in range(self.num_iterations):
row = transaction.fetchone()
self.failUnless(len(row) == 1, "Wrong size row")
self.failUnless(row[0] == i, "Value not returned.")
# should test this, but gadfly throws an exception instead
#self.failUnless(transaction.fetchone() is None, "Too many rows")
return "done"
def bad_interaction(self, transaction):
if self.can_rollback:
transaction.execute("insert into simple(x) values(0)")
transaction.execute("select * from NOTABLE")
def withConnection(self, conn):
curs = conn.cursor()
try:
curs.execute("select x from simple order by x")
for i in range(self.num_iterations):
row = curs.fetchone()
self.failUnless(len(row) == 1, "Wrong size row")
self.failUnless(row[0] == i, "Value not returned.")
# should test this, but gadfly throws an exception instead
#self.failUnless(transaction.fetchone() is None, "Too many rows")
finally:
curs.close()
return "done"
def close_withConnection(self, conn):
conn.close()
def bad_withConnection(self, conn):
curs = conn.cursor()
try:
curs.execute("select * from NOTABLE")
finally:
curs.close()
class ReconnectTestBase:
"""Test the asynchronous DB-API code with reconnect."""
if interfaces.IReactorThreads(reactor, None) is None:
skip = "ADB-API requires threads, no way to test without them"
def extraSetUp(self):
"""
Skip the test if C{good_sql} is unavailable. Otherwise, set up the
database, create a connection pool pointed at it, and set up a simple
schema in it.
"""
if self.good_sql is None:
raise unittest.SkipTest('no good sql for reconnect test')
self.startDB()
self.dbpool = self.makePool(cp_max=1, cp_reconnect=True,
cp_good_sql=self.good_sql)
self.dbpool.start()
return self.dbpool.runOperation(simple_table_schema)
def tearDown(self):
d = self.dbpool.runOperation('DROP TABLE simple')
d.addCallback(lambda res: self.dbpool.close())
d.addCallback(lambda res: self.stopDB())
return d
def testPool(self):
d = defer.succeed(None)
d.addCallback(self._testPool_1)
d.addCallback(self._testPool_2)
if not self.early_reconnect:
d.addCallback(self._testPool_3)
d.addCallback(self._testPool_4)
d.addCallback(self._testPool_5)
return d
def _testPool_1(self, res):
sql = "select count(1) from simple"
d = self.dbpool.runQuery(sql)
def _check(row):
self.failUnless(int(row[0][0]) == 0, "Table not empty")
d.addCallback(_check)
return d
def _testPool_2(self, res):
# reach in and close the connection manually
self.dbpool.connections.values()[0].close()
def _testPool_3(self, res):
sql = "select count(1) from simple"
d = defer.maybeDeferred(self.dbpool.runQuery, sql)
d.addCallbacks(lambda res: self.fail('no exception'),
lambda f: None)
return d
def _testPool_4(self, res):
sql = "select count(1) from simple"
d = self.dbpool.runQuery(sql)
def _check(row):
self.failUnless(int(row[0][0]) == 0, "Table not empty")
d.addCallback(_check)
return d
def _testPool_5(self, res):
self.flushLoggedErrors()
sql = "select * from NOTABLE" # bad sql
d = defer.maybeDeferred(self.dbpool.runQuery, sql)
d.addCallbacks(lambda res: self.fail('no exception'),
lambda f: self.failIf(f.check(ConnectionLost)))
return d
class DBTestConnector:
"""A class which knows how to test for the presence of
and establish a connection to a relational database.
To enable test cases which use a central, system database,
you must create a database named DB_NAME with a user DB_USER
and password DB_PASS with full access rights to database DB_NAME.
"""
TEST_PREFIX = None # used for creating new test cases
DB_NAME = "twisted_test"
DB_USER = 'twisted_test'
DB_PASS = 'twisted_test'
DB_DIR = None # directory for database storage
nulls_ok = True # nulls supported
trailing_spaces_ok = True # trailing spaces in strings preserved
can_rollback = True # rollback supported
test_failures = True # test bad sql?
escape_slashes = True # escape \ in sql?
good_sql = ConnectionPool.good_sql
early_reconnect = True # cursor() will fail on closed connection
can_clear = True # can try to clear out tables when starting
num_iterations = 50 # number of iterations for test loops
# (lower this for slow db's)
def setUp(self):
self.DB_DIR = self.mktemp()
os.mkdir(self.DB_DIR)
if not self.can_connect():
raise unittest.SkipTest('%s: Cannot access db' % self.TEST_PREFIX)
return self.extraSetUp()
def can_connect(self):
"""Return true if this database is present on the system
and can be used in a test."""
raise NotImplementedError()
def startDB(self):
"""Take any steps needed to bring database up."""
pass
def stopDB(self):
"""Bring database down, if needed."""
pass
def makePool(self, **newkw):
"""Create a connection pool with additional keyword arguments."""
args, kw = self.getPoolArgs()
kw = kw.copy()
kw.update(newkw)
return ConnectionPool(*args, **kw)
def getPoolArgs(self):
"""Return a tuple (args, kw) of list and keyword arguments
that need to be passed to ConnectionPool to create a connection
to this database."""
raise NotImplementedError()
class GadflyConnector(DBTestConnector):
TEST_PREFIX = 'Gadfly'
nulls_ok = False
can_rollback = False
escape_slashes = False
good_sql = 'select * from simple where 1=0'
num_iterations = 1 # slow
def can_connect(self):
try: import gadfly
except: return False
if not getattr(gadfly, 'connect', None):
gadfly.connect = gadfly.gadfly
return True
def startDB(self):
import gadfly
conn = gadfly.gadfly()
conn.startup(self.DB_NAME, self.DB_DIR)
# gadfly seems to want us to create something to get the db going
cursor = conn.cursor()
cursor.execute("create table x (x integer)")
conn.commit()
conn.close()
def getPoolArgs(self):
args = ('gadfly', self.DB_NAME, self.DB_DIR)
kw = {'cp_max': 1}
return args, kw
class SQLiteConnector(DBTestConnector):
TEST_PREFIX = 'SQLite'
escape_slashes = False
num_iterations = 1 # slow
def can_connect(self):
try: import sqlite
except: return False
return True
def startDB(self):
self.database = os.path.join(self.DB_DIR, self.DB_NAME)
if os.path.exists(self.database):
os.unlink(self.database)
def getPoolArgs(self):
args = ('sqlite',)
kw = {'database': self.database, 'cp_max': 1}
return args, kw
class PyPgSQLConnector(DBTestConnector):
TEST_PREFIX = "PyPgSQL"
def can_connect(self):
try: from pyPgSQL import PgSQL
except: return False
try:
conn = PgSQL.connect(database=self.DB_NAME, user=self.DB_USER,
password=self.DB_PASS)
conn.close()
return True
except:
return False
def getPoolArgs(self):
args = ('pyPgSQL.PgSQL',)
kw = {'database': self.DB_NAME, 'user': self.DB_USER,
'password': self.DB_PASS, 'cp_min': 0}
return args, kw
class PsycopgConnector(DBTestConnector):
TEST_PREFIX = 'Psycopg'
def can_connect(self):
try: import psycopg
except: return False
try:
conn = psycopg.connect(database=self.DB_NAME, user=self.DB_USER,
password=self.DB_PASS)
conn.close()
return True
except:
return False
def getPoolArgs(self):
args = ('psycopg',)
kw = {'database': self.DB_NAME, 'user': self.DB_USER,
'password': self.DB_PASS, 'cp_min': 0}
return args, kw
class MySQLConnector(DBTestConnector):
TEST_PREFIX = 'MySQL'
trailing_spaces_ok = False
can_rollback = False
early_reconnect = False
def can_connect(self):
try: import MySQLdb
except: return False
try:
conn = MySQLdb.connect(db=self.DB_NAME, user=self.DB_USER,
passwd=self.DB_PASS)
conn.close()
return True
except:
return False
def getPoolArgs(self):
args = ('MySQLdb',)
kw = {'db': self.DB_NAME, 'user': self.DB_USER, 'passwd': self.DB_PASS}
return args, kw
class FirebirdConnector(DBTestConnector):
TEST_PREFIX = 'Firebird'
test_failures = False # failure testing causes problems
escape_slashes = False
good_sql = None # firebird doesn't handle failed sql well
can_clear = False # firebird is not so good
num_iterations = 5 # slow
def can_connect(self):
try: import kinterbasdb
except: return False
try:
self.startDB()
self.stopDB()
return True
except:
return False
def startDB(self):
import kinterbasdb
self.DB_NAME = os.path.join(self.DB_DIR, DBTestConnector.DB_NAME)
os.chmod(self.DB_DIR, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO)
sql = 'create database "%s" user "%s" password "%s"'
sql %= (self.DB_NAME, self.DB_USER, self.DB_PASS);
conn = kinterbasdb.create_database(sql)
conn.close()
def getPoolArgs(self):
args = ('kinterbasdb',)
kw = {'database': self.DB_NAME, 'host': '127.0.0.1',
'user': self.DB_USER, 'password': self.DB_PASS}
return args, kw
def stopDB(self):
import kinterbasdb
conn = kinterbasdb.connect(database=self.DB_NAME,
host='127.0.0.1', user=self.DB_USER,
password=self.DB_PASS)
conn.drop_database()
def makeSQLTests(base, suffix, globals):
"""
Make a test case for every db connector which can connect.
@param base: Base class for test case. Additional base classes
will be a DBConnector subclass and unittest.TestCase
@param suffix: A suffix used to create test case names. Prefixes
are defined in the DBConnector subclasses.
"""
connectors = [GadflyConnector, SQLiteConnector, PyPgSQLConnector,
PsycopgConnector, MySQLConnector, FirebirdConnector]
for connclass in connectors:
name = connclass.TEST_PREFIX + suffix
klass = types.ClassType(name, (connclass, base, unittest.TestCase),
base.__dict__)
globals[name] = klass
# GadflyADBAPITestCase SQLiteADBAPITestCase PyPgSQLADBAPITestCase
# PsycopgADBAPITestCase MySQLADBAPITestCase FirebirdADBAPITestCase
makeSQLTests(ADBAPITestBase, 'ADBAPITestCase', globals())
# GadflyReconnectTestCase SQLiteReconnectTestCase PyPgSQLReconnectTestCase
# PsycopgReconnectTestCase MySQLReconnectTestCase FirebirdReconnectTestCase
makeSQLTests(ReconnectTestBase, 'ReconnectTestCase', globals())
class FakePool(object):
"""
A fake L{ConnectionPool} for tests.
@ivar connectionFactory: factory for making connections returned by the
C{connect} method.
@type connectionFactory: any callable
"""
reconnect = True
noisy = True
def __init__(self, connectionFactory):
self.connectionFactory = connectionFactory
def connect(self):
"""
Return an instance of C{self.connectionFactory}.
"""
return self.connectionFactory()
def disconnect(self, connection):
"""
Do nothing.
"""
class ConnectionTestCase(unittest.TestCase):
"""
Tests for the L{Connection} class.
"""
def test_rollbackErrorLogged(self):
"""
If an error happens during rollback, L{ConnectionLost} is raised but
the original error is logged.
"""
class ConnectionRollbackRaise(object):
def rollback(self):
raise RuntimeError("problem!")
pool = FakePool(ConnectionRollbackRaise)
connection = Connection(pool)
self.assertRaises(ConnectionLost, connection.rollback)
errors = self.flushLoggedErrors(RuntimeError)
self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].value.args[0], "problem!")
class TransactionTestCase(unittest.TestCase):
"""
Tests for the L{Transaction} class.
"""
def test_reopenLogErrorIfReconnect(self):
"""
If the cursor creation raises an error in L{Transaction.reopen}, it
reconnects but log the error occurred.
"""
class ConnectionCursorRaise(object):
count = 0
def reconnect(self):
pass
def cursor(self):
if self.count == 0:
self.count += 1
raise RuntimeError("problem!")
pool = FakePool(None)
transaction = Transaction(pool, ConnectionCursorRaise())
transaction.reopen()
errors = self.flushLoggedErrors(RuntimeError)
self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].value.args[0], "problem!")
class NonThreadPool(object):
def callInThreadWithCallback(self, onResult, f, *a, **kw):
success = True
try:
result = f(*a, **kw)
except Exception, e:
success = False
result = Failure()
onResult(success, result)
class DummyConnectionPool(ConnectionPool):
"""
A testable L{ConnectionPool};
"""
threadpool = NonThreadPool()
def __init__(self):
"""
Don't forward init call.
"""
self.reactor = reactor
class EventReactor(object):
"""
Partial L{IReactorCore} implementation with simple event-related
methods.
@ivar _running: A C{bool} indicating whether the reactor is pretending
to have been started already or not.
@ivar triggers: A C{list} of pending system event triggers.
"""
def __init__(self, running):
self._running = running
self.triggers = []
def callWhenRunning(self, function):
if self._running:
function()
else:
return self.addSystemEventTrigger('after', 'startup', function)
def addSystemEventTrigger(self, phase, event, trigger):
handle = (phase, event, trigger)
self.triggers.append(handle)
return handle
def removeSystemEventTrigger(self, handle):
self.triggers.remove(handle)
class ConnectionPoolTestCase(unittest.TestCase):
"""
Unit tests for L{ConnectionPool}.
"""
def test_runWithConnectionRaiseOriginalError(self):
"""
If rollback fails, L{ConnectionPool.runWithConnection} raises the
original exception and log the error of the rollback.
"""
class ConnectionRollbackRaise(object):
def __init__(self, pool):
pass
def rollback(self):
raise RuntimeError("problem!")
def raisingFunction(connection):
raise ValueError("foo")
pool = DummyConnectionPool()
pool.connectionFactory = ConnectionRollbackRaise
d = pool.runWithConnection(raisingFunction)
d = self.assertFailure(d, ValueError)
def cbFailed(ignored):
errors = self.flushLoggedErrors(RuntimeError)
self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].value.args[0], "problem!")
d.addCallback(cbFailed)
return d
def test_closeLogError(self):
"""
L{ConnectionPool._close} logs exceptions.
"""
class ConnectionCloseRaise(object):
def close(self):
raise RuntimeError("problem!")
pool = DummyConnectionPool()
pool._close(ConnectionCloseRaise())
errors = self.flushLoggedErrors(RuntimeError)
self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].value.args[0], "problem!")
def test_runWithInteractionRaiseOriginalError(self):
"""
If rollback fails, L{ConnectionPool.runInteraction} raises the
original exception and log the error of the rollback.
"""
class ConnectionRollbackRaise(object):
def __init__(self, pool):
pass
def rollback(self):
raise RuntimeError("problem!")
class DummyTransaction(object):
def __init__(self, pool, connection):
pass
def raisingFunction(transaction):
raise ValueError("foo")
pool = DummyConnectionPool()
pool.connectionFactory = ConnectionRollbackRaise
pool.transactionFactory = DummyTransaction
d = pool.runInteraction(raisingFunction)
d = self.assertFailure(d, ValueError)
def cbFailed(ignored):
errors = self.flushLoggedErrors(RuntimeError)
self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].value.args[0], "problem!")
d.addCallback(cbFailed)
return d
def test_unstartedClose(self):
"""
If L{ConnectionPool.close} is called without L{ConnectionPool.start}
having been called, the pool's startup event is cancelled.
"""
reactor = EventReactor(False)
pool = ConnectionPool('twisted.test.test_adbapi', cp_reactor=reactor)
# There should be a startup trigger waiting.
self.assertEqual(reactor.triggers, [('after', 'startup', pool._start)])
pool.close()
# But not anymore.
self.assertFalse(reactor.triggers)
def test_startedClose(self):
"""
If L{ConnectionPool.close} is called after it has been started, but
not by its shutdown trigger, the shutdown trigger is cancelled.
"""
reactor = EventReactor(True)
pool = ConnectionPool('twisted.test.test_adbapi', cp_reactor=reactor)
# There should be a shutdown trigger waiting.
self.assertEqual(reactor.triggers, [('during', 'shutdown', pool.finalClose)])
pool.close()
# But not anymore.
self.assertFalse(reactor.triggers)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,896 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.application} and its interaction with
L{twisted.persisted.sob}.
"""
import copy, os, pickle
from StringIO import StringIO
from twisted.trial import unittest
from twisted.application import service, internet, app
from twisted.persisted import sob
from twisted.python import usage
from twisted.internet import interfaces, defer
from twisted.protocols import wire, basic
from twisted.internet import protocol, reactor
from twisted.application import reactors
from twisted.test.proto_helpers import MemoryReactor
from twisted.python.test.modules_helpers import TwistedModulesMixin
class Dummy:
processName=None
class TestService(unittest.TestCase):
def testName(self):
s = service.Service()
s.setName("hello")
self.assertEqual(s.name, "hello")
def testParent(self):
s = service.Service()
p = service.MultiService()
s.setServiceParent(p)
self.assertEqual(list(p), [s])
self.assertEqual(s.parent, p)
def testApplicationAsParent(self):
s = service.Service()
p = service.Application("")
s.setServiceParent(p)
self.assertEqual(list(service.IServiceCollection(p)), [s])
self.assertEqual(s.parent, service.IServiceCollection(p))
def testNamedChild(self):
s = service.Service()
p = service.MultiService()
s.setName("hello")
s.setServiceParent(p)
self.assertEqual(list(p), [s])
self.assertEqual(s.parent, p)
self.assertEqual(p.getServiceNamed("hello"), s)
def testDoublyNamedChild(self):
s = service.Service()
p = service.MultiService()
s.setName("hello")
s.setServiceParent(p)
self.failUnlessRaises(RuntimeError, s.setName, "lala")
def testDuplicateNamedChild(self):
s = service.Service()
p = service.MultiService()
s.setName("hello")
s.setServiceParent(p)
s = service.Service()
s.setName("hello")
self.failUnlessRaises(RuntimeError, s.setServiceParent, p)
def testDisowning(self):
s = service.Service()
p = service.MultiService()
s.setServiceParent(p)
self.assertEqual(list(p), [s])
self.assertEqual(s.parent, p)
s.disownServiceParent()
self.assertEqual(list(p), [])
self.assertEqual(s.parent, None)
def testRunning(self):
s = service.Service()
self.assert_(not s.running)
s.startService()
self.assert_(s.running)
s.stopService()
self.assert_(not s.running)
def testRunningChildren1(self):
s = service.Service()
p = service.MultiService()
s.setServiceParent(p)
self.assert_(not s.running)
self.assert_(not p.running)
p.startService()
self.assert_(s.running)
self.assert_(p.running)
p.stopService()
self.assert_(not s.running)
self.assert_(not p.running)
def testRunningChildren2(self):
s = service.Service()
def checkRunning():
self.assert_(s.running)
t = service.Service()
t.stopService = checkRunning
t.startService = checkRunning
p = service.MultiService()
s.setServiceParent(p)
t.setServiceParent(p)
p.startService()
p.stopService()
def testAddingIntoRunning(self):
p = service.MultiService()
p.startService()
s = service.Service()
self.assert_(not s.running)
s.setServiceParent(p)
self.assert_(s.running)
s.disownServiceParent()
self.assert_(not s.running)
def testPrivileged(self):
s = service.Service()
def pss():
s.privilegedStarted = 1
s.privilegedStartService = pss
s1 = service.Service()
p = service.MultiService()
s.setServiceParent(p)
s1.setServiceParent(p)
p.privilegedStartService()
self.assert_(s.privilegedStarted)
def testCopying(self):
s = service.Service()
s.startService()
s1 = copy.copy(s)
self.assert_(not s1.running)
self.assert_(s.running)
if hasattr(os, "getuid"):
curuid = os.getuid()
curgid = os.getgid()
else:
curuid = curgid = 0
class TestProcess(unittest.TestCase):
def testID(self):
p = service.Process(5, 6)
self.assertEqual(p.uid, 5)
self.assertEqual(p.gid, 6)
def testDefaults(self):
p = service.Process(5)
self.assertEqual(p.uid, 5)
self.assertEqual(p.gid, None)
p = service.Process(gid=5)
self.assertEqual(p.uid, None)
self.assertEqual(p.gid, 5)
p = service.Process()
self.assertEqual(p.uid, None)
self.assertEqual(p.gid, None)
def testProcessName(self):
p = service.Process()
self.assertEqual(p.processName, None)
p.processName = 'hello'
self.assertEqual(p.processName, 'hello')
class TestInterfaces(unittest.TestCase):
def testService(self):
self.assert_(service.IService.providedBy(service.Service()))
def testMultiService(self):
self.assert_(service.IService.providedBy(service.MultiService()))
self.assert_(service.IServiceCollection.providedBy(service.MultiService()))
def testProcess(self):
self.assert_(service.IProcess.providedBy(service.Process()))
class TestApplication(unittest.TestCase):
def testConstructor(self):
service.Application("hello")
service.Application("hello", 5)
service.Application("hello", 5, 6)
def testProcessComponent(self):
a = service.Application("hello")
self.assertEqual(service.IProcess(a).uid, None)
self.assertEqual(service.IProcess(a).gid, None)
a = service.Application("hello", 5)
self.assertEqual(service.IProcess(a).uid, 5)
self.assertEqual(service.IProcess(a).gid, None)
a = service.Application("hello", 5, 6)
self.assertEqual(service.IProcess(a).uid, 5)
self.assertEqual(service.IProcess(a).gid, 6)
def testServiceComponent(self):
a = service.Application("hello")
self.assert_(service.IService(a) is service.IServiceCollection(a))
self.assertEqual(service.IService(a).name, "hello")
self.assertEqual(service.IService(a).parent, None)
def testPersistableComponent(self):
a = service.Application("hello")
p = sob.IPersistable(a)
self.assertEqual(p.style, 'pickle')
self.assertEqual(p.name, 'hello')
self.assert_(p.original is a)
class TestLoading(unittest.TestCase):
def test_simpleStoreAndLoad(self):
a = service.Application("hello")
p = sob.IPersistable(a)
for style in 'source pickle'.split():
p.setStyle(style)
p.save()
a1 = service.loadApplication("hello.ta"+style[0], style)
self.assertEqual(service.IService(a1).name, "hello")
f = open("hello.tac", 'w')
f.writelines([
"from twisted.application import service\n",
"application = service.Application('hello')\n",
])
f.close()
a1 = service.loadApplication("hello.tac", 'python')
self.assertEqual(service.IService(a1).name, "hello")
class TestAppSupport(unittest.TestCase):
def testPassphrase(self):
self.assertEqual(app.getPassphrase(0), None)
def testLoadApplication(self):
"""
Test loading an application file in different dump format.
"""
a = service.Application("hello")
baseconfig = {'file': None, 'source': None, 'python':None}
for style in 'source pickle'.split():
config = baseconfig.copy()
config[{'pickle': 'file'}.get(style, style)] = 'helloapplication'
sob.IPersistable(a).setStyle(style)
sob.IPersistable(a).save(filename='helloapplication')
a1 = app.getApplication(config, None)
self.assertEqual(service.IService(a1).name, "hello")
config = baseconfig.copy()
config['python'] = 'helloapplication'
f = open("helloapplication", 'w')
f.writelines([
"from twisted.application import service\n",
"application = service.Application('hello')\n",
])
f.close()
a1 = app.getApplication(config, None)
self.assertEqual(service.IService(a1).name, "hello")
def test_convertStyle(self):
appl = service.Application("lala")
for instyle in 'source pickle'.split():
for outstyle in 'source pickle'.split():
sob.IPersistable(appl).setStyle(instyle)
sob.IPersistable(appl).save(filename="converttest")
app.convertStyle("converttest", instyle, None,
"converttest.out", outstyle, 0)
appl2 = service.loadApplication("converttest.out", outstyle)
self.assertEqual(service.IService(appl2).name, "lala")
def test_startApplication(self):
appl = service.Application("lala")
app.startApplication(appl, 0)
self.assert_(service.IService(appl).running)
class Foo(basic.LineReceiver):
def connectionMade(self):
self.transport.write('lalala\r\n')
def lineReceived(self, line):
self.factory.line = line
self.transport.loseConnection()
def connectionLost(self, reason):
self.factory.d.callback(self.factory.line)
class DummyApp:
processName = None
def addService(self, service):
self.services[service.name] = service
def removeService(self, service):
del self.services[service.name]
class TimerTarget:
def __init__(self):
self.l = []
def append(self, what):
self.l.append(what)
class TestEcho(wire.Echo):
def connectionLost(self, reason):
self.d.callback(True)
class TestInternet2(unittest.TestCase):
def testTCP(self):
s = service.MultiService()
s.startService()
factory = protocol.ServerFactory()
factory.protocol = TestEcho
TestEcho.d = defer.Deferred()
t = internet.TCPServer(0, factory)
t.setServiceParent(s)
num = t._port.getHost().port
factory = protocol.ClientFactory()
factory.d = defer.Deferred()
factory.protocol = Foo
factory.line = None
internet.TCPClient('127.0.0.1', num, factory).setServiceParent(s)
factory.d.addCallback(self.assertEqual, 'lalala')
factory.d.addCallback(lambda x : s.stopService())
factory.d.addCallback(lambda x : TestEcho.d)
return factory.d
def test_UDP(self):
"""
Test L{internet.UDPServer} with a random port: starting the service
should give it valid port, and stopService should free it so that we
can start a server on the same port again.
"""
if not interfaces.IReactorUDP(reactor, None):
raise unittest.SkipTest("This reactor does not support UDP sockets")
p = protocol.DatagramProtocol()
t = internet.UDPServer(0, p)
t.startService()
num = t._port.getHost().port
self.assertNotEquals(num, 0)
def onStop(ignored):
t = internet.UDPServer(num, p)
t.startService()
return t.stopService()
return defer.maybeDeferred(t.stopService).addCallback(onStop)
def test_deprecatedUDPClient(self):
"""
L{internet.UDPClient} is deprecated since Twisted-13.1.
"""
internet.UDPClient
warningsShown = self.flushWarnings([self.test_deprecatedUDPClient])
self.assertEqual(1, len(warningsShown))
self.assertEqual(
"twisted.application.internet.UDPClient was deprecated in "
"Twisted 13.1.0: It relies upon IReactorUDP.connectUDP "
"which was removed in Twisted 10. "
"Use twisted.application.internet.UDPServer instead.",
warningsShown[0]['message'])
def testPrivileged(self):
factory = protocol.ServerFactory()
factory.protocol = TestEcho
TestEcho.d = defer.Deferred()
t = internet.TCPServer(0, factory)
t.privileged = 1
t.privilegedStartService()
num = t._port.getHost().port
factory = protocol.ClientFactory()
factory.d = defer.Deferred()
factory.protocol = Foo
factory.line = None
c = internet.TCPClient('127.0.0.1', num, factory)
c.startService()
factory.d.addCallback(self.assertEqual, 'lalala')
factory.d.addCallback(lambda x : c.stopService())
factory.d.addCallback(lambda x : t.stopService())
factory.d.addCallback(lambda x : TestEcho.d)
return factory.d
def testConnectionGettingRefused(self):
factory = protocol.ServerFactory()
factory.protocol = wire.Echo
t = internet.TCPServer(0, factory)
t.startService()
num = t._port.getHost().port
t.stopService()
d = defer.Deferred()
factory = protocol.ClientFactory()
factory.clientConnectionFailed = lambda *args: d.callback(None)
c = internet.TCPClient('127.0.0.1', num, factory)
c.startService()
return d
def testUNIX(self):
# FIXME: This test is far too dense. It needs comments.
# -- spiv, 2004-11-07
if not interfaces.IReactorUNIX(reactor, None):
raise unittest.SkipTest, "This reactor does not support UNIX domain sockets"
s = service.MultiService()
s.startService()
factory = protocol.ServerFactory()
factory.protocol = TestEcho
TestEcho.d = defer.Deferred()
t = internet.UNIXServer('echo.skt', factory)
t.setServiceParent(s)
factory = protocol.ClientFactory()
factory.protocol = Foo
factory.d = defer.Deferred()
factory.line = None
internet.UNIXClient('echo.skt', factory).setServiceParent(s)
factory.d.addCallback(self.assertEqual, 'lalala')
factory.d.addCallback(lambda x : s.stopService())
factory.d.addCallback(lambda x : TestEcho.d)
factory.d.addCallback(self._cbTestUnix, factory, s)
return factory.d
def _cbTestUnix(self, ignored, factory, s):
TestEcho.d = defer.Deferred()
factory.line = None
factory.d = defer.Deferred()
s.startService()
factory.d.addCallback(self.assertEqual, 'lalala')
factory.d.addCallback(lambda x : s.stopService())
factory.d.addCallback(lambda x : TestEcho.d)
return factory.d
def testVolatile(self):
if not interfaces.IReactorUNIX(reactor, None):
raise unittest.SkipTest, "This reactor does not support UNIX domain sockets"
factory = protocol.ServerFactory()
factory.protocol = wire.Echo
t = internet.UNIXServer('echo.skt', factory)
t.startService()
self.failIfIdentical(t._port, None)
t1 = copy.copy(t)
self.assertIdentical(t1._port, None)
t.stopService()
self.assertIdentical(t._port, None)
self.failIf(t.running)
factory = protocol.ClientFactory()
factory.protocol = wire.Echo
t = internet.UNIXClient('echo.skt', factory)
t.startService()
self.failIfIdentical(t._connection, None)
t1 = copy.copy(t)
self.assertIdentical(t1._connection, None)
t.stopService()
self.assertIdentical(t._connection, None)
self.failIf(t.running)
def testStoppingServer(self):
if not interfaces.IReactorUNIX(reactor, None):
raise unittest.SkipTest, "This reactor does not support UNIX domain sockets"
factory = protocol.ServerFactory()
factory.protocol = wire.Echo
t = internet.UNIXServer('echo.skt', factory)
t.startService()
t.stopService()
self.failIf(t.running)
factory = protocol.ClientFactory()
d = defer.Deferred()
factory.clientConnectionFailed = lambda *args: d.callback(None)
reactor.connectUNIX('echo.skt', factory)
return d
def testPickledTimer(self):
target = TimerTarget()
t0 = internet.TimerService(1, target.append, "hello")
t0.startService()
s = pickle.dumps(t0)
t0.stopService()
t = pickle.loads(s)
self.failIf(t.running)
def testBrokenTimer(self):
d = defer.Deferred()
t = internet.TimerService(1, lambda: 1 // 0)
oldFailed = t._failed
def _failed(why):
oldFailed(why)
d.callback(None)
t._failed = _failed
t.startService()
d.addCallback(lambda x : t.stopService)
d.addCallback(lambda x : self.assertEqual(
[ZeroDivisionError],
[o.value.__class__ for o in self.flushLoggedErrors(ZeroDivisionError)]))
return d
def test_everythingThere(self):
"""
L{twisted.application.internet} dynamically defines a set of
L{service.Service} subclasses that in general have corresponding
reactor.listenXXX or reactor.connectXXX calls.
"""
trans = 'TCP UNIX SSL UDP UNIXDatagram Multicast'.split()
for tran in trans[:]:
if not getattr(interfaces, "IReactor" + tran)(reactor, None):
trans.remove(tran)
for tran in trans:
for side in 'Server Client'.split():
if tran == "Multicast" and side == "Client":
continue
self.assertTrue(hasattr(internet, tran + side))
method = getattr(internet, tran + side).method
prefix = {'Server': 'listen', 'Client': 'connect'}[side]
self.assertTrue(hasattr(reactor, prefix + method) or
(prefix == "connect" and method == "UDP"))
o = getattr(internet, tran + side)()
self.assertEqual(service.IService(o), o)
def test_importAll(self):
"""
L{twisted.application.internet} dynamically defines L{service.Service}
subclasses. This test ensures that the subclasses exposed by C{__all__}
are valid attributes of the module.
"""
for cls in internet.__all__:
self.assertTrue(
hasattr(internet, cls),
'%s not importable from twisted.application.internet' % (cls,))
def test_reactorParametrizationInServer(self):
"""
L{internet._AbstractServer} supports a C{reactor} keyword argument
that can be used to parametrize the reactor used to listen for
connections.
"""
reactor = MemoryReactor()
factory = object()
t = internet.TCPServer(1234, factory, reactor=reactor)
t.startService()
self.assertEqual(reactor.tcpServers.pop()[:2], (1234, factory))
def test_reactorParametrizationInClient(self):
"""
L{internet._AbstractClient} supports a C{reactor} keyword arguments
that can be used to parametrize the reactor used to create new client
connections.
"""
reactor = MemoryReactor()
factory = protocol.ClientFactory()
t = internet.TCPClient('127.0.0.1', 1234, factory, reactor=reactor)
t.startService()
self.assertEqual(
reactor.tcpClients.pop()[:3], ('127.0.0.1', 1234, factory))
def test_reactorParametrizationInServerMultipleStart(self):
"""
Like L{test_reactorParametrizationInServer}, but stop and restart the
service and check that the given reactor is still used.
"""
reactor = MemoryReactor()
factory = protocol.Factory()
t = internet.TCPServer(1234, factory, reactor=reactor)
t.startService()
self.assertEqual(reactor.tcpServers.pop()[:2], (1234, factory))
t.stopService()
t.startService()
self.assertEqual(reactor.tcpServers.pop()[:2], (1234, factory))
def test_reactorParametrizationInClientMultipleStart(self):
"""
Like L{test_reactorParametrizationInClient}, but stop and restart the
service and check that the given reactor is still used.
"""
reactor = MemoryReactor()
factory = protocol.ClientFactory()
t = internet.TCPClient('127.0.0.1', 1234, factory, reactor=reactor)
t.startService()
self.assertEqual(
reactor.tcpClients.pop()[:3], ('127.0.0.1', 1234, factory))
t.stopService()
t.startService()
self.assertEqual(
reactor.tcpClients.pop()[:3], ('127.0.0.1', 1234, factory))
class TestTimerBasic(unittest.TestCase):
def testTimerRuns(self):
d = defer.Deferred()
self.t = internet.TimerService(1, d.callback, 'hello')
self.t.startService()
d.addCallback(self.assertEqual, 'hello')
d.addCallback(lambda x : self.t.stopService())
d.addCallback(lambda x : self.failIf(self.t.running))
return d
def tearDown(self):
return self.t.stopService()
def testTimerRestart(self):
# restart the same TimerService
d1 = defer.Deferred()
d2 = defer.Deferred()
work = [(d2, "bar"), (d1, "foo")]
def trigger():
d, arg = work.pop()
d.callback(arg)
self.t = internet.TimerService(1, trigger)
self.t.startService()
def onFirstResult(result):
self.assertEqual(result, 'foo')
return self.t.stopService()
def onFirstStop(ignored):
self.failIf(self.t.running)
self.t.startService()
return d2
def onSecondResult(result):
self.assertEqual(result, 'bar')
self.t.stopService()
d1.addCallback(onFirstResult)
d1.addCallback(onFirstStop)
d1.addCallback(onSecondResult)
return d1
def testTimerLoops(self):
l = []
def trigger(data, number, d):
l.append(data)
if len(l) == number:
d.callback(l)
d = defer.Deferred()
self.t = internet.TimerService(0.01, trigger, "hello", 10, d)
self.t.startService()
d.addCallback(self.assertEqual, ['hello'] * 10)
d.addCallback(lambda x : self.t.stopService())
return d
class FakeReactor(reactors.Reactor):
"""
A fake reactor with a hooked install method.
"""
def __init__(self, install, *args, **kwargs):
"""
@param install: any callable that will be used as install method.
@type install: C{callable}
"""
reactors.Reactor.__init__(self, *args, **kwargs)
self.install = install
class PluggableReactorTestCase(TwistedModulesMixin, unittest.TestCase):
"""
Tests for the reactor discovery/inspection APIs.
"""
def setUp(self):
"""
Override the L{reactors.getPlugins} function, normally bound to
L{twisted.plugin.getPlugins}, in order to control which
L{IReactorInstaller} plugins are seen as available.
C{self.pluginResults} can be customized and will be used as the
result of calls to C{reactors.getPlugins}.
"""
self.pluginCalls = []
self.pluginResults = []
self.originalFunction = reactors.getPlugins
reactors.getPlugins = self._getPlugins
def tearDown(self):
"""
Restore the original L{reactors.getPlugins}.
"""
reactors.getPlugins = self.originalFunction
def _getPlugins(self, interface, package=None):
"""
Stand-in for the real getPlugins method which records its arguments
and returns a fixed result.
"""
self.pluginCalls.append((interface, package))
return list(self.pluginResults)
def test_getPluginReactorTypes(self):
"""
Test that reactor plugins are returned from L{getReactorTypes}
"""
name = 'fakereactortest'
package = __name__ + '.fakereactor'
description = 'description'
self.pluginResults = [reactors.Reactor(name, package, description)]
reactorTypes = reactors.getReactorTypes()
self.assertEqual(
self.pluginCalls,
[(reactors.IReactorInstaller, None)])
for r in reactorTypes:
if r.shortName == name:
self.assertEqual(r.description, description)
break
else:
self.fail("Reactor plugin not present in getReactorTypes() result")
def test_reactorInstallation(self):
"""
Test that L{reactors.Reactor.install} loads the correct module and
calls its install attribute.
"""
installed = []
def install():
installed.append(True)
fakeReactor = FakeReactor(install,
'fakereactortest', __name__, 'described')
modules = {'fakereactortest': fakeReactor}
self.replaceSysModules(modules)
installer = reactors.Reactor('fakereactor', 'fakereactortest', 'described')
installer.install()
self.assertEqual(installed, [True])
def test_installReactor(self):
"""
Test that the L{reactors.installReactor} function correctly installs
the specified reactor.
"""
installed = []
def install():
installed.append(True)
name = 'fakereactortest'
package = __name__
description = 'description'
self.pluginResults = [FakeReactor(install, name, package, description)]
reactors.installReactor(name)
self.assertEqual(installed, [True])
def test_installReactorReturnsReactor(self):
"""
Test that the L{reactors.installReactor} function correctly returns
the installed reactor.
"""
reactor = object()
def install():
from twisted import internet
self.patch(internet, 'reactor', reactor)
name = 'fakereactortest'
package = __name__
description = 'description'
self.pluginResults = [FakeReactor(install, name, package, description)]
installed = reactors.installReactor(name)
self.assertIdentical(installed, reactor)
def test_installReactorMultiplePlugins(self):
"""
Test that the L{reactors.installReactor} function correctly installs
the specified reactor when there are multiple reactor plugins.
"""
installed = []
def install():
installed.append(True)
name = 'fakereactortest'
package = __name__
description = 'description'
fakeReactor = FakeReactor(install, name, package, description)
otherReactor = FakeReactor(lambda: None,
"otherreactor", package, description)
self.pluginResults = [otherReactor, fakeReactor]
reactors.installReactor(name)
self.assertEqual(installed, [True])
def test_installNonExistentReactor(self):
"""
Test that L{reactors.installReactor} raises L{reactors.NoSuchReactor}
when asked to install a reactor which it cannot find.
"""
self.pluginResults = []
self.assertRaises(
reactors.NoSuchReactor,
reactors.installReactor, 'somereactor')
def test_installNotAvailableReactor(self):
"""
Test that L{reactors.installReactor} raises an exception when asked to
install a reactor which doesn't work in this environment.
"""
def install():
raise ImportError("Missing foo bar")
name = 'fakereactortest'
package = __name__
description = 'description'
self.pluginResults = [FakeReactor(install, name, package, description)]
self.assertRaises(ImportError, reactors.installReactor, name)
def test_reactorSelectionMixin(self):
"""
Test that the reactor selected is installed as soon as possible, ie
when the option is parsed.
"""
executed = []
INSTALL_EVENT = 'reactor installed'
SUBCOMMAND_EVENT = 'subcommands loaded'
class ReactorSelectionOptions(usage.Options, app.ReactorSelectionMixin):
def subCommands(self):
executed.append(SUBCOMMAND_EVENT)
return [('subcommand', None, lambda: self, 'test subcommand')]
subCommands = property(subCommands)
def install():
executed.append(INSTALL_EVENT)
self.pluginResults = [
FakeReactor(install, 'fakereactortest', __name__, 'described')
]
options = ReactorSelectionOptions()
options.parseOptions(['--reactor', 'fakereactortest', 'subcommand'])
self.assertEqual(executed[0], INSTALL_EVENT)
self.assertEqual(executed.count(INSTALL_EVENT), 1)
self.assertEqual(options["reactor"], "fakereactortest")
def test_reactorSelectionMixinNonExistent(self):
"""
Test that the usage mixin exits when trying to use a non existent
reactor (the name not matching to any reactor), giving an error
message.
"""
class ReactorSelectionOptions(usage.Options, app.ReactorSelectionMixin):
pass
self.pluginResults = []
options = ReactorSelectionOptions()
options.messageOutput = StringIO()
e = self.assertRaises(usage.UsageError, options.parseOptions,
['--reactor', 'fakereactortest', 'subcommand'])
self.assertIn("fakereactortest", e.args[0])
self.assertIn("help-reactors", e.args[0])
def test_reactorSelectionMixinNotAvailable(self):
"""
Test that the usage mixin exits when trying to use a reactor not
available (the reactor raises an error at installation), giving an
error message.
"""
class ReactorSelectionOptions(usage.Options, app.ReactorSelectionMixin):
pass
message = "Missing foo bar"
def install():
raise ImportError(message)
name = 'fakereactortest'
package = __name__
description = 'description'
self.pluginResults = [FakeReactor(install, name, package, description)]
options = ReactorSelectionOptions()
options.messageOutput = StringIO()
e = self.assertRaises(usage.UsageError, options.parseOptions,
['--reactor', 'fakereactortest', 'subcommand'])
self.assertIn(message, e.args[0])
self.assertIn("help-reactors", e.args[0])

View file

@ -0,0 +1,278 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import StringIO
import sys
# Twisted Imports
from twisted.trial import unittest
from twisted.spread import banana
from twisted.python import failure
from twisted.internet import protocol, main
class MathTestCase(unittest.TestCase):
def testInt2b128(self):
funkylist = range(0,100) + range(1000,1100) + range(1000000,1000100) + [1024 **10l]
for i in funkylist:
x = StringIO.StringIO()
banana.int2b128(i, x.write)
v = x.getvalue()
y = banana.b1282int(v)
assert y == i, "y = %s; i = %s" % (y,i)
class BananaTestCase(unittest.TestCase):
encClass = banana.Banana
def setUp(self):
self.io = StringIO.StringIO()
self.enc = self.encClass()
self.enc.makeConnection(protocol.FileWrapper(self.io))
self.enc._selectDialect("none")
self.enc.expressionReceived = self.putResult
def putResult(self, result):
self.result = result
def tearDown(self):
self.enc.connectionLost(failure.Failure(main.CONNECTION_DONE))
del self.enc
def testString(self):
self.enc.sendEncoded("hello")
l = []
self.enc.dataReceived(self.io.getvalue())
assert self.result == 'hello'
def test_int(self):
"""
A positive integer less than 2 ** 32 should round-trip through
banana without changing value and should come out represented
as an C{int} (regardless of the type which was encoded).
"""
for value in (10151, 10151L):
self.enc.sendEncoded(value)
self.enc.dataReceived(self.io.getvalue())
self.assertEqual(self.result, 10151)
self.assertIsInstance(self.result, int)
def test_largeLong(self):
"""
Integers greater than 2 ** 32 and less than -2 ** 32 should
round-trip through banana without changing value and should
come out represented as C{int} instances if the value fits
into that type on the receiving platform.
"""
for exp in (32, 64, 128, 256):
for add in (0, 1):
m = 2 ** exp + add
for n in (m, -m-1):
self.io.truncate(0)
self.enc.sendEncoded(n)
self.enc.dataReceived(self.io.getvalue())
self.assertEqual(self.result, n)
if n > sys.maxint or n < -sys.maxint - 1:
self.assertIsInstance(self.result, long)
else:
self.assertIsInstance(self.result, int)
def _getSmallest(self):
# How many bytes of prefix our implementation allows
bytes = self.enc.prefixLimit
# How many useful bits we can extract from that based on Banana's
# base-128 representation.
bits = bytes * 7
# The largest number we _should_ be able to encode
largest = 2 ** bits - 1
# The smallest number we _shouldn't_ be able to encode
smallest = largest + 1
return smallest
def test_encodeTooLargeLong(self):
"""
Test that a long above the implementation-specific limit is rejected
as too large to be encoded.
"""
smallest = self._getSmallest()
self.assertRaises(banana.BananaError, self.enc.sendEncoded, smallest)
def test_decodeTooLargeLong(self):
"""
Test that a long above the implementation specific limit is rejected
as too large to be decoded.
"""
smallest = self._getSmallest()
self.enc.setPrefixLimit(self.enc.prefixLimit * 2)
self.enc.sendEncoded(smallest)
encoded = self.io.getvalue()
self.io.truncate(0)
self.enc.setPrefixLimit(self.enc.prefixLimit // 2)
self.assertRaises(banana.BananaError, self.enc.dataReceived, encoded)
def _getLargest(self):
return -self._getSmallest()
def test_encodeTooSmallLong(self):
"""
Test that a negative long below the implementation-specific limit is
rejected as too small to be encoded.
"""
largest = self._getLargest()
self.assertRaises(banana.BananaError, self.enc.sendEncoded, largest)
def test_decodeTooSmallLong(self):
"""
Test that a negative long below the implementation specific limit is
rejected as too small to be decoded.
"""
largest = self._getLargest()
self.enc.setPrefixLimit(self.enc.prefixLimit * 2)
self.enc.sendEncoded(largest)
encoded = self.io.getvalue()
self.io.truncate(0)
self.enc.setPrefixLimit(self.enc.prefixLimit // 2)
self.assertRaises(banana.BananaError, self.enc.dataReceived, encoded)
def testNegativeLong(self):
self.enc.sendEncoded(-1015l)
self.enc.dataReceived(self.io.getvalue())
assert self.result == -1015l, "should be -1015l, got %s" % self.result
def testInteger(self):
self.enc.sendEncoded(1015)
self.enc.dataReceived(self.io.getvalue())
assert self.result == 1015, "should be 1015, got %s" % self.result
def testNegative(self):
self.enc.sendEncoded(-1015)
self.enc.dataReceived(self.io.getvalue())
assert self.result == -1015, "should be -1015, got %s" % self.result
def testFloat(self):
self.enc.sendEncoded(1015.)
self.enc.dataReceived(self.io.getvalue())
assert self.result == 1015.
def testList(self):
foo = [1, 2, [3, 4], [30.5, 40.2], 5, ["six", "seven", ["eight", 9]], [10], []]
self.enc.sendEncoded(foo)
self.enc.dataReceived(self.io.getvalue())
assert self.result == foo, "%s!=%s" % (repr(self.result), repr(self.result))
def testPartial(self):
foo = [1, 2, [3, 4], [30.5, 40.2], 5,
["six", "seven", ["eight", 9]], [10],
# TODO: currently the C implementation's a bit buggy...
sys.maxint * 3l, sys.maxint * 2l, sys.maxint * -2l]
self.enc.sendEncoded(foo)
for byte in self.io.getvalue():
self.enc.dataReceived(byte)
assert self.result == foo, "%s!=%s" % (repr(self.result), repr(foo))
def feed(self, data):
for byte in data:
self.enc.dataReceived(byte)
def testOversizedList(self):
data = '\x02\x01\x01\x01\x01\x80'
# list(size=0x0101010102, about 4.3e9)
self.failUnlessRaises(banana.BananaError, self.feed, data)
def testOversizedString(self):
data = '\x02\x01\x01\x01\x01\x82'
# string(size=0x0101010102, about 4.3e9)
self.failUnlessRaises(banana.BananaError, self.feed, data)
def testCrashString(self):
crashString = '\x00\x00\x00\x00\x04\x80'
# string(size=0x0400000000, about 17.2e9)
# cBanana would fold that into a 32-bit 'int', then try to allocate
# a list with PyList_New(). cBanana ignored the NULL return value,
# so it would segfault when trying to free the imaginary list.
# This variant doesn't segfault straight out in my environment.
# Instead, it takes up large amounts of CPU and memory...
#crashString = '\x00\x00\x00\x00\x01\x80'
# print repr(crashString)
#self.failUnlessRaises(Exception, self.enc.dataReceived, crashString)
try:
# should now raise MemoryError
self.enc.dataReceived(crashString)
except banana.BananaError:
pass
def testCrashNegativeLong(self):
# There was a bug in cBanana which relied on negating a negative integer
# always giving a postive result, but for the lowest possible number in
# 2s-complement arithmetic, that's not true, i.e.
# long x = -2147483648;
# long y = -x;
# x == y; /* true! */
# (assuming 32-bit longs)
self.enc.sendEncoded(-2147483648)
self.enc.dataReceived(self.io.getvalue())
assert self.result == -2147483648, "should be -2147483648, got %s" % self.result
def test_sizedIntegerTypes(self):
"""
Test that integers below the maximum C{INT} token size cutoff are
serialized as C{INT} or C{NEG} and that larger integers are
serialized as C{LONGINT} or C{LONGNEG}.
"""
def encoded(n):
self.io.seek(0)
self.io.truncate()
self.enc.sendEncoded(n)
return self.io.getvalue()
baseIntIn = +2147483647
baseNegIn = -2147483648
baseIntOut = '\x7f\x7f\x7f\x07\x81'
self.assertEqual(encoded(baseIntIn - 2), '\x7d' + baseIntOut)
self.assertEqual(encoded(baseIntIn - 1), '\x7e' + baseIntOut)
self.assertEqual(encoded(baseIntIn - 0), '\x7f' + baseIntOut)
baseLongIntOut = '\x00\x00\x00\x08\x85'
self.assertEqual(encoded(baseIntIn + 1), '\x00' + baseLongIntOut)
self.assertEqual(encoded(baseIntIn + 2), '\x01' + baseLongIntOut)
self.assertEqual(encoded(baseIntIn + 3), '\x02' + baseLongIntOut)
baseNegOut = '\x7f\x7f\x7f\x07\x83'
self.assertEqual(encoded(baseNegIn + 2), '\x7e' + baseNegOut)
self.assertEqual(encoded(baseNegIn + 1), '\x7f' + baseNegOut)
self.assertEqual(encoded(baseNegIn + 0), '\x00\x00\x00\x00\x08\x83')
baseLongNegOut = '\x00\x00\x00\x08\x86'
self.assertEqual(encoded(baseNegIn - 1), '\x01' + baseLongNegOut)
self.assertEqual(encoded(baseNegIn - 2), '\x02' + baseLongNegOut)
self.assertEqual(encoded(baseNegIn - 3), '\x03' + baseLongNegOut)
class GlobalCoderTests(unittest.TestCase):
"""
Tests for the free functions L{banana.encode} and L{banana.decode}.
"""
def test_statelessDecode(self):
"""
Test that state doesn't carry over between calls to L{banana.decode}.
"""
# Banana encoding of 2 ** 449
undecodable = '\x7f' * 65 + '\x85'
self.assertRaises(banana.BananaError, banana.decode, undecodable)
# Banana encoding of 1
decodable = '\x01\x81'
self.assertEqual(banana.decode(decodable), 1)

View file

@ -0,0 +1,623 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.python.compat}.
"""
from __future__ import division, absolute_import
import socket, sys, traceback
from twisted.trial import unittest
from twisted.python.compat import reduce, execfile, _PY3
from twisted.python.compat import comparable, cmp, nativeString, networkString
from twisted.python.compat import unicode as unicodeCompat, lazyByteSlice
from twisted.python.compat import reraise, NativeStringIO, iterbytes, intToBytes
from twisted.python.filepath import FilePath
class CompatTestCase(unittest.SynchronousTestCase):
"""
Various utility functions in C{twisted.python.compat} provide same
functionality as modern Python variants.
"""
def test_set(self):
"""
L{set} should behave like the expected set interface.
"""
a = set()
a.add('b')
a.add('c')
a.add('a')
b = list(a)
b.sort()
self.assertEqual(b, ['a', 'b', 'c'])
a.remove('b')
b = list(a)
b.sort()
self.assertEqual(b, ['a', 'c'])
a.discard('d')
b = set(['r', 's'])
d = a.union(b)
b = list(d)
b.sort()
self.assertEqual(b, ['a', 'c', 'r', 's'])
def test_frozenset(self):
"""
L{frozenset} should behave like the expected frozenset interface.
"""
a = frozenset(['a', 'b'])
self.assertRaises(AttributeError, getattr, a, "add")
self.assertEqual(sorted(a), ['a', 'b'])
b = frozenset(['r', 's'])
d = a.union(b)
b = list(d)
b.sort()
self.assertEqual(b, ['a', 'b', 'r', 's'])
def test_reduce(self):
"""
L{reduce} should behave like the builtin reduce.
"""
self.assertEqual(15, reduce(lambda x, y: x + y, [1, 2, 3, 4, 5]))
self.assertEqual(16, reduce(lambda x, y: x + y, [1, 2, 3, 4, 5], 1))
class IPv6Tests(unittest.SynchronousTestCase):
"""
C{inet_pton} and C{inet_ntop} implementations support IPv6.
"""
def testNToP(self):
from twisted.python.compat import inet_ntop
f = lambda a: inet_ntop(socket.AF_INET6, a)
g = lambda a: inet_ntop(socket.AF_INET, a)
self.assertEqual('::', f('\x00' * 16))
self.assertEqual('::1', f('\x00' * 15 + '\x01'))
self.assertEqual(
'aef:b01:506:1001:ffff:9997:55:170',
f('\x0a\xef\x0b\x01\x05\x06\x10\x01\xff\xff\x99\x97\x00\x55\x01\x70'))
self.assertEqual('1.0.1.0', g('\x01\x00\x01\x00'))
self.assertEqual('170.85.170.85', g('\xaa\x55\xaa\x55'))
self.assertEqual('255.255.255.255', g('\xff\xff\xff\xff'))
self.assertEqual('100::', f('\x01' + '\x00' * 15))
self.assertEqual('100::1', f('\x01' + '\x00' * 14 + '\x01'))
def testPToN(self):
from twisted.python.compat import inet_pton
f = lambda a: inet_pton(socket.AF_INET6, a)
g = lambda a: inet_pton(socket.AF_INET, a)
self.assertEqual('\x00\x00\x00\x00', g('0.0.0.0'))
self.assertEqual('\xff\x00\xff\x00', g('255.0.255.0'))
self.assertEqual('\xaa\xaa\xaa\xaa', g('170.170.170.170'))
self.assertEqual('\x00' * 16, f('::'))
self.assertEqual('\x00' * 16, f('0::0'))
self.assertEqual('\x00\x01' + '\x00' * 14, f('1::'))
self.assertEqual(
'\x45\xef\x76\xcb\x00\x1a\x56\xef\xaf\xeb\x0b\xac\x19\x24\xae\xae',
f('45ef:76cb:1a:56ef:afeb:bac:1924:aeae'))
self.assertEqual('\x00' * 14 + '\x00\x01', f('::1'))
self.assertEqual('\x00' * 12 + '\x01\x02\x03\x04', f('::1.2.3.4'))
self.assertEqual(
'\x00\x01\x00\x02\x00\x03\x00\x04\x00\x05\x00\x06\x01\x02\x03\xff',
f('1:2:3:4:5:6:1.2.3.255'))
for badaddr in ['1:2:3:4:5:6:7:8:', ':1:2:3:4:5:6:7:8', '1::2::3',
'1:::3', ':::', '1:2', '::1.2', '1.2.3.4::',
'abcd:1.2.3.4:abcd:abcd:abcd:abcd:abcd',
'1234:1.2.3.4:1234:1234:1234:1234:1234:1234',
'1.2.3.4']:
self.assertRaises(ValueError, f, badaddr)
if _PY3:
IPv6Tests.skip = "These tests are only relevant to old versions of Python"
class ExecfileCompatTestCase(unittest.SynchronousTestCase):
"""
Tests for the Python 3-friendly L{execfile} implementation.
"""
def writeScript(self, content):
"""
Write L{content} to a new temporary file, returning the L{FilePath}
for the new file.
"""
path = self.mktemp()
with open(path, "wb") as f:
f.write(content.encode("ascii"))
return FilePath(path.encode("utf-8"))
def test_execfileGlobals(self):
"""
L{execfile} executes the specified file in the given global namespace.
"""
script = self.writeScript(u"foo += 1\n")
globalNamespace = {"foo": 1}
execfile(script.path, globalNamespace)
self.assertEqual(2, globalNamespace["foo"])
def test_execfileGlobalsAndLocals(self):
"""
L{execfile} executes the specified file in the given global and local
namespaces.
"""
script = self.writeScript(u"foo += 1\n")
globalNamespace = {"foo": 10}
localNamespace = {"foo": 20}
execfile(script.path, globalNamespace, localNamespace)
self.assertEqual(10, globalNamespace["foo"])
self.assertEqual(21, localNamespace["foo"])
def test_execfileUniversalNewlines(self):
"""
L{execfile} reads in the specified file using universal newlines so
that scripts written on one platform will work on another.
"""
for lineEnding in u"\n", u"\r", u"\r\n":
script = self.writeScript(u"foo = 'okay'" + lineEnding)
globalNamespace = {"foo": None}
execfile(script.path, globalNamespace)
self.assertEqual("okay", globalNamespace["foo"])
class PY3Tests(unittest.SynchronousTestCase):
"""
Identification of Python 2 vs. Python 3.
"""
def test_python2(self):
"""
On Python 2, C{_PY3} is False.
"""
if sys.version.startswith("2."):
self.assertFalse(_PY3)
def test_python3(self):
"""
On Python 3, C{_PY3} is True.
"""
if sys.version.startswith("3."):
self.assertTrue(_PY3)
@comparable
class Comparable(object):
"""
Objects that can be compared to each other, but not others.
"""
def __init__(self, value):
self.value = value
def __cmp__(self, other):
if not isinstance(other, Comparable):
return NotImplemented
return cmp(self.value, other.value)
class ComparableTests(unittest.SynchronousTestCase):
"""
L{comparable} decorated classes emulate Python 2's C{__cmp__} semantics.
"""
def test_equality(self):
"""
Instances of a class that is decorated by C{comparable} support
equality comparisons.
"""
# Make explicitly sure we're using ==:
self.assertTrue(Comparable(1) == Comparable(1))
self.assertFalse(Comparable(2) == Comparable(1))
def test_nonEquality(self):
"""
Instances of a class that is decorated by C{comparable} support
inequality comparisons.
"""
# Make explicitly sure we're using !=:
self.assertFalse(Comparable(1) != Comparable(1))
self.assertTrue(Comparable(2) != Comparable(1))
def test_greaterThan(self):
"""
Instances of a class that is decorated by C{comparable} support
greater-than comparisons.
"""
self.assertTrue(Comparable(2) > Comparable(1))
self.assertFalse(Comparable(0) > Comparable(3))
def test_greaterThanOrEqual(self):
"""
Instances of a class that is decorated by C{comparable} support
greater-than-or-equal comparisons.
"""
self.assertTrue(Comparable(1) >= Comparable(1))
self.assertTrue(Comparable(2) >= Comparable(1))
self.assertFalse(Comparable(0) >= Comparable(3))
def test_lessThan(self):
"""
Instances of a class that is decorated by C{comparable} support
less-than comparisons.
"""
self.assertTrue(Comparable(0) < Comparable(3))
self.assertFalse(Comparable(2) < Comparable(0))
def test_lessThanOrEqual(self):
"""
Instances of a class that is decorated by C{comparable} support
less-than-or-equal comparisons.
"""
self.assertTrue(Comparable(3) <= Comparable(3))
self.assertTrue(Comparable(0) <= Comparable(3))
self.assertFalse(Comparable(2) <= Comparable(0))
class Python3ComparableTests(unittest.SynchronousTestCase):
"""
Python 3-specific functionality of C{comparable}.
"""
def test_notImplementedEquals(self):
"""
Instances of a class that is decorated by C{comparable} support
returning C{NotImplemented} from C{__eq__} if it is returned by the
underlying C{__cmp__} call.
"""
self.assertEqual(Comparable(1).__eq__(object()), NotImplemented)
def test_notImplementedNotEquals(self):
"""
Instances of a class that is decorated by C{comparable} support
returning C{NotImplemented} from C{__ne__} if it is returned by the
underlying C{__cmp__} call.
"""
self.assertEqual(Comparable(1).__ne__(object()), NotImplemented)
def test_notImplementedGreaterThan(self):
"""
Instances of a class that is decorated by C{comparable} support
returning C{NotImplemented} from C{__gt__} if it is returned by the
underlying C{__cmp__} call.
"""
self.assertEqual(Comparable(1).__gt__(object()), NotImplemented)
def test_notImplementedLessThan(self):
"""
Instances of a class that is decorated by C{comparable} support
returning C{NotImplemented} from C{__lt__} if it is returned by the
underlying C{__cmp__} call.
"""
self.assertEqual(Comparable(1).__lt__(object()), NotImplemented)
def test_notImplementedGreaterThanEquals(self):
"""
Instances of a class that is decorated by C{comparable} support
returning C{NotImplemented} from C{__ge__} if it is returned by the
underlying C{__cmp__} call.
"""
self.assertEqual(Comparable(1).__ge__(object()), NotImplemented)
def test_notImplementedLessThanEquals(self):
"""
Instances of a class that is decorated by C{comparable} support
returning C{NotImplemented} from C{__le__} if it is returned by the
underlying C{__cmp__} call.
"""
self.assertEqual(Comparable(1).__le__(object()), NotImplemented)
if not _PY3:
# On Python 2, we just use __cmp__ directly, so checking detailed
# comparison methods doesn't makes sense.
Python3ComparableTests.skip = "Python 3 only."
class CmpTests(unittest.SynchronousTestCase):
"""
L{cmp} should behave like the built-in Python 2 C{cmp}.
"""
def test_equals(self):
"""
L{cmp} returns 0 for equal objects.
"""
self.assertEqual(cmp(u"a", u"a"), 0)
self.assertEqual(cmp(1, 1), 0)
self.assertEqual(cmp([1], [1]), 0)
def test_greaterThan(self):
"""
L{cmp} returns 1 if its first argument is bigger than its second.
"""
self.assertEqual(cmp(4, 0), 1)
self.assertEqual(cmp(b"z", b"a"), 1)
def test_lessThan(self):
"""
L{cmp} returns -1 if its first argument is smaller than its second.
"""
self.assertEqual(cmp(0.1, 2.3), -1)
self.assertEqual(cmp(b"a", b"d"), -1)
class StringTests(unittest.SynchronousTestCase):
"""
Compatibility functions and types for strings.
"""
def assertNativeString(self, original, expected):
"""
Raise an exception indicating a failed test if the output of
C{nativeString(original)} is unequal to the expected string, or is not
a native string.
"""
self.assertEqual(nativeString(original), expected)
self.assertIsInstance(nativeString(original), str)
def test_nonASCIIBytesToString(self):
"""
C{nativeString} raises a C{UnicodeError} if input bytes are not ASCII
decodable.
"""
self.assertRaises(UnicodeError, nativeString, b"\xFF")
def test_nonASCIIUnicodeToString(self):
"""
C{nativeString} raises a C{UnicodeError} if input Unicode is not ASCII
encodable.
"""
self.assertRaises(UnicodeError, nativeString, u"\u1234")
def test_bytesToString(self):
"""
C{nativeString} converts bytes to the native string format, assuming
an ASCII encoding if applicable.
"""
self.assertNativeString(b"hello", "hello")
def test_unicodeToString(self):
"""
C{nativeString} converts unicode to the native string format, assuming
an ASCII encoding if applicable.
"""
self.assertNativeString(u"Good day", "Good day")
def test_stringToString(self):
"""
C{nativeString} leaves native strings as native strings.
"""
self.assertNativeString("Hello!", "Hello!")
def test_unexpectedType(self):
"""
C{nativeString} raises a C{TypeError} if given an object that is not a
string of some sort.
"""
self.assertRaises(TypeError, nativeString, 1)
def test_unicode(self):
"""
C{compat.unicode} is C{str} on Python 3, C{unicode} on Python 2.
"""
if _PY3:
expected = str
else:
expected = unicode
self.assertTrue(unicodeCompat is expected)
def test_nativeStringIO(self):
"""
L{NativeStringIO} is a file-like object that stores native strings in
memory.
"""
f = NativeStringIO()
f.write("hello")
f.write(" there")
self.assertEqual(f.getvalue(), "hello there")
class NetworkStringTests(unittest.SynchronousTestCase):
"""
Tests for L{networkString}.
"""
def test_bytes(self):
"""
L{networkString} returns a C{bytes} object passed to it unmodified.
"""
self.assertEqual(b"foo", networkString(b"foo"))
def test_bytesOutOfRange(self):
"""
L{networkString} raises C{UnicodeError} if passed a C{bytes} instance
containing bytes not used by ASCII.
"""
self.assertRaises(
UnicodeError, networkString, u"\N{SNOWMAN}".encode('utf-8'))
if _PY3:
test_bytes.skip = test_bytesOutOfRange.skip = (
"Bytes behavior of networkString only provided on Python 2.")
def test_unicode(self):
"""
L{networkString} returns a C{unicode} object passed to it encoded into a
C{bytes} instance.
"""
self.assertEqual(b"foo", networkString(u"foo"))
def test_unicodeOutOfRange(self):
"""
L{networkString} raises L{UnicodeError} if passed a C{unicode} instance
containing characters not encodable in ASCII.
"""
self.assertRaises(
UnicodeError, networkString, u"\N{SNOWMAN}")
if not _PY3:
test_unicode.skip = test_unicodeOutOfRange.skip = (
"Unicode behavior of networkString only provided on Python 3.")
def test_nonString(self):
"""
L{networkString} raises L{TypeError} if passed a non-string object or
the wrong type of string object.
"""
self.assertRaises(TypeError, networkString, object())
if _PY3:
self.assertRaises(TypeError, networkString, b"bytes")
else:
self.assertRaises(TypeError, networkString, u"text")
class ReraiseTests(unittest.SynchronousTestCase):
"""
L{reraise} re-raises exceptions on both Python 2 and Python 3.
"""
def test_reraiseWithNone(self):
"""
Calling L{reraise} with an exception instance and a traceback of
C{None} re-raises it with a new traceback.
"""
try:
1/0
except:
typ, value, tb = sys.exc_info()
try:
reraise(value, None)
except:
typ2, value2, tb2 = sys.exc_info()
self.assertEqual(typ2, ZeroDivisionError)
self.assertTrue(value is value2)
self.assertNotEqual(traceback.format_tb(tb)[-1],
traceback.format_tb(tb2)[-1])
else:
self.fail("The exception was not raised.")
def test_reraiseWithTraceback(self):
"""
Calling L{reraise} with an exception instance and a traceback
re-raises the exception with the given traceback.
"""
try:
1/0
except:
typ, value, tb = sys.exc_info()
try:
reraise(value, tb)
except:
typ2, value2, tb2 = sys.exc_info()
self.assertEqual(typ2, ZeroDivisionError)
self.assertTrue(value is value2)
self.assertEqual(traceback.format_tb(tb)[-1],
traceback.format_tb(tb2)[-1])
else:
self.fail("The exception was not raised.")
class Python3BytesTests(unittest.SynchronousTestCase):
"""
Tests for L{iterbytes}, L{intToBytes}, L{lazyByteSlice}.
"""
def test_iteration(self):
"""
When L{iterbytes} is called with a bytestring, the returned object
can be iterated over, resulting in the individual bytes of the
bytestring.
"""
input = b"abcd"
result = list(iterbytes(input))
self.assertEqual(result, [b'a', b'b', b'c', b'd'])
def test_intToBytes(self):
"""
When L{intToBytes} is called with an integer, the result is an
ASCII-encoded string representation of the number.
"""
self.assertEqual(intToBytes(213), b"213")
def test_lazyByteSliceNoOffset(self):
"""
L{lazyByteSlice} called with some bytes returns a semantically equal version
of these bytes.
"""
data = b'123XYZ'
self.assertEqual(bytes(lazyByteSlice(data)), data)
def test_lazyByteSliceOffset(self):
"""
L{lazyByteSlice} called with some bytes and an offset returns a semantically
equal version of these bytes starting at the given offset.
"""
data = b'123XYZ'
self.assertEqual(bytes(lazyByteSlice(data, 2)), data[2:])
def test_lazyByteSliceOffsetAndLength(self):
"""
L{lazyByteSlice} called with some bytes, an offset and a length returns a
semantically equal version of these bytes starting at the given
offset, up to the given length.
"""
data = b'123XYZ'
self.assertEqual(bytes(lazyByteSlice(data, 2, 3)), data[2:5])

View file

@ -0,0 +1,51 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.python.context}.
"""
from __future__ import division, absolute_import
from twisted.trial.unittest import SynchronousTestCase
from twisted.python import context
class ContextTest(SynchronousTestCase):
"""
Tests for the module-scope APIs for L{twisted.python.context}.
"""
def test_notPresentIfNotSet(self):
"""
Arbitrary keys which have not been set in the context have an associated
value of C{None}.
"""
self.assertEqual(context.get("x"), None)
def test_setByCall(self):
"""
Values may be associated with keys by passing them in a dictionary as
the first argument to L{twisted.python.context.call}.
"""
self.assertEqual(context.call({"x": "y"}, context.get, "x"), "y")
def test_unsetAfterCall(self):
"""
After a L{twisted.python.context.call} completes, keys specified in the
call are no longer associated with the values from that call.
"""
context.call({"x": "y"}, lambda: None)
self.assertEqual(context.get("x"), None)
def test_setDefault(self):
"""
A default value may be set for a key in the context using
L{twisted.python.context.setDefault}.
"""
key = object()
self.addCleanup(context.defaultContextDict.pop, key, None)
context.setDefault(key, "y")
self.assertEqual("y", context.get(key))

View file

@ -0,0 +1,711 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This module contains tests for L{twisted.internet.task.Cooperator} and
related functionality.
"""
from __future__ import division, absolute_import
from twisted.internet import reactor, defer, task
from twisted.trial import unittest
class FakeDelayedCall(object):
"""
Fake delayed call which lets us simulate the scheduler.
"""
def __init__(self, func):
"""
A function to run, later.
"""
self.func = func
self.cancelled = False
def cancel(self):
"""
Don't run my function later.
"""
self.cancelled = True
class FakeScheduler(object):
"""
A fake scheduler for testing against.
"""
def __init__(self):
"""
Create a fake scheduler with a list of work to do.
"""
self.work = []
def __call__(self, thunk):
"""
Schedule a unit of work to be done later.
"""
unit = FakeDelayedCall(thunk)
self.work.append(unit)
return unit
def pump(self):
"""
Do all of the work that is currently available to be done.
"""
work, self.work = self.work, []
for unit in work:
if not unit.cancelled:
unit.func()
class TestCooperator(unittest.TestCase):
RESULT = 'done'
def ebIter(self, err):
err.trap(task.SchedulerStopped)
return self.RESULT
def cbIter(self, ign):
self.fail()
def testStoppedRejectsNewTasks(self):
"""
Test that Cooperators refuse new tasks when they have been stopped.
"""
def testwith(stuff):
c = task.Cooperator()
c.stop()
d = c.coiterate(iter(()), stuff)
d.addCallback(self.cbIter)
d.addErrback(self.ebIter)
return d.addCallback(lambda result:
self.assertEqual(result, self.RESULT))
return testwith(None).addCallback(lambda ign: testwith(defer.Deferred()))
def testStopRunning(self):
"""
Test that a running iterator will not run to completion when the
cooperator is stopped.
"""
c = task.Cooperator()
def myiter():
for myiter.value in range(3):
yield myiter.value
myiter.value = -1
d = c.coiterate(myiter())
d.addCallback(self.cbIter)
d.addErrback(self.ebIter)
c.stop()
def doasserts(result):
self.assertEqual(result, self.RESULT)
self.assertEqual(myiter.value, -1)
d.addCallback(doasserts)
return d
def testStopOutstanding(self):
"""
An iterator run with L{Cooperator.coiterate} paused on a L{Deferred}
yielded by that iterator will fire its own L{Deferred} (the one
returned by C{coiterate}) when L{Cooperator.stop} is called.
"""
testControlD = defer.Deferred()
outstandingD = defer.Deferred()
def myiter():
reactor.callLater(0, testControlD.callback, None)
yield outstandingD
self.fail()
c = task.Cooperator()
d = c.coiterate(myiter())
def stopAndGo(ign):
c.stop()
outstandingD.callback('arglebargle')
testControlD.addCallback(stopAndGo)
d.addCallback(self.cbIter)
d.addErrback(self.ebIter)
return d.addCallback(
lambda result: self.assertEqual(result, self.RESULT))
def testUnexpectedError(self):
c = task.Cooperator()
def myiter():
if 0:
yield None
else:
raise RuntimeError()
d = c.coiterate(myiter())
return self.assertFailure(d, RuntimeError)
def testUnexpectedErrorActuallyLater(self):
def myiter():
D = defer.Deferred()
reactor.callLater(0, D.errback, RuntimeError())
yield D
c = task.Cooperator()
d = c.coiterate(myiter())
return self.assertFailure(d, RuntimeError)
def testUnexpectedErrorNotActuallyLater(self):
def myiter():
yield defer.fail(RuntimeError())
c = task.Cooperator()
d = c.coiterate(myiter())
return self.assertFailure(d, RuntimeError)
def testCooperation(self):
L = []
def myiter(things):
for th in things:
L.append(th)
yield None
groupsOfThings = ['abc', (1, 2, 3), 'def', (4, 5, 6)]
c = task.Cooperator()
tasks = []
for stuff in groupsOfThings:
tasks.append(c.coiterate(myiter(stuff)))
return defer.DeferredList(tasks).addCallback(
lambda ign: self.assertEqual(tuple(L), sum(zip(*groupsOfThings), ())))
def testResourceExhaustion(self):
output = []
def myiter():
for i in range(100):
output.append(i)
if i == 9:
_TPF.stopped = True
yield i
class _TPF:
stopped = False
def __call__(self):
return self.stopped
c = task.Cooperator(terminationPredicateFactory=_TPF)
c.coiterate(myiter()).addErrback(self.ebIter)
c._delayedCall.cancel()
# testing a private method because only the test case will ever care
# about this, so we have to carefully clean up after ourselves.
c._tick()
c.stop()
self.failUnless(_TPF.stopped)
self.assertEqual(output, list(range(10)))
def testCallbackReCoiterate(self):
"""
If a callback to a deferred returned by coiterate calls coiterate on
the same Cooperator, we should make sure to only do the minimal amount
of scheduling work. (This test was added to demonstrate a specific bug
that was found while writing the scheduler.)
"""
calls = []
class FakeCall:
def __init__(self, func):
self.func = func
def __repr__(self):
return '<FakeCall %r>' % (self.func,)
def sched(f):
self.failIf(calls, repr(calls))
calls.append(FakeCall(f))
return calls[-1]
c = task.Cooperator(scheduler=sched, terminationPredicateFactory=lambda: lambda: True)
d = c.coiterate(iter(()))
done = []
def anotherTask(ign):
c.coiterate(iter(())).addBoth(done.append)
d.addCallback(anotherTask)
work = 0
while not done:
work += 1
while calls:
calls.pop(0).func()
work += 1
if work > 50:
self.fail("Cooperator took too long")
def test_removingLastTaskStopsScheduledCall(self):
"""
If the last task in a Cooperator is removed, the scheduled call for
the next tick is cancelled, since it is no longer necessary.
This behavior is useful for tests that want to assert they have left
no reactor state behind when they're done.
"""
calls = [None]
def sched(f):
calls[0] = FakeDelayedCall(f)
return calls[0]
coop = task.Cooperator(scheduler=sched)
# Add two task; this should schedule the tick:
task1 = coop.cooperate(iter([1, 2]))
task2 = coop.cooperate(iter([1, 2]))
self.assertEqual(calls[0].func, coop._tick)
# Remove first task; scheduled call should still be going:
task1.stop()
self.assertEqual(calls[0].cancelled, False)
self.assertEqual(coop._delayedCall, calls[0])
# Remove second task; scheduled call should be cancelled:
task2.stop()
self.assertEqual(calls[0].cancelled, True)
self.assertEqual(coop._delayedCall, None)
# Add another task; scheduled call will be recreated:
coop.cooperate(iter([1, 2]))
self.assertEqual(calls[0].cancelled, False)
self.assertEqual(coop._delayedCall, calls[0])
def test_runningWhenStarted(self):
"""
L{Cooperator.running} reports C{True} if the L{Cooperator}
was started on creation.
"""
c = task.Cooperator()
self.assertTrue(c.running)
def test_runningWhenNotStarted(self):
"""
L{Cooperator.running} reports C{False} if the L{Cooperator}
has not been started.
"""
c = task.Cooperator(started=False)
self.assertFalse(c.running)
def test_runningWhenRunning(self):
"""
L{Cooperator.running} reports C{True} when the L{Cooperator}
is running.
"""
c = task.Cooperator(started=False)
c.start()
self.addCleanup(c.stop)
self.assertTrue(c.running)
def test_runningWhenStopped(self):
"""
L{Cooperator.running} reports C{False} after the L{Cooperator}
has been stopped.
"""
c = task.Cooperator(started=False)
c.start()
c.stop()
self.assertFalse(c.running)
class UnhandledException(Exception):
"""
An exception that should go unhandled.
"""
class AliasTests(unittest.TestCase):
"""
Integration test to verify that the global singleton aliases do what
they're supposed to.
"""
def test_cooperate(self):
"""
L{twisted.internet.task.cooperate} ought to run the generator that it is
"""
d = defer.Deferred()
def doit():
yield 1
yield 2
yield 3
d.callback("yay")
it = doit()
theTask = task.cooperate(it)
self.assertIn(theTask, task._theCooperator._tasks)
return d
class RunStateTests(unittest.TestCase):
"""
Tests to verify the behavior of L{CooperativeTask.pause},
L{CooperativeTask.resume}, L{CooperativeTask.stop}, exhausting the
underlying iterator, and their interactions with each other.
"""
def setUp(self):
"""
Create a cooperator with a fake scheduler and a termination predicate
that ensures only one unit of work will take place per tick.
"""
self._doDeferNext = False
self._doStopNext = False
self._doDieNext = False
self.work = []
self.scheduler = FakeScheduler()
self.cooperator = task.Cooperator(
scheduler=self.scheduler,
# Always stop after one iteration of work (return a function which
# returns a function which always returns True)
terminationPredicateFactory=lambda: lambda: True)
self.task = self.cooperator.cooperate(self.worker())
self.cooperator.start()
def worker(self):
"""
This is a sample generator which yields Deferreds when we are testing
deferral and an ascending integer count otherwise.
"""
i = 0
while True:
i += 1
if self._doDeferNext:
self._doDeferNext = False
d = defer.Deferred()
self.work.append(d)
yield d
elif self._doStopNext:
return
elif self._doDieNext:
raise UnhandledException()
else:
self.work.append(i)
yield i
def tearDown(self):
"""
Drop references to interesting parts of the fixture to allow Deferred
errors to be noticed when things start failing.
"""
del self.task
del self.scheduler
def deferNext(self):
"""
Defer the next result from my worker iterator.
"""
self._doDeferNext = True
def stopNext(self):
"""
Make the next result from my worker iterator be completion (raising
StopIteration).
"""
self._doStopNext = True
def dieNext(self):
"""
Make the next result from my worker iterator be raising an
L{UnhandledException}.
"""
def ignoreUnhandled(failure):
failure.trap(UnhandledException)
return None
self._doDieNext = True
def test_pauseResume(self):
"""
Cooperators should stop running their tasks when they're paused, and
start again when they're resumed.
"""
# first, sanity check
self.scheduler.pump()
self.assertEqual(self.work, [1])
self.scheduler.pump()
self.assertEqual(self.work, [1, 2])
# OK, now for real
self.task.pause()
self.scheduler.pump()
self.assertEqual(self.work, [1, 2])
self.task.resume()
# Resuming itself shoult not do any work
self.assertEqual(self.work, [1, 2])
self.scheduler.pump()
# But when the scheduler rolls around again...
self.assertEqual(self.work, [1, 2, 3])
def test_resumeNotPaused(self):
"""
L{CooperativeTask.resume} should raise a L{TaskNotPaused} exception if
it was not paused; e.g. if L{CooperativeTask.pause} was not invoked
more times than L{CooperativeTask.resume} on that object.
"""
self.assertRaises(task.NotPaused, self.task.resume)
self.task.pause()
self.task.resume()
self.assertRaises(task.NotPaused, self.task.resume)
def test_pauseTwice(self):
"""
Pauses on tasks should behave like a stack. If a task is paused twice,
it needs to be resumed twice.
"""
# pause once
self.task.pause()
self.scheduler.pump()
self.assertEqual(self.work, [])
# pause twice
self.task.pause()
self.scheduler.pump()
self.assertEqual(self.work, [])
# resume once (it shouldn't)
self.task.resume()
self.scheduler.pump()
self.assertEqual(self.work, [])
# resume twice (now it should go)
self.task.resume()
self.scheduler.pump()
self.assertEqual(self.work, [1])
def test_pauseWhileDeferred(self):
"""
C{pause()}ing a task while it is waiting on an outstanding
L{defer.Deferred} should put the task into a state where the
outstanding L{defer.Deferred} must be called back I{and} the task is
C{resume}d before it will continue processing.
"""
self.deferNext()
self.scheduler.pump()
self.assertEqual(len(self.work), 1)
self.failUnless(isinstance(self.work[0], defer.Deferred))
self.scheduler.pump()
self.assertEqual(len(self.work), 1)
self.task.pause()
self.scheduler.pump()
self.assertEqual(len(self.work), 1)
self.task.resume()
self.scheduler.pump()
self.assertEqual(len(self.work), 1)
self.work[0].callback("STUFF!")
self.scheduler.pump()
self.assertEqual(len(self.work), 2)
self.assertEqual(self.work[1], 2)
def test_whenDone(self):
"""
L{CooperativeTask.whenDone} returns a Deferred which fires when the
Cooperator's iterator is exhausted. It returns a new Deferred each
time it is called; callbacks added to other invocations will not modify
the value that subsequent invocations will fire with.
"""
deferred1 = self.task.whenDone()
deferred2 = self.task.whenDone()
results1 = []
results2 = []
final1 = []
final2 = []
def callbackOne(result):
results1.append(result)
return 1
def callbackTwo(result):
results2.append(result)
return 2
deferred1.addCallback(callbackOne)
deferred2.addCallback(callbackTwo)
deferred1.addCallback(final1.append)
deferred2.addCallback(final2.append)
# exhaust the task iterator
# callbacks fire
self.stopNext()
self.scheduler.pump()
self.assertEqual(len(results1), 1)
self.assertEqual(len(results2), 1)
self.assertIdentical(results1[0], self.task._iterator)
self.assertIdentical(results2[0], self.task._iterator)
self.assertEqual(final1, [1])
self.assertEqual(final2, [2])
def test_whenDoneError(self):
"""
L{CooperativeTask.whenDone} returns a L{defer.Deferred} that will fail
when the iterable's C{next} method raises an exception, with that
exception.
"""
deferred1 = self.task.whenDone()
results = []
deferred1.addErrback(results.append)
self.dieNext()
self.scheduler.pump()
self.assertEqual(len(results), 1)
self.assertEqual(results[0].check(UnhandledException), UnhandledException)
def test_whenDoneStop(self):
"""
L{CooperativeTask.whenDone} returns a L{defer.Deferred} that fails with
L{TaskStopped} when the C{stop} method is called on that
L{CooperativeTask}.
"""
deferred1 = self.task.whenDone()
errors = []
deferred1.addErrback(errors.append)
self.task.stop()
self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].check(task.TaskStopped), task.TaskStopped)
def test_whenDoneAlreadyDone(self):
"""
L{CooperativeTask.whenDone} will return a L{defer.Deferred} that will
succeed immediately if its iterator has already completed.
"""
self.stopNext()
self.scheduler.pump()
results = []
self.task.whenDone().addCallback(results.append)
self.assertEqual(results, [self.task._iterator])
def test_stopStops(self):
"""
C{stop()}ping a task should cause it to be removed from the run just as
C{pause()}ing, with the distinction that C{resume()} will raise a
L{TaskStopped} exception.
"""
self.task.stop()
self.scheduler.pump()
self.assertEqual(len(self.work), 0)
self.assertRaises(task.TaskStopped, self.task.stop)
self.assertRaises(task.TaskStopped, self.task.pause)
# Sanity check - it's still not scheduled, is it?
self.scheduler.pump()
self.assertEqual(self.work, [])
def test_pauseStopResume(self):
"""
C{resume()}ing a paused, stopped task should be a no-op; it should not
raise an exception, because it's paused, but neither should it actually
do more work from the task.
"""
self.task.pause()
self.task.stop()
self.task.resume()
self.scheduler.pump()
self.assertEqual(self.work, [])
def test_stopDeferred(self):
"""
As a corrolary of the interaction of C{pause()} and C{unpause()},
C{stop()}ping a task which is waiting on a L{Deferred} should cause the
task to gracefully shut down, meaning that it should not be unpaused
when the deferred fires.
"""
self.deferNext()
self.scheduler.pump()
d = self.work.pop()
self.assertEqual(self.task._pauseCount, 1)
results = []
d.addBoth(results.append)
self.scheduler.pump()
self.task.stop()
self.scheduler.pump()
d.callback(7)
self.scheduler.pump()
# Let's make sure that Deferred doesn't come out fried with an
# unhandled error that will be logged. The value is None, rather than
# our test value, 7, because this Deferred is returned to and consumed
# by the cooperator code. Its callback therefore has no contract.
self.assertEqual(results, [None])
# But more importantly, no further work should have happened.
self.assertEqual(self.work, [])
def test_stopExhausted(self):
"""
C{stop()}ping a L{CooperativeTask} whose iterator has been exhausted
should raise L{TaskDone}.
"""
self.stopNext()
self.scheduler.pump()
self.assertRaises(task.TaskDone, self.task.stop)
def test_stopErrored(self):
"""
C{stop()}ping a L{CooperativeTask} whose iterator has encountered an
error should raise L{TaskFailed}.
"""
self.dieNext()
self.scheduler.pump()
self.assertRaises(task.TaskFailed, self.task.stop)
def test_stopCooperatorReentrancy(self):
"""
If a callback of a L{Deferred} from L{CooperativeTask.whenDone} calls
C{Cooperator.stop} on its L{CooperativeTask._cooperator}, the
L{Cooperator} will stop, but the L{CooperativeTask} whose callback is
calling C{stop} should already be considered 'stopped' by the time the
callback is running, and therefore removed from the
L{CoooperativeTask}.
"""
callbackPhases = []
def stopit(result):
callbackPhases.append(result)
self.cooperator.stop()
# "done" here is a sanity check to make sure that we get all the
# way through the callback; i.e. stop() shouldn't be raising an
# exception due to the stopped-ness of our main task.
callbackPhases.append("done")
self.task.whenDone().addCallback(stopit)
self.stopNext()
self.scheduler.pump()
self.assertEqual(callbackPhases, [self.task._iterator, "done"])

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,301 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.internet.defer.deferredGenerator} and related APIs.
"""
from __future__ import division, absolute_import
import sys
from twisted.internet import reactor
from twisted.trial import unittest
from twisted.internet.defer import waitForDeferred, deferredGenerator, Deferred
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet import defer
def getThing():
d = Deferred()
reactor.callLater(0, d.callback, "hi")
return d
def getOwie():
d = Deferred()
def CRAP():
d.errback(ZeroDivisionError('OMG'))
reactor.callLater(0, CRAP)
return d
# NOTE: most of the tests in DeferredGeneratorTests are duplicated
# with slightly different syntax for the InlineCallbacksTests below.
class TerminalException(Exception):
pass
class BaseDefgenTests:
"""
This class sets up a bunch of test cases which will test both
deferredGenerator and inlineCallbacks based generators. The subclasses
DeferredGeneratorTests and InlineCallbacksTests each provide the actual
generator implementations tested.
"""
def testBasics(self):
"""
Test that a normal deferredGenerator works. Tests yielding a
deferred which callbacks, as well as a deferred errbacks. Also
ensures returning a final value works.
"""
return self._genBasics().addCallback(self.assertEqual, 'WOOSH')
def testBuggy(self):
"""
Ensure that a buggy generator properly signals a Failure
condition on result deferred.
"""
return self.assertFailure(self._genBuggy(), ZeroDivisionError)
def testNothing(self):
"""Test that a generator which never yields results in None."""
return self._genNothing().addCallback(self.assertEqual, None)
def testHandledTerminalFailure(self):
"""
Create a Deferred Generator which yields a Deferred which fails and
handles the exception which results. Assert that the Deferred
Generator does not errback its Deferred.
"""
return self._genHandledTerminalFailure().addCallback(self.assertEqual, None)
def testHandledTerminalAsyncFailure(self):
"""
Just like testHandledTerminalFailure, only with a Deferred which fires
asynchronously with an error.
"""
d = defer.Deferred()
deferredGeneratorResultDeferred = self._genHandledTerminalAsyncFailure(d)
d.errback(TerminalException("Handled Terminal Failure"))
return deferredGeneratorResultDeferred.addCallback(
self.assertEqual, None)
def testStackUsage(self):
"""
Make sure we don't blow the stack when yielding immediately
available deferreds.
"""
return self._genStackUsage().addCallback(self.assertEqual, 0)
def testStackUsage2(self):
"""
Make sure we don't blow the stack when yielding immediately
available values.
"""
return self._genStackUsage2().addCallback(self.assertEqual, 0)
class DeferredGeneratorTests(BaseDefgenTests, unittest.TestCase):
# First provide all the generator impls necessary for BaseDefgenTests
def _genBasics(self):
x = waitForDeferred(getThing())
yield x
x = x.getResult()
self.assertEqual(x, "hi")
ow = waitForDeferred(getOwie())
yield ow
try:
ow.getResult()
except ZeroDivisionError as e:
self.assertEqual(str(e), 'OMG')
yield "WOOSH"
return
_genBasics = deferredGenerator(_genBasics)
def _genBuggy(self):
yield waitForDeferred(getThing())
1//0
_genBuggy = deferredGenerator(_genBuggy)
def _genNothing(self):
if 0: yield 1
_genNothing = deferredGenerator(_genNothing)
def _genHandledTerminalFailure(self):
x = waitForDeferred(defer.fail(TerminalException("Handled Terminal Failure")))
yield x
try:
x.getResult()
except TerminalException:
pass
_genHandledTerminalFailure = deferredGenerator(_genHandledTerminalFailure)
def _genHandledTerminalAsyncFailure(self, d):
x = waitForDeferred(d)
yield x
try:
x.getResult()
except TerminalException:
pass
_genHandledTerminalAsyncFailure = deferredGenerator(_genHandledTerminalAsyncFailure)
def _genStackUsage(self):
for x in range(5000):
# Test with yielding a deferred
x = waitForDeferred(defer.succeed(1))
yield x
x = x.getResult()
yield 0
_genStackUsage = deferredGenerator(_genStackUsage)
def _genStackUsage2(self):
for x in range(5000):
# Test with yielding a random value
yield 1
yield 0
_genStackUsage2 = deferredGenerator(_genStackUsage2)
# Tests unique to deferredGenerator
def testDeferredYielding(self):
"""
Ensure that yielding a Deferred directly is trapped as an
error.
"""
# See the comment _deferGenerator about d.callback(Deferred).
def _genDeferred():
yield getThing()
_genDeferred = deferredGenerator(_genDeferred)
return self.assertFailure(_genDeferred(), TypeError)
class InlineCallbacksTests(BaseDefgenTests, unittest.TestCase):
# First provide all the generator impls necessary for BaseDefgenTests
def _genBasics(self):
x = yield getThing()
self.assertEqual(x, "hi")
try:
ow = yield getOwie()
except ZeroDivisionError as e:
self.assertEqual(str(e), 'OMG')
returnValue("WOOSH")
_genBasics = inlineCallbacks(_genBasics)
def _genBuggy(self):
yield getThing()
1/0
_genBuggy = inlineCallbacks(_genBuggy)
def _genNothing(self):
if 0: yield 1
_genNothing = inlineCallbacks(_genNothing)
def _genHandledTerminalFailure(self):
try:
x = yield defer.fail(TerminalException("Handled Terminal Failure"))
except TerminalException:
pass
_genHandledTerminalFailure = inlineCallbacks(_genHandledTerminalFailure)
def _genHandledTerminalAsyncFailure(self, d):
try:
x = yield d
except TerminalException:
pass
_genHandledTerminalAsyncFailure = inlineCallbacks(
_genHandledTerminalAsyncFailure)
def _genStackUsage(self):
for x in range(5000):
# Test with yielding a deferred
x = yield defer.succeed(1)
returnValue(0)
_genStackUsage = inlineCallbacks(_genStackUsage)
def _genStackUsage2(self):
for x in range(5000):
# Test with yielding a random value
yield 1
returnValue(0)
_genStackUsage2 = inlineCallbacks(_genStackUsage2)
# Tests unique to inlineCallbacks
def testYieldNonDeferrred(self):
"""
Ensure that yielding a non-deferred passes it back as the
result of the yield expression.
"""
def _test():
x = yield 5
returnValue(5)
_test = inlineCallbacks(_test)
return _test().addCallback(self.assertEqual, 5)
def testReturnNoValue(self):
"""Ensure a standard python return results in a None result."""
def _noReturn():
yield 5
return
_noReturn = inlineCallbacks(_noReturn)
return _noReturn().addCallback(self.assertEqual, None)
def testReturnValue(self):
"""Ensure that returnValue works."""
def _return():
yield 5
returnValue(6)
_return = inlineCallbacks(_return)
return _return().addCallback(self.assertEqual, 6)
def test_nonGeneratorReturn(self):
"""
Ensure that C{TypeError} with a message about L{inlineCallbacks} is
raised when a non-generator returns something other than a generator.
"""
def _noYield():
return 5
_noYield = inlineCallbacks(_noYield)
self.assertIn("inlineCallbacks",
str(self.assertRaises(TypeError, _noYield)))
def test_nonGeneratorReturnValue(self):
"""
Ensure that C{TypeError} with a message about L{inlineCallbacks} is
raised when a non-generator calls L{returnValue}.
"""
def _noYield():
returnValue(5)
_noYield = inlineCallbacks(_noYield)
self.assertIn("inlineCallbacks",
str(self.assertRaises(TypeError, _noYield)))

View file

@ -0,0 +1,22 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.trial import unittest
from twisted.protocols import dict
paramString = "\"This is a dqstring \\w\\i\\t\\h boring stuff like: \\\"\" and t\\hes\\\"e are a\\to\\ms"
goodparams = ["This is a dqstring with boring stuff like: \"", "and", "thes\"e", "are", "atoms"]
class ParamTest(unittest.TestCase):
def testParseParam(self):
"""Testing command response handling"""
params = []
rest = paramString
while 1:
(param, rest) = dict.parseParam(rest)
if param == None:
break
params.append(param)
self.assertEqual(params, goodparams)#, "DictClient.parseParam returns unexpected results")

View file

@ -0,0 +1,681 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.cred._digest} and the associated bits in
L{twisted.cred.credentials}.
"""
from hashlib import md5, sha1
from zope.interface.verify import verifyObject
from twisted.trial.unittest import TestCase
from twisted.internet.address import IPv4Address
from twisted.cred.error import LoginFailed
from twisted.cred.credentials import calcHA1, calcHA2, IUsernameDigestHash
from twisted.cred.credentials import calcResponse, DigestCredentialFactory
def b64encode(s):
return s.encode('base64').strip()
class FakeDigestCredentialFactory(DigestCredentialFactory):
"""
A Fake Digest Credential Factory that generates a predictable
nonce and opaque
"""
def __init__(self, *args, **kwargs):
super(FakeDigestCredentialFactory, self).__init__(*args, **kwargs)
self.privateKey = "0"
def _generateNonce(self):
"""
Generate a static nonce
"""
return '178288758716122392881254770685'
def _getTime(self):
"""
Return a stable time
"""
return 0
class DigestAuthTests(TestCase):
"""
L{TestCase} mixin class which defines a number of tests for
L{DigestCredentialFactory}. Because this mixin defines C{setUp}, it
must be inherited before L{TestCase}.
"""
def setUp(self):
"""
Create a DigestCredentialFactory for testing
"""
self.username = "foobar"
self.password = "bazquux"
self.realm = "test realm"
self.algorithm = "md5"
self.cnonce = "29fc54aa1641c6fa0e151419361c8f23"
self.qop = "auth"
self.uri = "/write/"
self.clientAddress = IPv4Address('TCP', '10.2.3.4', 43125)
self.method = 'GET'
self.credentialFactory = DigestCredentialFactory(
self.algorithm, self.realm)
def test_MD5HashA1(self, _algorithm='md5', _hash=md5):
"""
L{calcHA1} accepts the C{'md5'} algorithm and returns an MD5 hash of
its parameters, excluding the nonce and cnonce.
"""
nonce = 'abc123xyz'
hashA1 = calcHA1(_algorithm, self.username, self.realm, self.password,
nonce, self.cnonce)
a1 = '%s:%s:%s' % (self.username, self.realm, self.password)
expected = _hash(a1).hexdigest()
self.assertEqual(hashA1, expected)
def test_MD5SessionHashA1(self):
"""
L{calcHA1} accepts the C{'md5-sess'} algorithm and returns an MD5 hash
of its parameters, including the nonce and cnonce.
"""
nonce = 'xyz321abc'
hashA1 = calcHA1('md5-sess', self.username, self.realm, self.password,
nonce, self.cnonce)
a1 = '%s:%s:%s' % (self.username, self.realm, self.password)
ha1 = md5(a1).digest()
a1 = '%s:%s:%s' % (ha1, nonce, self.cnonce)
expected = md5(a1).hexdigest()
self.assertEqual(hashA1, expected)
def test_SHAHashA1(self):
"""
L{calcHA1} accepts the C{'sha'} algorithm and returns a SHA hash of its
parameters, excluding the nonce and cnonce.
"""
self.test_MD5HashA1('sha', sha1)
def test_MD5HashA2Auth(self, _algorithm='md5', _hash=md5):
"""
L{calcHA2} accepts the C{'md5'} algorithm and returns an MD5 hash of
its arguments, excluding the entity hash for QOP other than
C{'auth-int'}.
"""
method = 'GET'
hashA2 = calcHA2(_algorithm, method, self.uri, 'auth', None)
a2 = '%s:%s' % (method, self.uri)
expected = _hash(a2).hexdigest()
self.assertEqual(hashA2, expected)
def test_MD5HashA2AuthInt(self, _algorithm='md5', _hash=md5):
"""
L{calcHA2} accepts the C{'md5'} algorithm and returns an MD5 hash of
its arguments, including the entity hash for QOP of C{'auth-int'}.
"""
method = 'GET'
hentity = 'foobarbaz'
hashA2 = calcHA2(_algorithm, method, self.uri, 'auth-int', hentity)
a2 = '%s:%s:%s' % (method, self.uri, hentity)
expected = _hash(a2).hexdigest()
self.assertEqual(hashA2, expected)
def test_MD5SessHashA2Auth(self):
"""
L{calcHA2} accepts the C{'md5-sess'} algorithm and QOP of C{'auth'} and
returns the same value as it does for the C{'md5'} algorithm.
"""
self.test_MD5HashA2Auth('md5-sess')
def test_MD5SessHashA2AuthInt(self):
"""
L{calcHA2} accepts the C{'md5-sess'} algorithm and QOP of C{'auth-int'}
and returns the same value as it does for the C{'md5'} algorithm.
"""
self.test_MD5HashA2AuthInt('md5-sess')
def test_SHAHashA2Auth(self):
"""
L{calcHA2} accepts the C{'sha'} algorithm and returns a SHA hash of
its arguments, excluding the entity hash for QOP other than
C{'auth-int'}.
"""
self.test_MD5HashA2Auth('sha', sha1)
def test_SHAHashA2AuthInt(self):
"""
L{calcHA2} accepts the C{'sha'} algorithm and returns a SHA hash of
its arguments, including the entity hash for QOP of C{'auth-int'}.
"""
self.test_MD5HashA2AuthInt('sha', sha1)
def test_MD5HashResponse(self, _algorithm='md5', _hash=md5):
"""
L{calcResponse} accepts the C{'md5'} algorithm and returns an MD5 hash
of its parameters, excluding the nonce count, client nonce, and QoP
value if the nonce count and client nonce are C{None}
"""
hashA1 = 'abc123'
hashA2 = '789xyz'
nonce = 'lmnopq'
response = '%s:%s:%s' % (hashA1, nonce, hashA2)
expected = _hash(response).hexdigest()
digest = calcResponse(hashA1, hashA2, _algorithm, nonce, None, None,
None)
self.assertEqual(expected, digest)
def test_MD5SessionHashResponse(self):
"""
L{calcResponse} accepts the C{'md5-sess'} algorithm and returns an MD5
hash of its parameters, excluding the nonce count, client nonce, and
QoP value if the nonce count and client nonce are C{None}
"""
self.test_MD5HashResponse('md5-sess')
def test_SHAHashResponse(self):
"""
L{calcResponse} accepts the C{'sha'} algorithm and returns a SHA hash
of its parameters, excluding the nonce count, client nonce, and QoP
value if the nonce count and client nonce are C{None}
"""
self.test_MD5HashResponse('sha', sha1)
def test_MD5HashResponseExtra(self, _algorithm='md5', _hash=md5):
"""
L{calcResponse} accepts the C{'md5'} algorithm and returns an MD5 hash
of its parameters, including the nonce count, client nonce, and QoP
value if they are specified.
"""
hashA1 = 'abc123'
hashA2 = '789xyz'
nonce = 'lmnopq'
nonceCount = '00000004'
clientNonce = 'abcxyz123'
qop = 'auth'
response = '%s:%s:%s:%s:%s:%s' % (
hashA1, nonce, nonceCount, clientNonce, qop, hashA2)
expected = _hash(response).hexdigest()
digest = calcResponse(
hashA1, hashA2, _algorithm, nonce, nonceCount, clientNonce, qop)
self.assertEqual(expected, digest)
def test_MD5SessionHashResponseExtra(self):
"""
L{calcResponse} accepts the C{'md5-sess'} algorithm and returns an MD5
hash of its parameters, including the nonce count, client nonce, and
QoP value if they are specified.
"""
self.test_MD5HashResponseExtra('md5-sess')
def test_SHAHashResponseExtra(self):
"""
L{calcResponse} accepts the C{'sha'} algorithm and returns a SHA hash
of its parameters, including the nonce count, client nonce, and QoP
value if they are specified.
"""
self.test_MD5HashResponseExtra('sha', sha1)
def formatResponse(self, quotes=True, **kw):
"""
Format all given keyword arguments and their values suitably for use as
the value of an HTTP header.
@types quotes: C{bool}
@param quotes: A flag indicating whether to quote the values of each
field in the response.
@param **kw: Keywords and C{str} values which will be treated as field
name/value pairs to include in the result.
@rtype: C{str}
@return: The given fields formatted for use as an HTTP header value.
"""
if 'username' not in kw:
kw['username'] = self.username
if 'realm' not in kw:
kw['realm'] = self.realm
if 'algorithm' not in kw:
kw['algorithm'] = self.algorithm
if 'qop' not in kw:
kw['qop'] = self.qop
if 'cnonce' not in kw:
kw['cnonce'] = self.cnonce
if 'uri' not in kw:
kw['uri'] = self.uri
if quotes:
quote = '"'
else:
quote = ''
return ', '.join([
'%s=%s%s%s' % (k, quote, v, quote)
for (k, v)
in kw.iteritems()
if v is not None])
def getDigestResponse(self, challenge, ncount):
"""
Calculate the response for the given challenge
"""
nonce = challenge.get('nonce')
algo = challenge.get('algorithm').lower()
qop = challenge.get('qop')
ha1 = calcHA1(
algo, self.username, self.realm, self.password, nonce, self.cnonce)
ha2 = calcHA2(algo, "GET", self.uri, qop, None)
expected = calcResponse(ha1, ha2, algo, nonce, ncount, self.cnonce, qop)
return expected
def test_response(self, quotes=True):
"""
L{DigestCredentialFactory.decode} accepts a digest challenge response
and parses it into an L{IUsernameHashedPassword} provider.
"""
challenge = self.credentialFactory.getChallenge(self.clientAddress.host)
nc = "00000001"
clientResponse = self.formatResponse(
quotes=quotes,
nonce=challenge['nonce'],
response=self.getDigestResponse(challenge, nc),
nc=nc,
opaque=challenge['opaque'])
creds = self.credentialFactory.decode(
clientResponse, self.method, self.clientAddress.host)
self.assertTrue(creds.checkPassword(self.password))
self.assertFalse(creds.checkPassword(self.password + 'wrong'))
def test_responseWithoutQuotes(self):
"""
L{DigestCredentialFactory.decode} accepts a digest challenge response
which does not quote the values of its fields and parses it into an
L{IUsernameHashedPassword} provider in the same way it would a
response which included quoted field values.
"""
self.test_response(False)
def test_responseWithCommaURI(self):
"""
L{DigestCredentialFactory.decode} accepts a digest challenge response
which quotes the values of its fields and includes a C{b","} in the URI
field.
"""
self.uri = b"/some,path/"
self.test_response(True)
def test_caseInsensitiveAlgorithm(self):
"""
The case of the algorithm value in the response is ignored when
checking the credentials.
"""
self.algorithm = 'MD5'
self.test_response()
def test_md5DefaultAlgorithm(self):
"""
The algorithm defaults to MD5 if it is not supplied in the response.
"""
self.algorithm = None
self.test_response()
def test_responseWithoutClientIP(self):
"""
L{DigestCredentialFactory.decode} accepts a digest challenge response
even if the client address it is passed is C{None}.
"""
challenge = self.credentialFactory.getChallenge(None)
nc = "00000001"
clientResponse = self.formatResponse(
nonce=challenge['nonce'],
response=self.getDigestResponse(challenge, nc),
nc=nc,
opaque=challenge['opaque'])
creds = self.credentialFactory.decode(clientResponse, self.method, None)
self.assertTrue(creds.checkPassword(self.password))
self.assertFalse(creds.checkPassword(self.password + 'wrong'))
def test_multiResponse(self):
"""
L{DigestCredentialFactory.decode} handles multiple responses to a
single challenge.
"""
challenge = self.credentialFactory.getChallenge(self.clientAddress.host)
nc = "00000001"
clientResponse = self.formatResponse(
nonce=challenge['nonce'],
response=self.getDigestResponse(challenge, nc),
nc=nc,
opaque=challenge['opaque'])
creds = self.credentialFactory.decode(clientResponse, self.method,
self.clientAddress.host)
self.assertTrue(creds.checkPassword(self.password))
self.assertFalse(creds.checkPassword(self.password + 'wrong'))
nc = "00000002"
clientResponse = self.formatResponse(
nonce=challenge['nonce'],
response=self.getDigestResponse(challenge, nc),
nc=nc,
opaque=challenge['opaque'])
creds = self.credentialFactory.decode(clientResponse, self.method,
self.clientAddress.host)
self.assertTrue(creds.checkPassword(self.password))
self.assertFalse(creds.checkPassword(self.password + 'wrong'))
def test_failsWithDifferentMethod(self):
"""
L{DigestCredentialFactory.decode} returns an L{IUsernameHashedPassword}
provider which rejects a correct password for the given user if the
challenge response request is made using a different HTTP method than
was used to request the initial challenge.
"""
challenge = self.credentialFactory.getChallenge(self.clientAddress.host)
nc = "00000001"
clientResponse = self.formatResponse(
nonce=challenge['nonce'],
response=self.getDigestResponse(challenge, nc),
nc=nc,
opaque=challenge['opaque'])
creds = self.credentialFactory.decode(clientResponse, 'POST',
self.clientAddress.host)
self.assertFalse(creds.checkPassword(self.password))
self.assertFalse(creds.checkPassword(self.password + 'wrong'))
def test_noUsername(self):
"""
L{DigestCredentialFactory.decode} raises L{LoginFailed} if the response
has no username field or if the username field is empty.
"""
# Check for no username
e = self.assertRaises(
LoginFailed,
self.credentialFactory.decode,
self.formatResponse(username=None),
self.method, self.clientAddress.host)
self.assertEqual(str(e), "Invalid response, no username given.")
# Check for an empty username
e = self.assertRaises(
LoginFailed,
self.credentialFactory.decode,
self.formatResponse(username=""),
self.method, self.clientAddress.host)
self.assertEqual(str(e), "Invalid response, no username given.")
def test_noNonce(self):
"""
L{DigestCredentialFactory.decode} raises L{LoginFailed} if the response
has no nonce.
"""
e = self.assertRaises(
LoginFailed,
self.credentialFactory.decode,
self.formatResponse(opaque="abc123"),
self.method, self.clientAddress.host)
self.assertEqual(str(e), "Invalid response, no nonce given.")
def test_noOpaque(self):
"""
L{DigestCredentialFactory.decode} raises L{LoginFailed} if the response
has no opaque.
"""
e = self.assertRaises(
LoginFailed,
self.credentialFactory.decode,
self.formatResponse(),
self.method, self.clientAddress.host)
self.assertEqual(str(e), "Invalid response, no opaque given.")
def test_checkHash(self):
"""
L{DigestCredentialFactory.decode} returns an L{IUsernameDigestHash}
provider which can verify a hash of the form 'username:realm:password'.
"""
challenge = self.credentialFactory.getChallenge(self.clientAddress.host)
nc = "00000001"
clientResponse = self.formatResponse(
nonce=challenge['nonce'],
response=self.getDigestResponse(challenge, nc),
nc=nc,
opaque=challenge['opaque'])
creds = self.credentialFactory.decode(clientResponse, self.method,
self.clientAddress.host)
self.assertTrue(verifyObject(IUsernameDigestHash, creds))
cleartext = '%s:%s:%s' % (self.username, self.realm, self.password)
hash = md5(cleartext)
self.assertTrue(creds.checkHash(hash.hexdigest()))
hash.update('wrong')
self.assertFalse(creds.checkHash(hash.hexdigest()))
def test_invalidOpaque(self):
"""
L{DigestCredentialFactory.decode} raises L{LoginFailed} when the opaque
value does not contain all the required parts.
"""
credentialFactory = FakeDigestCredentialFactory(self.algorithm,
self.realm)
challenge = credentialFactory.getChallenge(self.clientAddress.host)
exc = self.assertRaises(
LoginFailed,
credentialFactory._verifyOpaque,
'badOpaque',
challenge['nonce'],
self.clientAddress.host)
self.assertEqual(str(exc), 'Invalid response, invalid opaque value')
badOpaque = 'foo-' + b64encode('nonce,clientip')
exc = self.assertRaises(
LoginFailed,
credentialFactory._verifyOpaque,
badOpaque,
challenge['nonce'],
self.clientAddress.host)
self.assertEqual(str(exc), 'Invalid response, invalid opaque value')
exc = self.assertRaises(
LoginFailed,
credentialFactory._verifyOpaque,
'',
challenge['nonce'],
self.clientAddress.host)
self.assertEqual(str(exc), 'Invalid response, invalid opaque value')
badOpaque = (
'foo-' + b64encode('%s,%s,foobar' % (
challenge['nonce'],
self.clientAddress.host)))
exc = self.assertRaises(
LoginFailed,
credentialFactory._verifyOpaque,
badOpaque,
challenge['nonce'],
self.clientAddress.host)
self.assertEqual(
str(exc), 'Invalid response, invalid opaque/time values')
def test_incompatibleNonce(self):
"""
L{DigestCredentialFactory.decode} raises L{LoginFailed} when the given
nonce from the response does not match the nonce encoded in the opaque.
"""
credentialFactory = FakeDigestCredentialFactory(self.algorithm, self.realm)
challenge = credentialFactory.getChallenge(self.clientAddress.host)
badNonceOpaque = credentialFactory._generateOpaque(
'1234567890',
self.clientAddress.host)
exc = self.assertRaises(
LoginFailed,
credentialFactory._verifyOpaque,
badNonceOpaque,
challenge['nonce'],
self.clientAddress.host)
self.assertEqual(
str(exc),
'Invalid response, incompatible opaque/nonce values')
exc = self.assertRaises(
LoginFailed,
credentialFactory._verifyOpaque,
badNonceOpaque,
'',
self.clientAddress.host)
self.assertEqual(
str(exc),
'Invalid response, incompatible opaque/nonce values')
def test_incompatibleClientIP(self):
"""
L{DigestCredentialFactory.decode} raises L{LoginFailed} when the
request comes from a client IP other than what is encoded in the
opaque.
"""
credentialFactory = FakeDigestCredentialFactory(self.algorithm, self.realm)
challenge = credentialFactory.getChallenge(self.clientAddress.host)
badAddress = '10.0.0.1'
# Sanity check
self.assertNotEqual(self.clientAddress.host, badAddress)
badNonceOpaque = credentialFactory._generateOpaque(
challenge['nonce'], badAddress)
self.assertRaises(
LoginFailed,
credentialFactory._verifyOpaque,
badNonceOpaque,
challenge['nonce'],
self.clientAddress.host)
def test_oldNonce(self):
"""
L{DigestCredentialFactory.decode} raises L{LoginFailed} when the given
opaque is older than C{DigestCredentialFactory.CHALLENGE_LIFETIME_SECS}
"""
credentialFactory = FakeDigestCredentialFactory(self.algorithm,
self.realm)
challenge = credentialFactory.getChallenge(self.clientAddress.host)
key = '%s,%s,%s' % (challenge['nonce'],
self.clientAddress.host,
'-137876876')
digest = md5(key + credentialFactory.privateKey).hexdigest()
ekey = b64encode(key)
oldNonceOpaque = '%s-%s' % (digest, ekey.strip('\n'))
self.assertRaises(
LoginFailed,
credentialFactory._verifyOpaque,
oldNonceOpaque,
challenge['nonce'],
self.clientAddress.host)
def test_mismatchedOpaqueChecksum(self):
"""
L{DigestCredentialFactory.decode} raises L{LoginFailed} when the opaque
checksum fails verification.
"""
credentialFactory = FakeDigestCredentialFactory(self.algorithm,
self.realm)
challenge = credentialFactory.getChallenge(self.clientAddress.host)
key = '%s,%s,%s' % (challenge['nonce'],
self.clientAddress.host,
'0')
digest = md5(key + 'this is not the right pkey').hexdigest()
badChecksum = '%s-%s' % (digest, b64encode(key))
self.assertRaises(
LoginFailed,
credentialFactory._verifyOpaque,
badChecksum,
challenge['nonce'],
self.clientAddress.host)
def test_incompatibleCalcHA1Options(self):
"""
L{calcHA1} raises L{TypeError} when any of the pszUsername, pszRealm,
or pszPassword arguments are specified with the preHA1 keyword
argument.
"""
arguments = (
("user", "realm", "password", "preHA1"),
(None, "realm", None, "preHA1"),
(None, None, "password", "preHA1"),
)
for pszUsername, pszRealm, pszPassword, preHA1 in arguments:
self.assertRaises(
TypeError,
calcHA1,
"md5",
pszUsername,
pszRealm,
pszPassword,
"nonce",
"cnonce",
preHA1=preHA1)
def test_noNewlineOpaque(self):
"""
L{DigestCredentialFactory._generateOpaque} returns a value without
newlines, regardless of the length of the nonce.
"""
opaque = self.credentialFactory._generateOpaque(
"long nonce " * 10, None)
self.assertNotIn('\n', opaque)

View file

@ -0,0 +1,170 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for dirdbm module.
"""
import os, shutil, glob
from twisted.trial import unittest
from twisted.persisted import dirdbm
class DirDbmTestCase(unittest.TestCase):
def setUp(self):
self.path = self.mktemp()
self.dbm = dirdbm.open(self.path)
self.items = (('abc', 'foo'), ('/lalal', '\000\001'), ('\000\012', 'baz'))
def testAll(self):
k = "//==".decode("base64")
self.dbm[k] = "a"
self.dbm[k] = "a"
self.assertEqual(self.dbm[k], "a")
def testRebuildInteraction(self):
from twisted.persisted import dirdbm
from twisted.python import rebuild
s = dirdbm.Shelf('dirdbm.rebuild.test')
s['key'] = 'value'
rebuild.rebuild(dirdbm)
# print s['key']
def testDbm(self):
d = self.dbm
# insert keys
keys = []
values = set()
for k, v in self.items:
d[k] = v
keys.append(k)
values.add(v)
keys.sort()
# check they exist
for k, v in self.items:
assert d.has_key(k), "has_key() failed"
assert d[k] == v, "database has wrong value"
# check non existent key
try:
d["XXX"]
except KeyError:
pass
else:
assert 0, "didn't raise KeyError on non-existent key"
# check keys(), values() and items()
dbkeys = list(d.keys())
dbvalues = set(d.values())
dbitems = set(d.items())
dbkeys.sort()
items = set(self.items)
assert keys == dbkeys, ".keys() output didn't match: %s != %s" % (repr(keys), repr(dbkeys))
assert values == dbvalues, ".values() output didn't match: %s != %s" % (repr(values), repr(dbvalues))
assert items == dbitems, "items() didn't match: %s != %s" % (repr(items), repr(dbitems))
copyPath = self.mktemp()
d2 = d.copyTo(copyPath)
copykeys = list(d.keys())
copyvalues = set(d.values())
copyitems = set(d.items())
copykeys.sort()
assert dbkeys == copykeys, ".copyTo().keys() didn't match: %s != %s" % (repr(dbkeys), repr(copykeys))
assert dbvalues == copyvalues, ".copyTo().values() didn't match: %s != %s" % (repr(dbvalues), repr(copyvalues))
assert dbitems == copyitems, ".copyTo().items() didn't match: %s != %s" % (repr(dbkeys), repr(copyitems))
d2.clear()
assert len(d2.keys()) == len(d2.values()) == len(d2.items()) == 0, ".clear() failed"
shutil.rmtree(copyPath)
# delete items
for k, v in self.items:
del d[k]
assert not d.has_key(k), "has_key() even though we deleted it"
assert len(d.keys()) == 0, "database has keys"
assert len(d.values()) == 0, "database has values"
assert len(d.items()) == 0, "database has items"
def testModificationTime(self):
import time
# the mtime value for files comes from a different place than the
# gettimeofday() system call. On linux, gettimeofday() can be
# slightly ahead (due to clock drift which gettimeofday() takes into
# account but which open()/write()/close() do not), and if we are
# close to the edge of the next second, time.time() can give a value
# which is larger than the mtime which results from a subsequent
# write(). I consider this a kernel bug, but it is beyond the scope
# of this test. Thus we keep the range of acceptability to 3 seconds time.
# -warner
self.dbm["k"] = "v"
self.assert_(abs(time.time() - self.dbm.getModificationTime("k")) <= 3)
def testRecovery(self):
"""DirDBM: test recovery from directory after a faked crash"""
k = self.dbm._encode("key1")
f = open(os.path.join(self.path, k + ".rpl"), "wb")
f.write("value")
f.close()
k2 = self.dbm._encode("key2")
f = open(os.path.join(self.path, k2), "wb")
f.write("correct")
f.close()
f = open(os.path.join(self.path, k2 + ".rpl"), "wb")
f.write("wrong")
f.close()
f = open(os.path.join(self.path, "aa.new"), "wb")
f.write("deleted")
f.close()
dbm = dirdbm.DirDBM(self.path)
assert dbm["key1"] == "value"
assert dbm["key2"] == "correct"
assert not glob.glob(os.path.join(self.path, "*.new"))
assert not glob.glob(os.path.join(self.path, "*.rpl"))
def test_nonStringKeys(self):
"""
L{dirdbm.DirDBM} operations only support string keys: other types
should raise a C{AssertionError}. This really ought to be a
C{TypeError}, but it'll stay like this for backward compatibility.
"""
self.assertRaises(AssertionError, self.dbm.__setitem__, 2, "3")
try:
self.assertRaises(AssertionError, self.dbm.__setitem__, "2", 3)
except unittest.FailTest:
# dirdbm.Shelf.__setitem__ supports non-string values
self.assertIsInstance(self.dbm, dirdbm.Shelf)
self.assertRaises(AssertionError, self.dbm.__getitem__, 2)
self.assertRaises(AssertionError, self.dbm.__delitem__, 2)
self.assertRaises(AssertionError, self.dbm.has_key, 2)
self.assertRaises(AssertionError, self.dbm.__contains__, 2)
self.assertRaises(AssertionError, self.dbm.getModificationTime, 2)
class ShelfTestCase(DirDbmTestCase):
def setUp(self):
self.path = self.mktemp()
self.dbm = dirdbm.Shelf(self.path)
self.items = (('abc', 'foo'), ('/lalal', '\000\001'), ('\000\012', 'baz'),
('int', 12), ('float', 12.0), ('tuple', (None, 12)))
testCases = [DirDbmTestCase, ShelfTestCase]

View file

@ -0,0 +1,104 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import inspect, glob
from os import path
from twisted.trial import unittest
from twisted.python import reflect
from twisted.python.modules import getModule
def errorInFile(f, line=17, name=''):
"""
Return a filename formatted so emacs will recognize it as an error point
@param line: Line number in file. Defaults to 17 because that's about how
long the copyright headers are.
"""
return '%s:%d:%s' % (f, line, name)
# return 'File "%s", line %d, in %s' % (f, line, name)
class DocCoverage(unittest.TestCase):
"""
Looking for docstrings in all modules and packages.
"""
def setUp(self):
self.packageNames = []
for mod in getModule('twisted').walkModules():
if mod.isPackage():
self.packageNames.append(mod.name)
def testModules(self):
"""
Looking for docstrings in all modules.
"""
docless = []
for packageName in self.packageNames:
if packageName in ('twisted.test',):
# because some stuff in here behaves oddly when imported
continue
try:
package = reflect.namedModule(packageName)
except ImportError, e:
# This is testing doc coverage, not importability.
# (Really, I don't want to deal with the fact that I don't
# have pyserial installed.)
# print e
pass
else:
docless.extend(self.modulesInPackage(packageName, package))
self.failIf(docless, "No docstrings in module files:\n"
"%s" % ('\n'.join(map(errorInFile, docless)),))
def modulesInPackage(self, packageName, package):
docless = []
directory = path.dirname(package.__file__)
for modfile in glob.glob(path.join(directory, '*.py')):
moduleName = inspect.getmodulename(modfile)
if moduleName == '__init__':
# These are tested by test_packages.
continue
elif moduleName in ('spelunk_gnome','gtkmanhole'):
# argh special case pygtk evil argh. How does epydoc deal
# with this?
continue
try:
module = reflect.namedModule('.'.join([packageName,
moduleName]))
except Exception, e:
# print moduleName, "misbehaved:", e
pass
else:
if not inspect.getdoc(module):
docless.append(modfile)
return docless
def testPackages(self):
"""
Looking for docstrings in all packages.
"""
docless = []
for packageName in self.packageNames:
try:
package = reflect.namedModule(packageName)
except Exception, e:
# This is testing doc coverage, not importability.
# (Really, I don't want to deal with the fact that I don't
# have pyserial installed.)
# print e
pass
else:
if not inspect.getdoc(package):
docless.append(package.__file__.replace('.pyc','.py'))
self.failIf(docless, "No docstrings for package files\n"
"%s" % ('\n'.join(map(errorInFile, docless),)))
# This test takes a while and doesn't come close to passing. :(
testModules.skip = "Activate me when you feel like writing docstrings, and fixing GTK crashing bugs."

View file

@ -0,0 +1,266 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from __future__ import division, absolute_import
import socket, errno
from twisted.trial import unittest
from twisted.internet import error
from twisted.python.runtime import platformType
class TestStringification(unittest.SynchronousTestCase):
"""Test that the exceptions have useful stringifications.
"""
listOfTests = [
#(output, exception[, args[, kwargs]]),
("An error occurred binding to an interface.",
error.BindError),
("An error occurred binding to an interface: foo.",
error.BindError, ['foo']),
("An error occurred binding to an interface: foo bar.",
error.BindError, ['foo', 'bar']),
("Couldn't listen on eth0:4242: Foo.",
error.CannotListenError,
('eth0', 4242, socket.error('Foo'))),
("Message is too long to send.",
error.MessageLengthError),
("Message is too long to send: foo bar.",
error.MessageLengthError, ['foo', 'bar']),
("DNS lookup failed.",
error.DNSLookupError),
("DNS lookup failed: foo bar.",
error.DNSLookupError, ['foo', 'bar']),
("An error occurred while connecting.",
error.ConnectError),
("An error occurred while connecting: someOsError.",
error.ConnectError, ['someOsError']),
("An error occurred while connecting: foo.",
error.ConnectError, [], {'string': 'foo'}),
("An error occurred while connecting: someOsError: foo.",
error.ConnectError, ['someOsError', 'foo']),
("Couldn't bind.",
error.ConnectBindError),
("Couldn't bind: someOsError.",
error.ConnectBindError, ['someOsError']),
("Couldn't bind: someOsError: foo.",
error.ConnectBindError, ['someOsError', 'foo']),
("Hostname couldn't be looked up.",
error.UnknownHostError),
("No route to host.",
error.NoRouteError),
("Connection was refused by other side.",
error.ConnectionRefusedError),
("TCP connection timed out.",
error.TCPTimedOutError),
("File used for UNIX socket is no good.",
error.BadFileError),
("Service name given as port is unknown.",
error.ServiceNameUnknownError),
("User aborted connection.",
error.UserError),
("User timeout caused connection failure.",
error.TimeoutError),
("An SSL error occurred.",
error.SSLError),
("Connection to the other side was lost in a non-clean fashion.",
error.ConnectionLost),
("Connection to the other side was lost in a non-clean fashion: foo bar.",
error.ConnectionLost, ['foo', 'bar']),
("Connection was closed cleanly.",
error.ConnectionDone),
("Connection was closed cleanly: foo bar.",
error.ConnectionDone, ['foo', 'bar']),
("Uh.", #TODO nice docstring, you've got there.
error.ConnectionFdescWentAway),
("Tried to cancel an already-called event.",
error.AlreadyCalled),
("Tried to cancel an already-called event: foo bar.",
error.AlreadyCalled, ['foo', 'bar']),
("Tried to cancel an already-cancelled event.",
error.AlreadyCancelled),
("Tried to cancel an already-cancelled event: x 2.",
error.AlreadyCancelled, ["x", "2"]),
("A process has ended without apparent errors: process finished with exit code 0.",
error.ProcessDone,
[None]),
("A process has ended with a probable error condition: process ended.",
error.ProcessTerminated),
("A process has ended with a probable error condition: process ended with exit code 42.",
error.ProcessTerminated,
[],
{'exitCode': 42}),
("A process has ended with a probable error condition: process ended by signal SIGBUS.",
error.ProcessTerminated,
[],
{'signal': 'SIGBUS'}),
("The Connector was not connecting when it was asked to stop connecting.",
error.NotConnectingError),
("The Connector was not connecting when it was asked to stop connecting: x 13.",
error.NotConnectingError, ["x", "13"]),
("The Port was not listening when it was asked to stop listening.",
error.NotListeningError),
("The Port was not listening when it was asked to stop listening: a 12.",
error.NotListeningError, ["a", "12"]),
]
def testThemAll(self):
for entry in self.listOfTests:
output = entry[0]
exception = entry[1]
try:
args = entry[2]
except IndexError:
args = ()
try:
kwargs = entry[3]
except IndexError:
kwargs = {}
self.assertEqual(
str(exception(*args, **kwargs)),
output)
def test_connectingCancelledError(self):
"""
L{error.ConnectingCancelledError} has an C{address} attribute.
"""
address = object()
e = error.ConnectingCancelledError(address)
self.assertIdentical(e.address, address)
class SubclassingTests(unittest.SynchronousTestCase):
"""
Some exceptions are subclasses of other exceptions.
"""
def test_connectionLostSubclassOfConnectionClosed(self):
"""
L{error.ConnectionClosed} is a superclass of L{error.ConnectionLost}.
"""
self.assertTrue(issubclass(error.ConnectionLost,
error.ConnectionClosed))
def test_connectionDoneSubclassOfConnectionClosed(self):
"""
L{error.ConnectionClosed} is a superclass of L{error.ConnectionDone}.
"""
self.assertTrue(issubclass(error.ConnectionDone,
error.ConnectionClosed))
def test_invalidAddressErrorSubclassOfValueError(self):
"""
L{ValueError} is a superclass of L{error.InvalidAddressError}.
"""
self.assertTrue(issubclass(error.InvalidAddressError,
ValueError))
class GetConnectErrorTests(unittest.SynchronousTestCase):
"""
Given an exception instance thrown by C{socket.connect},
L{error.getConnectError} returns the appropriate high-level Twisted
exception instance.
"""
def assertErrnoException(self, errno, expectedClass):
"""
When called with a tuple with the given errno,
L{error.getConnectError} returns an exception which is an instance of
the expected class.
"""
e = (errno, "lalala")
result = error.getConnectError(e)
self.assertCorrectException(errno, "lalala", result, expectedClass)
def assertCorrectException(self, errno, message, result, expectedClass):
"""
The given result of L{error.getConnectError} has the given attributes
(C{osError} and C{args}), and is an instance of the given class.
"""
# Want exact class match, not inherited classes, so no isinstance():
self.assertEqual(result.__class__, expectedClass)
self.assertEqual(result.osError, errno)
self.assertEqual(result.args, (message,))
def test_errno(self):
"""
L{error.getConnectError} converts based on errno for C{socket.error}.
"""
self.assertErrnoException(errno.ENETUNREACH, error.NoRouteError)
self.assertErrnoException(errno.ECONNREFUSED, error.ConnectionRefusedError)
self.assertErrnoException(errno.ETIMEDOUT, error.TCPTimedOutError)
if platformType == "win32":
self.assertErrnoException(errno.WSAECONNREFUSED, error.ConnectionRefusedError)
self.assertErrnoException(errno.WSAENETUNREACH, error.NoRouteError)
def test_gaierror(self):
"""
L{error.getConnectError} converts to a L{error.UnknownHostError} given
a C{socket.gaierror} instance.
"""
result = error.getConnectError(socket.gaierror(12, "hello"))
self.assertCorrectException(12, "hello", result, error.UnknownHostError)
def test_nonTuple(self):
"""
L{error.getConnectError} converts to a L{error.ConnectError} given
an argument that cannot be unpacked.
"""
e = Exception()
result = error.getConnectError(e)
self.assertCorrectException(None, e, result, error.ConnectError)

View file

@ -0,0 +1,236 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for explorer
"""
from twisted.trial import unittest
from twisted.manhole import explorer
import types
"""
# Tests:
Get an ObjectLink. Browse ObjectLink.identifier. Is it the same?
Watch Object. Make sure an ObjectLink is received when:
Call a method.
Set an attribute.
Have an Object with a setattr class. Watch it.
Do both the navite setattr and the watcher get called?
Sequences with circular references. Does it blow up?
"""
class SomeDohickey:
def __init__(self, *a):
self.__dict__['args'] = a
def bip(self):
return self.args
class TestBrowser(unittest.TestCase):
def setUp(self):
self.pool = explorer.explorerPool
self.pool.clear()
self.testThing = ["How many stairs must a man climb down?",
SomeDohickey(42)]
def test_chain(self):
"Following a chain of Explorers."
xplorer = self.pool.getExplorer(self.testThing, 'testThing')
self.assertEqual(xplorer.id, id(self.testThing))
self.assertEqual(xplorer.identifier, 'testThing')
dxplorer = xplorer.get_elements()[1]
self.assertEqual(dxplorer.id, id(self.testThing[1]))
class Watcher:
zero = 0
def __init__(self):
self.links = []
def receiveBrowserObject(self, olink):
self.links.append(olink)
def setZero(self):
self.zero = len(self.links)
def len(self):
return len(self.links) - self.zero
class SetattrDohickey:
def __setattr__(self, k, v):
v = list(str(v))
v.reverse()
self.__dict__[k] = ''.join(v)
class MiddleMan(SomeDohickey, SetattrDohickey):
pass
# class TestWatch(unittest.TestCase):
class FIXME_Watch:
def setUp(self):
self.globalNS = globals().copy()
self.localNS = {}
self.browser = explorer.ObjectBrowser(self.globalNS, self.localNS)
self.watcher = Watcher()
def test_setAttrPlain(self):
"Triggering a watcher response by setting an attribute."
testThing = SomeDohickey('pencil')
self.browser.watchObject(testThing, 'testThing',
self.watcher.receiveBrowserObject)
self.watcher.setZero()
testThing.someAttr = 'someValue'
self.assertEqual(testThing.someAttr, 'someValue')
self.failUnless(self.watcher.len())
olink = self.watcher.links[-1]
self.assertEqual(olink.id, id(testThing))
def test_setAttrChain(self):
"Setting an attribute on a watched object that has __setattr__"
testThing = MiddleMan('pencil')
self.browser.watchObject(testThing, 'testThing',
self.watcher.receiveBrowserObject)
self.watcher.setZero()
testThing.someAttr = 'ZORT'
self.assertEqual(testThing.someAttr, 'TROZ')
self.failUnless(self.watcher.len())
olink = self.watcher.links[-1]
self.assertEqual(olink.id, id(testThing))
def test_method(self):
"Triggering a watcher response by invoking a method."
for testThing in (SomeDohickey('pencil'), MiddleMan('pencil')):
self.browser.watchObject(testThing, 'testThing',
self.watcher.receiveBrowserObject)
self.watcher.setZero()
rval = testThing.bip()
self.assertEqual(rval, ('pencil',))
self.failUnless(self.watcher.len())
olink = self.watcher.links[-1]
self.assertEqual(olink.id, id(testThing))
def function_noArgs():
"A function which accepts no arguments at all."
return
def function_simple(a, b, c):
"A function which accepts several arguments."
return a, b, c
def function_variable(*a, **kw):
"A function which accepts a variable number of args and keywords."
return a, kw
def function_crazy((alpha, beta), c, d=range(4), **kw):
"A function with a mad crazy signature."
return alpha, beta, c, d, kw
class TestBrowseFunction(unittest.TestCase):
def setUp(self):
self.pool = explorer.explorerPool
self.pool.clear()
def test_sanity(self):
"""Basic checks for browse_function.
Was the proper type returned? Does it have the right name and ID?
"""
for f_name in ('function_noArgs', 'function_simple',
'function_variable', 'function_crazy'):
f = eval(f_name)
xplorer = self.pool.getExplorer(f, f_name)
self.assertEqual(xplorer.id, id(f))
self.failUnless(isinstance(xplorer, explorer.ExplorerFunction))
self.assertEqual(xplorer.name, f_name)
def test_signature_noArgs(self):
"""Testing zero-argument function signature.
"""
xplorer = self.pool.getExplorer(function_noArgs, 'function_noArgs')
self.assertEqual(len(xplorer.signature), 0)
def test_signature_simple(self):
"""Testing simple function signature.
"""
xplorer = self.pool.getExplorer(function_simple, 'function_simple')
expected_signature = ('a','b','c')
self.assertEqual(xplorer.signature.name, expected_signature)
def test_signature_variable(self):
"""Testing variable-argument function signature.
"""
xplorer = self.pool.getExplorer(function_variable,
'function_variable')
expected_names = ('a','kw')
signature = xplorer.signature
self.assertEqual(signature.name, expected_names)
self.failUnless(signature.is_varlist(0))
self.failUnless(signature.is_keyword(1))
def test_signature_crazy(self):
"""Testing function with crazy signature.
"""
xplorer = self.pool.getExplorer(function_crazy, 'function_crazy')
signature = xplorer.signature
expected_signature = [{'name': 'c'},
{'name': 'd',
'default': range(4)},
{'name': 'kw',
'keywords': 1}]
# The name of the first argument seems to be indecipherable,
# but make sure it has one (and no default).
self.failUnless(signature.get_name(0))
self.failUnless(not signature.get_default(0)[0])
self.assertEqual(signature.get_name(1), 'c')
# Get a list of values from a list of ExplorerImmutables.
arg_2_default = map(lambda l: l.value,
signature.get_default(2)[1].get_elements())
self.assertEqual(signature.get_name(2), 'd')
self.assertEqual(arg_2_default, range(4))
self.assertEqual(signature.get_name(3), 'kw')
self.failUnless(signature.is_keyword(3))
if __name__ == '__main__':
unittest.main()

View file

@ -0,0 +1,145 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test code for basic Factory classes.
"""
from __future__ import division, absolute_import
import pickle
from twisted.trial.unittest import TestCase
from twisted.internet.task import Clock
from twisted.internet.protocol import ReconnectingClientFactory, Protocol
class FakeConnector(object):
"""
A fake connector class, to be used to mock connections failed or lost.
"""
def stopConnecting(self):
pass
def connect(self):
pass
class ReconnectingFactoryTestCase(TestCase):
"""
Tests for L{ReconnectingClientFactory}.
"""
def test_stopTryingWhenConnected(self):
"""
If a L{ReconnectingClientFactory} has C{stopTrying} called while it is
connected, it does not subsequently attempt to reconnect if the
connection is later lost.
"""
class NoConnectConnector(object):
def stopConnecting(self):
raise RuntimeError("Shouldn't be called, we're connected.")
def connect(self):
raise RuntimeError("Shouldn't be reconnecting.")
c = ReconnectingClientFactory()
c.protocol = Protocol
# Let's pretend we've connected:
c.buildProtocol(None)
# Now we stop trying, then disconnect:
c.stopTrying()
c.clientConnectionLost(NoConnectConnector(), None)
self.assertFalse(c.continueTrying)
def test_stopTryingDoesNotReconnect(self):
"""
Calling stopTrying on a L{ReconnectingClientFactory} doesn't attempt a
retry on any active connector.
"""
class FactoryAwareFakeConnector(FakeConnector):
attemptedRetry = False
def stopConnecting(self):
"""
Behave as though an ongoing connection attempt has now
failed, and notify the factory of this.
"""
f.clientConnectionFailed(self, None)
def connect(self):
"""
Record an attempt to reconnect, since this is what we
are trying to avoid.
"""
self.attemptedRetry = True
f = ReconnectingClientFactory()
f.clock = Clock()
# simulate an active connection - stopConnecting on this connector should
# be triggered when we call stopTrying
f.connector = FactoryAwareFakeConnector()
f.stopTrying()
# make sure we never attempted to retry
self.assertFalse(f.connector.attemptedRetry)
self.assertFalse(f.clock.getDelayedCalls())
def test_serializeUnused(self):
"""
A L{ReconnectingClientFactory} which hasn't been used for anything
can be pickled and unpickled and end up with the same state.
"""
original = ReconnectingClientFactory()
reconstituted = pickle.loads(pickle.dumps(original))
self.assertEqual(original.__dict__, reconstituted.__dict__)
def test_serializeWithClock(self):
"""
The clock attribute of L{ReconnectingClientFactory} is not serialized,
and the restored value sets it to the default value, the reactor.
"""
clock = Clock()
original = ReconnectingClientFactory()
original.clock = clock
reconstituted = pickle.loads(pickle.dumps(original))
self.assertIdentical(reconstituted.clock, None)
def test_deserializationResetsParameters(self):
"""
A L{ReconnectingClientFactory} which is unpickled does not have an
L{IConnector} and has its reconnecting timing parameters reset to their
initial values.
"""
factory = ReconnectingClientFactory()
factory.clientConnectionFailed(FakeConnector(), None)
self.addCleanup(factory.stopTrying)
serialized = pickle.dumps(factory)
unserialized = pickle.loads(serialized)
self.assertEqual(unserialized.connector, None)
self.assertEqual(unserialized._callID, None)
self.assertEqual(unserialized.retries, 0)
self.assertEqual(unserialized.delay, factory.initialDelay)
self.assertEqual(unserialized.continueTrying, True)
def test_parametrizedClock(self):
"""
The clock used by L{ReconnectingClientFactory} can be parametrized, so
that one can cleanly test reconnections.
"""
clock = Clock()
factory = ReconnectingClientFactory()
factory.clock = clock
factory.clientConnectionLost(FakeConnector(), None)
self.assertEqual(len(clock.calls), 1)

View file

@ -0,0 +1,993 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for the L{twisted.python.failure} module.
"""
from __future__ import division, absolute_import
import re
import sys
import traceback
import pdb
import linecache
from twisted.python.compat import NativeStringIO, _PY3
from twisted.python import reflect
from twisted.python import failure
from twisted.trial.unittest import SynchronousTestCase
try:
from twisted.test import raiser
except ImportError:
raiser = None
def getDivisionFailure(*args, **kwargs):
"""
Make a C{Failure} of a divide-by-zero error.
@param args: Any C{*args} are passed to Failure's constructor.
@param kwargs: Any C{**kwargs} are passed to Failure's constructor.
"""
try:
1/0
except:
f = failure.Failure(*args, **kwargs)
return f
class FailureTestCase(SynchronousTestCase):
"""
Tests for L{failure.Failure}.
"""
def test_failAndTrap(self):
"""
Trapping a L{Failure}.
"""
try:
raise NotImplementedError('test')
except:
f = failure.Failure()
error = f.trap(SystemExit, RuntimeError)
self.assertEqual(error, RuntimeError)
self.assertEqual(f.type, NotImplementedError)
def test_trapRaisesCurrentFailure(self):
"""
If the wrapped C{Exception} is not a subclass of one of the
expected types, L{failure.Failure.trap} raises the current
L{failure.Failure} ie C{self}.
"""
exception = ValueError()
try:
raise exception
except:
f = failure.Failure()
untrapped = self.assertRaises(failure.Failure, f.trap, OverflowError)
self.assertIdentical(f, untrapped)
if _PY3:
test_trapRaisesCurrentFailure.skip = (
"In Python3, Failure.trap raises the wrapped Exception "
"instead of the original Failure instance.")
def test_trapRaisesWrappedException(self):
"""
If the wrapped C{Exception} is not a subclass of one of the
expected types, L{failure.Failure.trap} raises the wrapped
C{Exception}.
"""
exception = ValueError()
try:
raise exception
except:
f = failure.Failure()
untrapped = self.assertRaises(ValueError, f.trap, OverflowError)
self.assertIdentical(exception, untrapped)
if not _PY3:
test_trapRaisesWrappedException.skip = (
"In Python2, Failure.trap raises the current Failure instance "
"instead of the wrapped Exception.")
def test_failureValueFromFailure(self):
"""
A L{failure.Failure} constructed from another
L{failure.Failure} instance, has its C{value} property set to
the value of that L{failure.Failure} instance.
"""
exception = ValueError()
f1 = failure.Failure(exception)
f2 = failure.Failure(f1)
self.assertIdentical(f2.value, exception)
def test_failureValueFromFoundFailure(self):
"""
A L{failure.Failure} constructed without a C{exc_value}
argument, will search for an "original" C{Failure}, and if
found, its value will be used as the value for the new
C{Failure}.
"""
exception = ValueError()
f1 = failure.Failure(exception)
try:
f1.trap(OverflowError)
except:
f2 = failure.Failure()
self.assertIdentical(f2.value, exception)
def assertStartsWith(self, s, prefix):
"""
Assert that C{s} starts with a particular C{prefix}.
@param s: The input string.
@type s: C{str}
@param prefix: The string that C{s} should start with.
@type prefix: C{str}
"""
self.assertTrue(s.startswith(prefix),
'%r is not the start of %r' % (prefix, s))
def assertEndsWith(self, s, suffix):
"""
Assert that C{s} end with a particular C{suffix}.
@param s: The input string.
@type s: C{str}
@param suffix: The string that C{s} should end with.
@type suffix: C{str}
"""
self.assertTrue(s.endswith(suffix),
'%r is not the end of %r' % (suffix, s))
def assertTracebackFormat(self, tb, prefix, suffix):
"""
Assert that the C{tb} traceback contains a particular C{prefix} and
C{suffix}.
@param tb: The traceback string.
@type tb: C{str}
@param prefix: The string that C{tb} should start with.
@type prefix: C{str}
@param suffix: The string that C{tb} should end with.
@type suffix: C{str}
"""
self.assertStartsWith(tb, prefix)
self.assertEndsWith(tb, suffix)
def assertDetailedTraceback(self, captureVars=False, cleanFailure=False):
"""
Assert that L{printDetailedTraceback} produces and prints a detailed
traceback.
The detailed traceback consists of a header::
*--- Failure #20 ---
The body contains the stacktrace::
/twisted/trial/_synctest.py:1180: _run(...)
/twisted/python/util.py:1076: runWithWarningsSuppressed(...)
--- <exception caught here> ---
/twisted/test/test_failure.py:39: getDivisionFailure(...)
If C{captureVars} is enabled the body also includes a list of
globals and locals::
[ Locals ]
exampleLocalVar : 'xyz'
...
( Globals )
...
Or when C{captureVars} is disabled::
[Capture of Locals and Globals disabled (use captureVars=True)]
When C{cleanFailure} is enabled references to other objects are removed
and replaced with strings.
And finally the footer with the L{Failure}'s value::
exceptions.ZeroDivisionError: float division
*--- End of Failure #20 ---
@param captureVars: Enables L{Failure.captureVars}.
@type captureVars: C{bool}
@param cleanFailure: Enables L{Failure.cleanFailure}.
@type cleanFailure: C{bool}
"""
if captureVars:
exampleLocalVar = 'xyz'
f = getDivisionFailure(captureVars=captureVars)
out = NativeStringIO()
if cleanFailure:
f.cleanFailure()
f.printDetailedTraceback(out)
tb = out.getvalue()
start = "*--- Failure #%d%s---\n" % (f.count,
(f.pickled and ' (pickled) ') or ' ')
end = "%s: %s\n*--- End of Failure #%s ---\n" % (reflect.qual(f.type),
reflect.safe_str(f.value), f.count)
self.assertTracebackFormat(tb, start, end)
# Variables are printed on lines with 2 leading spaces.
linesWithVars = [line for line in tb.splitlines()
if line.startswith(' ')]
if captureVars:
self.assertNotEqual([], linesWithVars)
if cleanFailure:
line = ' exampleLocalVar : "\'xyz\'"'
else:
line = " exampleLocalVar : 'xyz'"
self.assertIn(line, linesWithVars)
else:
self.assertEqual([], linesWithVars)
self.assertIn(' [Capture of Locals and Globals disabled (use '
'captureVars=True)]\n', tb)
def assertBriefTraceback(self, captureVars=False):
"""
Assert that L{printBriefTraceback} produces and prints a brief
traceback.
The brief traceback consists of a header::
Traceback: <type 'exceptions.ZeroDivisionError'>: float division
The body with the stacktrace::
/twisted/trial/_synctest.py:1180:_run
/twisted/python/util.py:1076:runWithWarningsSuppressed
And the footer::
--- <exception caught here> ---
/twisted/test/test_failure.py:39:getDivisionFailure
@param captureVars: Enables L{Failure.captureVars}.
@type captureVars: C{bool}
"""
if captureVars:
exampleLocalVar = 'abcde'
f = getDivisionFailure()
out = NativeStringIO()
f.printBriefTraceback(out)
tb = out.getvalue()
stack = ''
for method, filename, lineno, localVars, globalVars in f.frames:
stack += '%s:%s:%s\n' % (filename, lineno, method)
if _PY3:
zde = "class 'ZeroDivisionError'"
else:
zde = "type 'exceptions.ZeroDivisionError'"
self.assertTracebackFormat(tb,
"Traceback: <%s>: " % (zde,),
"%s\n%s" % (failure.EXCEPTION_CAUGHT_HERE, stack))
if captureVars:
self.assertEqual(None, re.search('exampleLocalVar.*abcde', tb))
def assertDefaultTraceback(self, captureVars=False):
"""
Assert that L{printTraceback} produces and prints a default traceback.
The default traceback consists of a header::
Traceback (most recent call last):
The body with traceback::
File "/twisted/trial/_synctest.py", line 1180, in _run
runWithWarningsSuppressed(suppress, method)
And the footer::
--- <exception caught here> ---
File "twisted/test/test_failure.py", line 39, in getDivisionFailure
1/0
exceptions.ZeroDivisionError: float division
@param captureVars: Enables L{Failure.captureVars}.
@type captureVars: C{bool}
"""
if captureVars:
exampleLocalVar = 'xyzzy'
f = getDivisionFailure(captureVars=captureVars)
out = NativeStringIO()
f.printTraceback(out)
tb = out.getvalue()
stack = ''
for method, filename, lineno, localVars, globalVars in f.frames:
stack += ' File "%s", line %s, in %s\n' % (filename, lineno,
method)
stack += ' %s\n' % (linecache.getline(
filename, lineno).strip(),)
self.assertTracebackFormat(tb,
"Traceback (most recent call last):",
"%s\n%s%s: %s\n" % (failure.EXCEPTION_CAUGHT_HERE, stack,
reflect.qual(f.type), reflect.safe_str(f.value)))
if captureVars:
self.assertEqual(None, re.search('exampleLocalVar.*xyzzy', tb))
def test_printDetailedTraceback(self):
"""
L{printDetailedTraceback} returns a detailed traceback including the
L{Failure}'s count.
"""
self.assertDetailedTraceback()
def test_printBriefTraceback(self):
"""
L{printBriefTraceback} returns a brief traceback.
"""
self.assertBriefTraceback()
def test_printTraceback(self):
"""
L{printTraceback} returns a traceback.
"""
self.assertDefaultTraceback()
def test_printDetailedTracebackCapturedVars(self):
"""
L{printDetailedTraceback} captures the locals and globals for its
stack frames and adds them to the traceback, when called on a
L{Failure} constructed with C{captureVars=True}.
"""
self.assertDetailedTraceback(captureVars=True)
def test_printBriefTracebackCapturedVars(self):
"""
L{printBriefTraceback} returns a brief traceback when called on a
L{Failure} constructed with C{captureVars=True}.
Local variables on the stack can not be seen in the resulting
traceback.
"""
self.assertBriefTraceback(captureVars=True)
def test_printTracebackCapturedVars(self):
"""
L{printTraceback} returns a traceback when called on a L{Failure}
constructed with C{captureVars=True}.
Local variables on the stack can not be seen in the resulting
traceback.
"""
self.assertDefaultTraceback(captureVars=True)
def test_printDetailedTracebackCapturedVarsCleaned(self):
"""
C{printDetailedTraceback} includes information about local variables on
the stack after C{cleanFailure} has been called.
"""
self.assertDetailedTraceback(captureVars=True, cleanFailure=True)
def test_invalidFormatFramesDetail(self):
"""
L{failure.format_frames} raises a L{ValueError} if the supplied
C{detail} level is unknown.
"""
self.assertRaises(ValueError, failure.format_frames, None, None,
detail='noisia')
def testExplictPass(self):
e = RuntimeError()
f = failure.Failure(e)
f.trap(RuntimeError)
self.assertEqual(f.value, e)
def _getInnermostFrameLine(self, f):
try:
f.raiseException()
except ZeroDivisionError:
tb = traceback.extract_tb(sys.exc_info()[2])
return tb[-1][-1]
else:
raise Exception(
"f.raiseException() didn't raise ZeroDivisionError!?")
def testRaiseExceptionWithTB(self):
f = getDivisionFailure()
innerline = self._getInnermostFrameLine(f)
self.assertEqual(innerline, '1/0')
def testLackOfTB(self):
f = getDivisionFailure()
f.cleanFailure()
innerline = self._getInnermostFrameLine(f)
self.assertEqual(innerline, '1/0')
testLackOfTB.todo = "the traceback is not preserved, exarkun said he'll try to fix this! god knows how"
if _PY3:
del testLackOfTB # fix in ticket #6008
def test_stringExceptionConstruction(self):
"""
Constructing a C{Failure} with a string as its exception value raises
a C{TypeError}, as this is no longer supported as of Python 2.6.
"""
exc = self.assertRaises(TypeError, failure.Failure, "ono!")
self.assertIn("Strings are not supported by Failure", str(exc))
def testConstructionFails(self):
"""
Creating a Failure with no arguments causes it to try to discover the
current interpreter exception state. If no such state exists, creating
the Failure should raise a synchronous exception.
"""
self.assertRaises(failure.NoCurrentExceptionError, failure.Failure)
def test_getTracebackObject(self):
"""
If the C{Failure} has not been cleaned, then C{getTracebackObject}
returns the traceback object that captured in its constructor.
"""
f = getDivisionFailure()
self.assertEqual(f.getTracebackObject(), f.tb)
def test_getTracebackObjectFromCaptureVars(self):
"""
C{captureVars=True} has no effect on the result of
C{getTracebackObject}.
"""
try:
1/0
except ZeroDivisionError:
noVarsFailure = failure.Failure()
varsFailure = failure.Failure(captureVars=True)
self.assertEqual(noVarsFailure.getTracebackObject(), varsFailure.tb)
def test_getTracebackObjectFromClean(self):
"""
If the Failure has been cleaned, then C{getTracebackObject} returns an
object that looks the same to L{traceback.extract_tb}.
"""
f = getDivisionFailure()
expected = traceback.extract_tb(f.getTracebackObject())
f.cleanFailure()
observed = traceback.extract_tb(f.getTracebackObject())
self.assertNotEqual(None, expected)
self.assertEqual(expected, observed)
def test_getTracebackObjectFromCaptureVarsAndClean(self):
"""
If the Failure was created with captureVars, then C{getTracebackObject}
returns an object that looks the same to L{traceback.extract_tb}.
"""
f = getDivisionFailure(captureVars=True)
expected = traceback.extract_tb(f.getTracebackObject())
f.cleanFailure()
observed = traceback.extract_tb(f.getTracebackObject())
self.assertEqual(expected, observed)
def test_getTracebackObjectWithoutTraceback(self):
"""
L{failure.Failure}s need not be constructed with traceback objects. If
a C{Failure} has no traceback information at all, C{getTracebackObject}
just returns None.
None is a good value, because traceback.extract_tb(None) -> [].
"""
f = failure.Failure(Exception("some error"))
self.assertEqual(f.getTracebackObject(), None)
def test_tracebackFromExceptionInPython3(self):
"""
If a L{failure.Failure} is constructed with an exception but no
traceback in Python 3, the traceback will be extracted from the
exception's C{__traceback__} attribute.
"""
try:
1/0
except:
klass, exception, tb = sys.exc_info()
f = failure.Failure(exception)
self.assertIdentical(f.tb, tb)
def test_cleanFailureRemovesTracebackInPython3(self):
"""
L{failure.Failure.cleanFailure} sets the C{__traceback__} attribute of
the exception to C{None} in Python 3.
"""
f = getDivisionFailure()
self.assertNotEqual(f.tb, None)
self.assertIdentical(f.value.__traceback__, f.tb)
f.cleanFailure()
self.assertIdentical(f.value.__traceback__, None)
if not _PY3:
test_tracebackFromExceptionInPython3.skip = "Python 3 only."
test_cleanFailureRemovesTracebackInPython3.skip = "Python 3 only."
class BrokenStr(Exception):
"""
An exception class the instances of which cannot be presented as strings via
C{str}.
"""
def __str__(self):
# Could raise something else, but there's no point as yet.
raise self
class BrokenExceptionMetaclass(type):
"""
A metaclass for an exception type which cannot be presented as a string via
C{str}.
"""
def __str__(self):
raise ValueError("You cannot make a string out of me.")
class BrokenExceptionType(Exception, object):
"""
The aforementioned exception type which cnanot be presented as a string via
C{str}.
"""
__metaclass__ = BrokenExceptionMetaclass
class GetTracebackTests(SynchronousTestCase):
"""
Tests for L{Failure.getTraceback}.
"""
def _brokenValueTest(self, detail):
"""
Construct a L{Failure} with an exception that raises an exception from
its C{__str__} method and then call C{getTraceback} with the specified
detail and verify that it returns a string.
"""
x = BrokenStr()
f = failure.Failure(x)
traceback = f.getTraceback(detail=detail)
self.assertIsInstance(traceback, str)
def test_brokenValueBriefDetail(self):
"""
A L{Failure} might wrap an exception with a C{__str__} method which
raises an exception. In this case, calling C{getTraceback} on the
failure with the C{"brief"} detail does not raise an exception.
"""
self._brokenValueTest("brief")
def test_brokenValueDefaultDetail(self):
"""
Like test_brokenValueBriefDetail, but for the C{"default"} detail case.
"""
self._brokenValueTest("default")
def test_brokenValueVerboseDetail(self):
"""
Like test_brokenValueBriefDetail, but for the C{"default"} detail case.
"""
self._brokenValueTest("verbose")
def _brokenTypeTest(self, detail):
"""
Construct a L{Failure} with an exception type that raises an exception
from its C{__str__} method and then call C{getTraceback} with the
specified detail and verify that it returns a string.
"""
f = failure.Failure(BrokenExceptionType())
traceback = f.getTraceback(detail=detail)
self.assertIsInstance(traceback, str)
def test_brokenTypeBriefDetail(self):
"""
A L{Failure} might wrap an exception the type object of which has a
C{__str__} method which raises an exception. In this case, calling
C{getTraceback} on the failure with the C{"brief"} detail does not raise
an exception.
"""
self._brokenTypeTest("brief")
def test_brokenTypeDefaultDetail(self):
"""
Like test_brokenTypeBriefDetail, but for the C{"default"} detail case.
"""
self._brokenTypeTest("default")
def test_brokenTypeVerboseDetail(self):
"""
Like test_brokenTypeBriefDetail, but for the C{"verbose"} detail case.
"""
self._brokenTypeTest("verbose")
class FindFailureTests(SynchronousTestCase):
"""
Tests for functionality related to L{Failure._findFailure}.
"""
def test_findNoFailureInExceptionHandler(self):
"""
Within an exception handler, _findFailure should return
C{None} in case no Failure is associated with the current
exception.
"""
try:
1/0
except:
self.assertEqual(failure.Failure._findFailure(), None)
else:
self.fail("No exception raised from 1/0!?")
def test_findNoFailure(self):
"""
Outside of an exception handler, _findFailure should return None.
"""
self.assertEqual(sys.exc_info()[-1], None) #environment sanity check
self.assertEqual(failure.Failure._findFailure(), None)
def test_findFailure(self):
"""
Within an exception handler, it should be possible to find the
original Failure that caused the current exception (if it was
caused by raiseException).
"""
f = getDivisionFailure()
f.cleanFailure()
try:
f.raiseException()
except:
self.assertEqual(failure.Failure._findFailure(), f)
else:
self.fail("No exception raised from raiseException!?")
def test_failureConstructionFindsOriginalFailure(self):
"""
When a Failure is constructed in the context of an exception
handler that is handling an exception raised by
raiseException, the new Failure should be chained to that
original Failure.
"""
f = getDivisionFailure()
f.cleanFailure()
try:
f.raiseException()
except:
newF = failure.Failure()
self.assertEqual(f.getTraceback(), newF.getTraceback())
else:
self.fail("No exception raised from raiseException!?")
def test_failureConstructionWithMungedStackSucceeds(self):
"""
Pyrex and Cython are known to insert fake stack frames so as to give
more Python-like tracebacks. These stack frames with empty code objects
should not break extraction of the exception.
"""
try:
raiser.raiseException()
except raiser.RaiserException:
f = failure.Failure()
self.assertTrue(f.check(raiser.RaiserException))
else:
self.fail("No exception raised from extension?!")
if raiser is None:
skipMsg = "raiser extension not available"
test_failureConstructionWithMungedStackSucceeds.skip = skipMsg
class TestFormattableTraceback(SynchronousTestCase):
"""
Whitebox tests that show that L{failure._Traceback} constructs objects that
can be used by L{traceback.extract_tb}.
If the objects can be used by L{traceback.extract_tb}, then they can be
formatted using L{traceback.format_tb} and friends.
"""
def test_singleFrame(self):
"""
A C{_Traceback} object constructed with a single frame should be able
to be passed to L{traceback.extract_tb}, and we should get a singleton
list containing a (filename, lineno, methodname, line) tuple.
"""
tb = failure._Traceback([['method', 'filename.py', 123, {}, {}]])
# Note that we don't need to test that extract_tb correctly extracts
# the line's contents. In this case, since filename.py doesn't exist,
# it will just use None.
self.assertEqual(traceback.extract_tb(tb),
[('filename.py', 123, 'method', None)])
def test_manyFrames(self):
"""
A C{_Traceback} object constructed with multiple frames should be able
to be passed to L{traceback.extract_tb}, and we should get a list
containing a tuple for each frame.
"""
tb = failure._Traceback([
['method1', 'filename.py', 123, {}, {}],
['method2', 'filename.py', 235, {}, {}]])
self.assertEqual(traceback.extract_tb(tb),
[('filename.py', 123, 'method1', None),
('filename.py', 235, 'method2', None)])
class TestFrameAttributes(SynchronousTestCase):
"""
_Frame objects should possess some basic attributes that qualify them as
fake python Frame objects.
"""
def test_fakeFrameAttributes(self):
"""
L{_Frame} instances have the C{f_globals} and C{f_locals} attributes
bound to C{dict} instance. They also have the C{f_code} attribute
bound to something like a code object.
"""
frame = failure._Frame("dummyname", "dummyfilename")
self.assertIsInstance(frame.f_globals, dict)
self.assertIsInstance(frame.f_locals, dict)
self.assertIsInstance(frame.f_code, failure._Code)
class TestDebugMode(SynchronousTestCase):
"""
Failure's debug mode should allow jumping into the debugger.
"""
def setUp(self):
"""
Override pdb.post_mortem so we can make sure it's called.
"""
# Make sure any changes we make are reversed:
post_mortem = pdb.post_mortem
if _PY3:
origInit = failure.Failure.__init__
else:
origInit = failure.Failure.__dict__['__init__']
def restore():
pdb.post_mortem = post_mortem
if _PY3:
failure.Failure.__init__ = origInit
else:
failure.Failure.__dict__['__init__'] = origInit
self.addCleanup(restore)
self.result = []
pdb.post_mortem = self.result.append
failure.startDebugMode()
def test_regularFailure(self):
"""
If startDebugMode() is called, calling Failure() will first call
pdb.post_mortem with the traceback.
"""
try:
1/0
except:
typ, exc, tb = sys.exc_info()
f = failure.Failure()
self.assertEqual(self.result, [tb])
self.assertEqual(f.captureVars, False)
def test_captureVars(self):
"""
If startDebugMode() is called, passing captureVars to Failure() will
not blow up.
"""
try:
1/0
except:
typ, exc, tb = sys.exc_info()
f = failure.Failure(captureVars=True)
self.assertEqual(self.result, [tb])
self.assertEqual(f.captureVars, True)
class ExtendedGeneratorTests(SynchronousTestCase):
"""
Tests C{failure.Failure} support for generator features added in Python 2.5
"""
def _throwIntoGenerator(self, f, g):
try:
f.throwExceptionIntoGenerator(g)
except StopIteration:
pass
else:
self.fail("throwExceptionIntoGenerator should have raised "
"StopIteration")
def test_throwExceptionIntoGenerator(self):
"""
It should be possible to throw the exception that a Failure
represents into a generator.
"""
stuff = []
def generator():
try:
yield
except:
stuff.append(sys.exc_info())
else:
self.fail("Yield should have yielded exception.")
g = generator()
f = getDivisionFailure()
next(g)
self._throwIntoGenerator(f, g)
self.assertEqual(stuff[0][0], ZeroDivisionError)
self.assertTrue(isinstance(stuff[0][1], ZeroDivisionError))
self.assertEqual(traceback.extract_tb(stuff[0][2])[-1][-1], "1/0")
def test_findFailureInGenerator(self):
"""
Within an exception handler, it should be possible to find the
original Failure that caused the current exception (if it was
caused by throwExceptionIntoGenerator).
"""
f = getDivisionFailure()
f.cleanFailure()
foundFailures = []
def generator():
try:
yield
except:
foundFailures.append(failure.Failure._findFailure())
else:
self.fail("No exception sent to generator")
g = generator()
next(g)
self._throwIntoGenerator(f, g)
self.assertEqual(foundFailures, [f])
def test_failureConstructionFindsOriginalFailure(self):
"""
When a Failure is constructed in the context of an exception
handler that is handling an exception raised by
throwExceptionIntoGenerator, the new Failure should be chained to that
original Failure.
"""
f = getDivisionFailure()
f.cleanFailure()
newFailures = []
def generator():
try:
yield
except:
newFailures.append(failure.Failure())
else:
self.fail("No exception sent to generator")
g = generator()
next(g)
self._throwIntoGenerator(f, g)
self.assertEqual(len(newFailures), 1)
self.assertEqual(newFailures[0].getTraceback(), f.getTraceback())
if _PY3:
test_findFailureInGenerator.todo = (
"Python 3 support to be fixed in #5949")
test_failureConstructionFindsOriginalFailure.todo = (
"Python 3 support to be fixed in #5949")
# Remove these two lines in #6008 (unittest todo support):
del test_findFailureInGenerator
del test_failureConstructionFindsOriginalFailure
def test_ambiguousFailureInGenerator(self):
"""
When a generator reraises a different exception,
L{Failure._findFailure} inside the generator should find the reraised
exception rather than original one.
"""
def generator():
try:
try:
yield
except:
[][1]
except:
self.assertIsInstance(failure.Failure().value, IndexError)
g = generator()
next(g)
f = getDivisionFailure()
self._throwIntoGenerator(f, g)
def test_ambiguousFailureFromGenerator(self):
"""
When a generator reraises a different exception,
L{Failure._findFailure} above the generator should find the reraised
exception rather than original one.
"""
def generator():
try:
yield
except:
[][1]
g = generator()
next(g)
f = getDivisionFailure()
try:
self._throwIntoGenerator(f, g)
except:
self.assertIsInstance(failure.Failure().value, IndexError)

View file

@ -0,0 +1,266 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.internet.fdesc}.
"""
import os, sys
import errno
try:
import fcntl
except ImportError:
skip = "not supported on this platform"
else:
from twisted.internet import fdesc
from twisted.python.util import untilConcludes
from twisted.trial import unittest
class NonBlockingTestCase(unittest.SynchronousTestCase):
"""
Tests for L{fdesc.setNonBlocking} and L{fdesc.setBlocking}.
"""
def test_setNonBlocking(self):
"""
L{fdesc.setNonBlocking} sets a file description to non-blocking.
"""
r, w = os.pipe()
self.addCleanup(os.close, r)
self.addCleanup(os.close, w)
self.assertFalse(fcntl.fcntl(r, fcntl.F_GETFL) & os.O_NONBLOCK)
fdesc.setNonBlocking(r)
self.assertTrue(fcntl.fcntl(r, fcntl.F_GETFL) & os.O_NONBLOCK)
def test_setBlocking(self):
"""
L{fdesc.setBlocking} sets a file description to blocking.
"""
r, w = os.pipe()
self.addCleanup(os.close, r)
self.addCleanup(os.close, w)
fdesc.setNonBlocking(r)
fdesc.setBlocking(r)
self.assertFalse(fcntl.fcntl(r, fcntl.F_GETFL) & os.O_NONBLOCK)
class ReadWriteTestCase(unittest.SynchronousTestCase):
"""
Tests for L{fdesc.readFromFD}, L{fdesc.writeToFD}.
"""
def setUp(self):
"""
Create a non-blocking pipe that can be used in tests.
"""
self.r, self.w = os.pipe()
fdesc.setNonBlocking(self.r)
fdesc.setNonBlocking(self.w)
def tearDown(self):
"""
Close pipes.
"""
try:
os.close(self.w)
except OSError:
pass
try:
os.close(self.r)
except OSError:
pass
def write(self, d):
"""
Write data to the pipe.
"""
return fdesc.writeToFD(self.w, d)
def read(self):
"""
Read data from the pipe.
"""
l = []
res = fdesc.readFromFD(self.r, l.append)
if res is None:
if l:
return l[0]
else:
return b""
else:
return res
def test_writeAndRead(self):
"""
Test that the number of bytes L{fdesc.writeToFD} reports as written
with its return value are seen by L{fdesc.readFromFD}.
"""
n = self.write(b"hello")
self.failUnless(n > 0)
s = self.read()
self.assertEqual(len(s), n)
self.assertEqual(b"hello"[:n], s)
def test_writeAndReadLarge(self):
"""
Similar to L{test_writeAndRead}, but use a much larger string to verify
the behavior for that case.
"""
orig = b"0123456879" * 10000
written = self.write(orig)
self.failUnless(written > 0)
result = []
resultlength = 0
i = 0
while resultlength < written or i < 50:
result.append(self.read())
resultlength += len(result[-1])
# Increment a counter to be sure we'll exit at some point
i += 1
result = b"".join(result)
self.assertEqual(len(result), written)
self.assertEqual(orig[:written], result)
def test_readFromEmpty(self):
"""
Verify that reading from a file descriptor with no data does not raise
an exception and does not result in the callback function being called.
"""
l = []
result = fdesc.readFromFD(self.r, l.append)
self.assertEqual(l, [])
self.assertEqual(result, None)
def test_readFromCleanClose(self):
"""
Test that using L{fdesc.readFromFD} on a cleanly closed file descriptor
returns a connection done indicator.
"""
os.close(self.w)
self.assertEqual(self.read(), fdesc.CONNECTION_DONE)
def test_writeToClosed(self):
"""
Verify that writing with L{fdesc.writeToFD} when the read end is closed
results in a connection lost indicator.
"""
os.close(self.r)
self.assertEqual(self.write(b"s"), fdesc.CONNECTION_LOST)
def test_readFromInvalid(self):
"""
Verify that reading with L{fdesc.readFromFD} when the read end is
closed results in a connection lost indicator.
"""
os.close(self.r)
self.assertEqual(self.read(), fdesc.CONNECTION_LOST)
def test_writeToInvalid(self):
"""
Verify that writing with L{fdesc.writeToFD} when the write end is
closed results in a connection lost indicator.
"""
os.close(self.w)
self.assertEqual(self.write(b"s"), fdesc.CONNECTION_LOST)
def test_writeErrors(self):
"""
Test error path for L{fdesc.writeTod}.
"""
oldOsWrite = os.write
def eagainWrite(fd, data):
err = OSError()
err.errno = errno.EAGAIN
raise err
os.write = eagainWrite
try:
self.assertEqual(self.write(b"s"), 0)
finally:
os.write = oldOsWrite
def eintrWrite(fd, data):
err = OSError()
err.errno = errno.EINTR
raise err
os.write = eintrWrite
try:
self.assertEqual(self.write(b"s"), 0)
finally:
os.write = oldOsWrite
class CloseOnExecTests(unittest.SynchronousTestCase):
"""
Tests for L{fdesc._setCloseOnExec} and L{fdesc._unsetCloseOnExec}.
"""
program = '''
import os, errno
try:
os.write(%d, b'lul')
except OSError as e:
if e.errno == errno.EBADF:
os._exit(0)
os._exit(5)
except:
os._exit(10)
else:
os._exit(20)
'''
def _execWithFileDescriptor(self, fObj):
pid = os.fork()
if pid == 0:
try:
os.execv(sys.executable, [sys.executable, '-c', self.program % (fObj.fileno(),)])
except:
import traceback
traceback.print_exc()
os._exit(30)
else:
# On Linux wait(2) doesn't seem ever able to fail with EINTR but
# POSIX seems to allow it and on OS X it happens quite a lot.
return untilConcludes(os.waitpid, pid, 0)[1]
def test_setCloseOnExec(self):
"""
A file descriptor passed to L{fdesc._setCloseOnExec} is not inherited
by a new process image created with one of the exec family of
functions.
"""
with open(self.mktemp(), 'wb') as fObj:
fdesc._setCloseOnExec(fObj.fileno())
status = self._execWithFileDescriptor(fObj)
self.assertTrue(os.WIFEXITED(status))
self.assertEqual(os.WEXITSTATUS(status), 0)
def test_unsetCloseOnExec(self):
"""
A file descriptor passed to L{fdesc._unsetCloseOnExec} is inherited by
a new process image created with one of the exec family of functions.
"""
with open(self.mktemp(), 'wb') as fObj:
fdesc._setCloseOnExec(fObj.fileno())
fdesc._unsetCloseOnExec(fObj.fileno())
status = self._execWithFileDescriptor(fObj)
self.assertTrue(os.WIFEXITED(status))
self.assertEqual(os.WEXITSTATUS(status), 20)

View file

@ -0,0 +1,67 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.protocols.finger}.
"""
from twisted.trial import unittest
from twisted.protocols import finger
from twisted.test.proto_helpers import StringTransport
class FingerTestCase(unittest.TestCase):
"""
Tests for L{finger.Finger}.
"""
def setUp(self):
"""
Create and connect a L{finger.Finger} instance.
"""
self.transport = StringTransport()
self.protocol = finger.Finger()
self.protocol.makeConnection(self.transport)
def test_simple(self):
"""
When L{finger.Finger} receives a CR LF terminated line, it responds
with the default user status message - that no such user exists.
"""
self.protocol.dataReceived("moshez\r\n")
self.assertEqual(
self.transport.value(),
"Login: moshez\nNo such user\n")
def test_simpleW(self):
"""
The behavior for a query which begins with C{"/w"} is the same as the
behavior for one which does not. The user is reported as not existing.
"""
self.protocol.dataReceived("/w moshez\r\n")
self.assertEqual(
self.transport.value(),
"Login: moshez\nNo such user\n")
def test_forwarding(self):
"""
When L{finger.Finger} receives a request for a remote user, it responds
with a message rejecting the request.
"""
self.protocol.dataReceived("moshez@example.com\r\n")
self.assertEqual(
self.transport.value(),
"Finger forwarding service denied\n")
def test_list(self):
"""
When L{finger.Finger} receives a blank line, it responds with a message
rejecting the request for all online users.
"""
self.protocol.dataReceived("\r\n")
self.assertEqual(
self.transport.value(),
"Finger online list denied\n")

View file

@ -0,0 +1,98 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for formmethod module.
"""
from twisted.trial import unittest
from twisted.python import formmethod
class ArgumentTestCase(unittest.TestCase):
def argTest(self, argKlass, testPairs, badValues, *args, **kwargs):
arg = argKlass("name", *args, **kwargs)
for val, result in testPairs:
self.assertEqual(arg.coerce(val), result)
for val in badValues:
self.assertRaises(formmethod.InputError, arg.coerce, val)
def test_argument(self):
"""
Test that corce correctly raises NotImplementedError.
"""
arg = formmethod.Argument("name")
self.assertRaises(NotImplementedError, arg.coerce, "")
def testString(self):
self.argTest(formmethod.String, [("a", "a"), (1, "1"), ("", "")], ())
self.argTest(formmethod.String, [("ab", "ab"), ("abc", "abc")], ("2", ""), min=2)
self.argTest(formmethod.String, [("ab", "ab"), ("a", "a")], ("223213", "345x"), max=3)
self.argTest(formmethod.String, [("ab", "ab"), ("add", "add")], ("223213", "x"), min=2, max=3)
def testInt(self):
self.argTest(formmethod.Integer, [("3", 3), ("-2", -2), ("", None)], ("q", "2.3"))
self.argTest(formmethod.Integer, [("3", 3), ("-2", -2)], ("q", "2.3", ""), allowNone=0)
def testFloat(self):
self.argTest(formmethod.Float, [("3", 3.0), ("-2.3", -2.3), ("", None)], ("q", "2.3z"))
self.argTest(formmethod.Float, [("3", 3.0), ("-2.3", -2.3)], ("q", "2.3z", ""),
allowNone=0)
def testChoice(self):
choices = [("a", "apple", "an apple"),
("b", "banana", "ook")]
self.argTest(formmethod.Choice, [("a", "apple"), ("b", "banana")],
("c", 1), choices=choices)
def testFlags(self):
flags = [("a", "apple", "an apple"),
("b", "banana", "ook")]
self.argTest(formmethod.Flags,
[(["a"], ["apple"]), (["b", "a"], ["banana", "apple"])],
(["a", "c"], ["fdfs"]),
flags=flags)
def testBoolean(self):
tests = [("yes", 1), ("", 0), ("False", 0), ("no", 0)]
self.argTest(formmethod.Boolean, tests, ())
def test_file(self):
"""
Test the correctness of the coerce function.
"""
arg = formmethod.File("name", allowNone=0)
self.assertEqual(arg.coerce("something"), "something")
self.assertRaises(formmethod.InputError, arg.coerce, None)
arg2 = formmethod.File("name")
self.assertEqual(arg2.coerce(None), None)
def testDate(self):
goodTests = {
("2002", "12", "21"): (2002, 12, 21),
("1996", "2", "29"): (1996, 2, 29),
("", "", ""): None,
}.items()
badTests = [("2002", "2", "29"), ("xx", "2", "3"),
("2002", "13", "1"), ("1999", "12","32"),
("2002", "1"), ("2002", "2", "3", "4")]
self.argTest(formmethod.Date, goodTests, badTests)
def testRangedInteger(self):
goodTests = {"0": 0, "12": 12, "3": 3}.items()
badTests = ["-1", "x", "13", "-2000", "3.4"]
self.argTest(formmethod.IntegerRange, goodTests, badTests, 0, 12)
def testVerifiedPassword(self):
goodTests = {("foo", "foo"): "foo", ("ab", "ab"): "ab"}.items()
badTests = [("ab", "a"), ("12345", "12345"), ("", ""), ("a", "a"), ("a",), ("a", "a", "a")]
self.argTest(formmethod.VerifiedPassword, goodTests, badTests, min=2, max=4)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,80 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.tap.ftp}.
"""
from twisted.trial.unittest import TestCase
from twisted.cred import credentials, error
from twisted.tap.ftp import Options
from twisted.python import versions
from twisted.python.filepath import FilePath
class FTPOptionsTestCase(TestCase):
"""
Tests for the command line option parser used for C{twistd ftp}.
"""
usernamePassword = ('iamuser', 'thisispassword')
def setUp(self):
"""
Create a file with two users.
"""
self.filename = self.mktemp()
f = FilePath(self.filename)
f.setContent(':'.join(self.usernamePassword))
self.options = Options()
def test_passwordfileDeprecation(self):
"""
The C{--password-file} option will emit a warning stating that
said option is deprecated.
"""
self.callDeprecated(
versions.Version("Twisted", 11, 1, 0),
self.options.opt_password_file, self.filename)
def test_authAdded(self):
"""
The C{--auth} command-line option will add a checker to the list of
checkers
"""
numCheckers = len(self.options['credCheckers'])
self.options.parseOptions(['--auth', 'file:' + self.filename])
self.assertEqual(len(self.options['credCheckers']), numCheckers + 1)
def test_authFailure(self):
"""
The checker created by the C{--auth} command-line option returns a
L{Deferred} that fails with L{UnauthorizedLogin} when
presented with credentials that are unknown to that checker.
"""
self.options.parseOptions(['--auth', 'file:' + self.filename])
checker = self.options['credCheckers'][-1]
invalid = credentials.UsernamePassword(self.usernamePassword[0], 'fake')
return (checker.requestAvatarId(invalid)
.addCallbacks(
lambda ignore: self.fail("Wrong password should raise error"),
lambda err: err.trap(error.UnauthorizedLogin)))
def test_authSuccess(self):
"""
The checker created by the C{--auth} command-line option returns a
L{Deferred} that returns the avatar id when presented with credentials
that are known to that checker.
"""
self.options.parseOptions(['--auth', 'file:' + self.filename])
checker = self.options['credCheckers'][-1]
correct = credentials.UsernamePassword(*self.usernamePassword)
return checker.requestAvatarId(correct).addCallback(
lambda username: self.assertEqual(username, correct.username)
)

View file

@ -0,0 +1,150 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for twisted.hook module.
"""
from twisted.python import hook
from twisted.trial import unittest
class BaseClass:
"""
dummy class to help in testing.
"""
def __init__(self):
"""
dummy initializer
"""
self.calledBasePre = 0
self.calledBasePost = 0
self.calledBase = 0
def func(self, a, b):
"""
dummy method
"""
assert a == 1
assert b == 2
self.calledBase = self.calledBase + 1
class SubClass(BaseClass):
"""
another dummy class
"""
def __init__(self):
"""
another dummy initializer
"""
BaseClass.__init__(self)
self.calledSubPre = 0
self.calledSubPost = 0
self.calledSub = 0
def func(self, a, b):
"""
another dummy function
"""
assert a == 1
assert b == 2
BaseClass.func(self, a, b)
self.calledSub = self.calledSub + 1
_clean_BaseClass = BaseClass.__dict__.copy()
_clean_SubClass = SubClass.__dict__.copy()
def basePre(base, a, b):
"""
a pre-hook for the base class
"""
base.calledBasePre = base.calledBasePre + 1
def basePost(base, a, b):
"""
a post-hook for the base class
"""
base.calledBasePost = base.calledBasePost + 1
def subPre(sub, a, b):
"""
a pre-hook for the subclass
"""
sub.calledSubPre = sub.calledSubPre + 1
def subPost(sub, a, b):
"""
a post-hook for the subclass
"""
sub.calledSubPost = sub.calledSubPost + 1
class HookTestCase(unittest.TestCase):
"""
test case to make sure hooks are called
"""
def setUp(self):
"""Make sure we have clean versions of our classes."""
BaseClass.__dict__.clear()
BaseClass.__dict__.update(_clean_BaseClass)
SubClass.__dict__.clear()
SubClass.__dict__.update(_clean_SubClass)
def testBaseHook(self):
"""make sure that the base class's hook is called reliably
"""
base = BaseClass()
self.assertEqual(base.calledBase, 0)
self.assertEqual(base.calledBasePre, 0)
base.func(1,2)
self.assertEqual(base.calledBase, 1)
self.assertEqual(base.calledBasePre, 0)
hook.addPre(BaseClass, "func", basePre)
base.func(1, b=2)
self.assertEqual(base.calledBase, 2)
self.assertEqual(base.calledBasePre, 1)
hook.addPost(BaseClass, "func", basePost)
base.func(1, b=2)
self.assertEqual(base.calledBasePost, 1)
self.assertEqual(base.calledBase, 3)
self.assertEqual(base.calledBasePre, 2)
hook.removePre(BaseClass, "func", basePre)
hook.removePost(BaseClass, "func", basePost)
base.func(1, b=2)
self.assertEqual(base.calledBasePost, 1)
self.assertEqual(base.calledBase, 4)
self.assertEqual(base.calledBasePre, 2)
def testSubHook(self):
"""test interactions between base-class hooks and subclass hooks
"""
sub = SubClass()
self.assertEqual(sub.calledSub, 0)
self.assertEqual(sub.calledBase, 0)
sub.func(1, b=2)
self.assertEqual(sub.calledSub, 1)
self.assertEqual(sub.calledBase, 1)
hook.addPre(SubClass, 'func', subPre)
self.assertEqual(sub.calledSub, 1)
self.assertEqual(sub.calledBase, 1)
self.assertEqual(sub.calledSubPre, 0)
self.assertEqual(sub.calledBasePre, 0)
sub.func(1, b=2)
self.assertEqual(sub.calledSub, 2)
self.assertEqual(sub.calledBase, 2)
self.assertEqual(sub.calledSubPre, 1)
self.assertEqual(sub.calledBasePre, 0)
# let the pain begin
hook.addPre(BaseClass, 'func', basePre)
BaseClass.func(sub, 1, b=2)
# sub.func(1, b=2)
self.assertEqual(sub.calledBase, 3)
self.assertEqual(sub.calledBasePre, 1, str(sub.calledBasePre))
sub.func(1, b=2)
self.assertEqual(sub.calledBasePre, 2)
self.assertEqual(sub.calledBase, 4)
self.assertEqual(sub.calledSubPre, 2)
self.assertEqual(sub.calledSub, 3)
testCases = [HookTestCase]

View file

@ -0,0 +1,109 @@
# -*- Python -*-
__version__ = '$Revision: 1.3 $'[11:-2]
from twisted.trial import unittest
from twisted.protocols import htb
class DummyClock:
time = 0
def set(self, when):
self.time = when
def __call__(self):
return self.time
class SomeBucket(htb.Bucket):
maxburst = 100
rate = 2
class TestBucketBase(unittest.TestCase):
def setUp(self):
self._realTimeFunc = htb.time
self.clock = DummyClock()
htb.time = self.clock
def tearDown(self):
htb.time = self._realTimeFunc
class TestBucket(TestBucketBase):
def testBucketSize(self):
"""Testing the size of the bucket."""
b = SomeBucket()
fit = b.add(1000)
self.assertEqual(100, fit)
def testBucketDrain(self):
"""Testing the bucket's drain rate."""
b = SomeBucket()
fit = b.add(1000)
self.clock.set(10)
fit = b.add(1000)
self.assertEqual(20, fit)
def test_bucketEmpty(self):
"""
L{htb.Bucket.drip} returns C{True} if the bucket is empty after that drip.
"""
b = SomeBucket()
b.add(20)
self.clock.set(9)
empty = b.drip()
self.assertFalse(empty)
self.clock.set(10)
empty = b.drip()
self.assertTrue(empty)
class TestBucketNesting(TestBucketBase):
def setUp(self):
TestBucketBase.setUp(self)
self.parent = SomeBucket()
self.child1 = SomeBucket(self.parent)
self.child2 = SomeBucket(self.parent)
def testBucketParentSize(self):
# Use up most of the parent bucket.
self.child1.add(90)
fit = self.child2.add(90)
self.assertEqual(10, fit)
def testBucketParentRate(self):
# Make the parent bucket drain slower.
self.parent.rate = 1
# Fill both child1 and parent.
self.child1.add(100)
self.clock.set(10)
fit = self.child1.add(100)
# How much room was there? The child bucket would have had 20,
# but the parent bucket only ten (so no, it wouldn't make too much
# sense to have a child bucket draining faster than its parent in a real
# application.)
self.assertEqual(10, fit)
# TODO: Test the Transport stuff?
from test_pcp import DummyConsumer
class ConsumerShaperTest(TestBucketBase):
def setUp(self):
TestBucketBase.setUp(self)
self.underlying = DummyConsumer()
self.bucket = SomeBucket()
self.shaped = htb.ShapedConsumer(self.underlying, self.bucket)
def testRate(self):
# Start off with a full bucket, so the burst-size dosen't factor in
# to the calculations.
delta_t = 10
self.bucket.add(100)
self.shaped.write("x" * 100)
self.clock.set(delta_t)
self.shaped.resumeProducing()
self.assertEqual(len(self.underlying.getvalue()),
delta_t * self.bucket.rate)
def testBucketRefs(self):
self.assertEqual(self.bucket._refcount, 1)
self.shaped.stopProducing()
self.assertEqual(self.bucket._refcount, 0)

View file

@ -0,0 +1,194 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for twisted.protocols.ident module.
"""
import struct
from twisted.protocols import ident
from twisted.python import failure
from twisted.internet import error
from twisted.internet import defer
from twisted.trial import unittest
from twisted.test.proto_helpers import StringTransport
class ClassParserTestCase(unittest.TestCase):
"""
Test parsing of ident responses.
"""
def setUp(self):
"""
Create a ident client used in tests.
"""
self.client = ident.IdentClient()
def test_indentError(self):
"""
'UNKNOWN-ERROR' error should map to the L{ident.IdentError} exception.
"""
d = defer.Deferred()
self.client.queries.append((d, 123, 456))
self.client.lineReceived('123, 456 : ERROR : UNKNOWN-ERROR')
return self.assertFailure(d, ident.IdentError)
def test_noUSerError(self):
"""
'NO-USER' error should map to the L{ident.NoUser} exception.
"""
d = defer.Deferred()
self.client.queries.append((d, 234, 456))
self.client.lineReceived('234, 456 : ERROR : NO-USER')
return self.assertFailure(d, ident.NoUser)
def test_invalidPortError(self):
"""
'INVALID-PORT' error should map to the L{ident.InvalidPort} exception.
"""
d = defer.Deferred()
self.client.queries.append((d, 345, 567))
self.client.lineReceived('345, 567 : ERROR : INVALID-PORT')
return self.assertFailure(d, ident.InvalidPort)
def test_hiddenUserError(self):
"""
'HIDDEN-USER' error should map to the L{ident.HiddenUser} exception.
"""
d = defer.Deferred()
self.client.queries.append((d, 567, 789))
self.client.lineReceived('567, 789 : ERROR : HIDDEN-USER')
return self.assertFailure(d, ident.HiddenUser)
def test_lostConnection(self):
"""
A pending query which failed because of a ConnectionLost should
receive an L{ident.IdentError}.
"""
d = defer.Deferred()
self.client.queries.append((d, 765, 432))
self.client.connectionLost(failure.Failure(error.ConnectionLost()))
return self.assertFailure(d, ident.IdentError)
class TestIdentServer(ident.IdentServer):
def lookup(self, serverAddress, clientAddress):
return self.resultValue
class TestErrorIdentServer(ident.IdentServer):
def lookup(self, serverAddress, clientAddress):
raise self.exceptionType()
class NewException(RuntimeError):
pass
class ServerParserTestCase(unittest.TestCase):
def testErrors(self):
p = TestErrorIdentServer()
p.makeConnection(StringTransport())
L = []
p.sendLine = L.append
p.exceptionType = ident.IdentError
p.lineReceived('123, 345')
self.assertEqual(L[0], '123, 345 : ERROR : UNKNOWN-ERROR')
p.exceptionType = ident.NoUser
p.lineReceived('432, 210')
self.assertEqual(L[1], '432, 210 : ERROR : NO-USER')
p.exceptionType = ident.InvalidPort
p.lineReceived('987, 654')
self.assertEqual(L[2], '987, 654 : ERROR : INVALID-PORT')
p.exceptionType = ident.HiddenUser
p.lineReceived('756, 827')
self.assertEqual(L[3], '756, 827 : ERROR : HIDDEN-USER')
p.exceptionType = NewException
p.lineReceived('987, 789')
self.assertEqual(L[4], '987, 789 : ERROR : UNKNOWN-ERROR')
errs = self.flushLoggedErrors(NewException)
self.assertEqual(len(errs), 1)
for port in -1, 0, 65536, 65537:
del L[:]
p.lineReceived('%d, 5' % (port,))
p.lineReceived('5, %d' % (port,))
self.assertEqual(
L, ['%d, 5 : ERROR : INVALID-PORT' % (port,),
'5, %d : ERROR : INVALID-PORT' % (port,)])
def testSuccess(self):
p = TestIdentServer()
p.makeConnection(StringTransport())
L = []
p.sendLine = L.append
p.resultValue = ('SYS', 'USER')
p.lineReceived('123, 456')
self.assertEqual(L[0], '123, 456 : USERID : SYS : USER')
if struct.pack('=L', 1)[0] == '\x01':
_addr1 = '0100007F'
_addr2 = '04030201'
else:
_addr1 = '7F000001'
_addr2 = '01020304'
class ProcMixinTestCase(unittest.TestCase):
line = ('4: %s:0019 %s:02FA 0A 00000000:00000000 '
'00:00000000 00000000 0 0 10927 1 f72a5b80 '
'3000 0 0 2 -1') % (_addr1, _addr2)
def testDottedQuadFromHexString(self):
p = ident.ProcServerMixin()
self.assertEqual(p.dottedQuadFromHexString(_addr1), '127.0.0.1')
def testUnpackAddress(self):
p = ident.ProcServerMixin()
self.assertEqual(p.unpackAddress(_addr1 + ':0277'),
('127.0.0.1', 631))
def testLineParser(self):
p = ident.ProcServerMixin()
self.assertEqual(
p.parseLine(self.line),
(('127.0.0.1', 25), ('1.2.3.4', 762), 0))
def testExistingAddress(self):
username = []
p = ident.ProcServerMixin()
p.entries = lambda: iter([self.line])
p.getUsername = lambda uid: (username.append(uid), 'root')[1]
self.assertEqual(
p.lookup(('127.0.0.1', 25), ('1.2.3.4', 762)),
(p.SYSTEM_NAME, 'root'))
self.assertEqual(username, [0])
def testNonExistingAddress(self):
p = ident.ProcServerMixin()
p.entries = lambda: iter([self.line])
self.assertRaises(ident.NoUser, p.lookup, ('127.0.0.1', 26),
('1.2.3.4', 762))
self.assertRaises(ident.NoUser, p.lookup, ('127.0.0.1', 25),
('1.2.3.5', 762))
self.assertRaises(ident.NoUser, p.lookup, ('127.0.0.1', 25),
('1.2.3.4', 763))

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,25 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.test.iosim}.
"""
from twisted.test.iosim import FakeTransport
from twisted.trial.unittest import TestCase
class FakeTransportTests(TestCase):
"""
Tests for L{FakeTransport}
"""
def test_connectionSerial(self):
"""
Each L{FakeTransport} receives a serial number that uniquely identifies
it.
"""
a = FakeTransport(object(), True)
b = FakeTransport(object(), False)
self.assertIsInstance(a.serial, int)
self.assertIsInstance(b.serial, int)
self.assertNotEqual(a.serial, b.serial)

View file

@ -0,0 +1,351 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test running processes with the APIs in L{twisted.internet.utils}.
"""
from __future__ import division, absolute_import
import warnings, os, stat, sys, signal
from twisted.python.compat import _PY3
from twisted.python.runtime import platform
from twisted.trial import unittest
from twisted.internet import error, reactor, utils, interfaces
from twisted.internet.defer import Deferred
from twisted.python.test.test_util import SuppressedWarningsTests
class ProcessUtilsTests(unittest.TestCase):
"""
Test running a process using L{getProcessOutput}, L{getProcessValue}, and
L{getProcessOutputAndValue}.
"""
if interfaces.IReactorProcess(reactor, None) is None:
skip = "reactor doesn't implement IReactorProcess"
output = None
value = None
exe = sys.executable
def makeSourceFile(self, sourceLines):
"""
Write the given list of lines to a text file and return the absolute
path to it.
"""
script = self.mktemp()
scriptFile = file(script, 'wt')
scriptFile.write(os.linesep.join(sourceLines) + os.linesep)
scriptFile.close()
return os.path.abspath(script)
def test_output(self):
"""
L{getProcessOutput} returns a L{Deferred} which fires with the complete
output of the process it runs after that process exits.
"""
scriptFile = self.makeSourceFile([
"import sys",
"for s in 'hello world\\n':",
" sys.stdout.write(s)",
" sys.stdout.flush()"])
d = utils.getProcessOutput(self.exe, ['-u', scriptFile])
return d.addCallback(self.assertEqual, "hello world\n")
def test_outputWithErrorIgnored(self):
"""
The L{Deferred} returned by L{getProcessOutput} is fired with an
L{IOError} L{Failure} if the child process writes to stderr.
"""
# make sure stderr raises an error normally
scriptFile = self.makeSourceFile([
'import sys',
'sys.stderr.write("hello world\\n")'
])
d = utils.getProcessOutput(self.exe, ['-u', scriptFile])
d = self.assertFailure(d, IOError)
def cbFailed(err):
return self.assertFailure(err.processEnded, error.ProcessDone)
d.addCallback(cbFailed)
return d
def test_outputWithErrorCollected(self):
"""
If a C{True} value is supplied for the C{errortoo} parameter to
L{getProcessOutput}, the returned L{Deferred} fires with the child's
stderr output as well as its stdout output.
"""
scriptFile = self.makeSourceFile([
'import sys',
# Write the same value to both because ordering isn't guaranteed so
# this simplifies the test.
'sys.stdout.write("foo")',
'sys.stdout.flush()',
'sys.stderr.write("foo")',
'sys.stderr.flush()'])
d = utils.getProcessOutput(self.exe, ['-u', scriptFile], errortoo=True)
return d.addCallback(self.assertEqual, "foofoo")
def test_value(self):
"""
The L{Deferred} returned by L{getProcessValue} is fired with the exit
status of the child process.
"""
scriptFile = self.makeSourceFile(["raise SystemExit(1)"])
d = utils.getProcessValue(self.exe, ['-u', scriptFile])
return d.addCallback(self.assertEqual, 1)
def test_outputAndValue(self):
"""
The L{Deferred} returned by L{getProcessOutputAndValue} fires with a
three-tuple, the elements of which give the data written to the child's
stdout, the data written to the child's stderr, and the exit status of
the child.
"""
exe = sys.executable
scriptFile = self.makeSourceFile([
"import sys",
"sys.stdout.write('hello world!\\n')",
"sys.stderr.write('goodbye world!\\n')",
"sys.exit(1)"
])
def gotOutputAndValue(out_err_code):
out, err, code = out_err_code
self.assertEqual(out, "hello world!\n")
self.assertEqual(err, "goodbye world!" + os.linesep)
self.assertEqual(code, 1)
d = utils.getProcessOutputAndValue(self.exe, ["-u", scriptFile])
return d.addCallback(gotOutputAndValue)
def test_outputSignal(self):
"""
If the child process exits because of a signal, the L{Deferred}
returned by L{getProcessOutputAndValue} fires a L{Failure} of a tuple
containing the the child's stdout, stderr, and the signal which caused
it to exit.
"""
# Use SIGKILL here because it's guaranteed to be delivered. Using
# SIGHUP might not work in, e.g., a buildbot slave run under the
# 'nohup' command.
scriptFile = self.makeSourceFile([
"import sys, os, signal",
"sys.stdout.write('stdout bytes\\n')",
"sys.stderr.write('stderr bytes\\n')",
"sys.stdout.flush()",
"sys.stderr.flush()",
"os.kill(os.getpid(), signal.SIGKILL)"])
def gotOutputAndValue(out_err_sig):
out, err, sig = out_err_sig
self.assertEqual(out, "stdout bytes\n")
self.assertEqual(err, "stderr bytes\n")
self.assertEqual(sig, signal.SIGKILL)
d = utils.getProcessOutputAndValue(self.exe, ['-u', scriptFile])
d = self.assertFailure(d, tuple)
return d.addCallback(gotOutputAndValue)
if platform.isWindows():
test_outputSignal.skip = "Windows doesn't have real signals."
def _pathTest(self, utilFunc, check):
dir = os.path.abspath(self.mktemp())
os.makedirs(dir)
scriptFile = self.makeSourceFile([
"import os, sys",
"sys.stdout.write(os.getcwd())"])
d = utilFunc(self.exe, ['-u', scriptFile], path=dir)
d.addCallback(check, dir)
return d
def test_getProcessOutputPath(self):
"""
L{getProcessOutput} runs the given command with the working directory
given by the C{path} parameter.
"""
return self._pathTest(utils.getProcessOutput, self.assertEqual)
def test_getProcessValuePath(self):
"""
L{getProcessValue} runs the given command with the working directory
given by the C{path} parameter.
"""
def check(result, ignored):
self.assertEqual(result, 0)
return self._pathTest(utils.getProcessValue, check)
def test_getProcessOutputAndValuePath(self):
"""
L{getProcessOutputAndValue} runs the given command with the working
directory given by the C{path} parameter.
"""
def check(out_err_status, dir):
out, err, status = out_err_status
self.assertEqual(out, dir)
self.assertEqual(status, 0)
return self._pathTest(utils.getProcessOutputAndValue, check)
def _defaultPathTest(self, utilFunc, check):
# Make another directory to mess around with.
dir = os.path.abspath(self.mktemp())
os.makedirs(dir)
scriptFile = self.makeSourceFile([
"import os, sys, stat",
# Fix the permissions so we can report the working directory.
# On OS X (and maybe elsewhere), os.getcwd() fails with EACCES
# if +x is missing from the working directory.
"os.chmod(%r, stat.S_IXUSR)" % (dir,),
"sys.stdout.write(os.getcwd())"])
# Switch to it, but make sure we switch back
self.addCleanup(os.chdir, os.getcwd())
os.chdir(dir)
# Get rid of all its permissions, but make sure they get cleaned up
# later, because otherwise it might be hard to delete the trial
# temporary directory.
self.addCleanup(
os.chmod, dir, stat.S_IMODE(os.stat('.').st_mode))
os.chmod(dir, 0)
d = utilFunc(self.exe, ['-u', scriptFile])
d.addCallback(check, dir)
return d
def test_getProcessOutputDefaultPath(self):
"""
If no value is supplied for the C{path} parameter, L{getProcessOutput}
runs the given command in the same working directory as the parent
process and succeeds even if the current working directory is not
accessible.
"""
return self._defaultPathTest(utils.getProcessOutput, self.assertEqual)
def test_getProcessValueDefaultPath(self):
"""
If no value is supplied for the C{path} parameter, L{getProcessValue}
runs the given command in the same working directory as the parent
process and succeeds even if the current working directory is not
accessible.
"""
def check(result, ignored):
self.assertEqual(result, 0)
return self._defaultPathTest(utils.getProcessValue, check)
def test_getProcessOutputAndValueDefaultPath(self):
"""
If no value is supplied for the C{path} parameter,
L{getProcessOutputAndValue} runs the given command in the same working
directory as the parent process and succeeds even if the current
working directory is not accessible.
"""
def check(out_err_status, dir):
out, err, status = out_err_status
self.assertEqual(out, dir)
self.assertEqual(status, 0)
return self._defaultPathTest(
utils.getProcessOutputAndValue, check)
class SuppressWarningsTests(unittest.SynchronousTestCase):
"""
Tests for L{utils.suppressWarnings}.
"""
def test_suppressWarnings(self):
"""
L{utils.suppressWarnings} decorates a function so that the given
warnings are suppressed.
"""
result = []
def showwarning(self, *a, **kw):
result.append((a, kw))
self.patch(warnings, "showwarning", showwarning)
def f(msg):
warnings.warn(msg)
g = utils.suppressWarnings(f, (('ignore',), dict(message="This is message")))
# Start off with a sanity check - calling the original function
# should emit the warning.
f("Sanity check message")
self.assertEqual(len(result), 1)
# Now that that's out of the way, call the wrapped function, and
# make sure no new warnings show up.
g("This is message")
self.assertEqual(len(result), 1)
# Finally, emit another warning which should not be ignored, and
# make sure it is not.
g("Unignored message")
self.assertEqual(len(result), 2)
class DeferredSuppressedWarningsTests(SuppressedWarningsTests):
"""
Tests for L{utils.runWithWarningsSuppressed}, the version that supports
Deferreds.
"""
# Override the non-Deferred-supporting function from the base class with
# the function we are testing in this class:
runWithWarningsSuppressed = staticmethod(utils.runWithWarningsSuppressed)
def test_deferredCallback(self):
"""
If the function called by L{utils.runWithWarningsSuppressed} returns a
C{Deferred}, the warning filters aren't removed until the Deferred
fires.
"""
filters = [(("ignore", ".*foo.*"), {}),
(("ignore", ".*bar.*"), {})]
result = Deferred()
self.runWithWarningsSuppressed(filters, lambda: result)
warnings.warn("ignore foo")
result.callback(3)
warnings.warn("ignore foo 2")
self.assertEqual(
["ignore foo 2"], [w['message'] for w in self.flushWarnings()])
def test_deferredErrback(self):
"""
If the function called by L{utils.runWithWarningsSuppressed} returns a
C{Deferred}, the warning filters aren't removed until the Deferred
fires with an errback.
"""
filters = [(("ignore", ".*foo.*"), {}),
(("ignore", ".*bar.*"), {})]
result = Deferred()
d = self.runWithWarningsSuppressed(filters, lambda: result)
warnings.warn("ignore foo")
result.errback(ZeroDivisionError())
d.addErrback(lambda f: f.trap(ZeroDivisionError))
warnings.warn("ignore foo 2")
self.assertEqual(
["ignore foo 2"], [w['message'] for w in self.flushWarnings()])
if _PY3:
del ProcessUtilsTests

View file

@ -0,0 +1,640 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for L{jelly} object serialization.
"""
import datetime
import decimal
from twisted.spread import jelly, pb
from twisted.trial import unittest
from twisted.test.proto_helpers import StringTransport
class TestNode(object, jelly.Jellyable):
"""
An object to test jellyfying of new style class instances.
"""
classAttr = 4
def __init__(self, parent=None):
if parent:
self.id = parent.id + 1
parent.children.append(self)
else:
self.id = 1
self.parent = parent
self.children = []
class A:
"""
Dummy class.
"""
def amethod(self):
"""
Method tp be used in serialization tests.
"""
def afunc(self):
"""
A dummy function to test function serialization.
"""
class B:
"""
Dummy class.
"""
def bmethod(self):
"""
Method to be used in serialization tests.
"""
class C:
"""
Dummy class.
"""
def cmethod(self):
"""
Method to be used in serialization tests.
"""
class D(object):
"""
Dummy new-style class.
"""
class E(object):
"""
Dummy new-style class with slots.
"""
__slots__ = ("x", "y")
def __init__(self, x=None, y=None):
self.x = x
self.y = y
def __getstate__(self):
return {"x" : self.x, "y" : self.y}
def __setstate__(self, state):
self.x = state["x"]
self.y = state["y"]
class SimpleJellyTest:
def __init__(self, x, y):
self.x = x
self.y = y
def isTheSameAs(self, other):
return self.__dict__ == other.__dict__
class JellyTestCase(unittest.TestCase):
"""
Testcases for L{jelly} module serialization.
@cvar decimalData: serialized version of decimal data, to be used in tests.
@type decimalData: C{list}
"""
def _testSecurity(self, inputList, atom):
"""
Helper test method to test security options for a type.
@param inputList: a sample input for the type.
@param inputList: C{list}
@param atom: atom identifier for the type.
@type atom: C{str}
"""
c = jelly.jelly(inputList)
taster = jelly.SecurityOptions()
taster.allowBasicTypes()
# By default, it should succeed
jelly.unjelly(c, taster)
taster.allowedTypes.pop(atom)
# But it should raise an exception when disallowed
self.assertRaises(jelly.InsecureJelly, jelly.unjelly, c, taster)
def test_methodSelfIdentity(self):
a = A()
b = B()
a.bmethod = b.bmethod
b.a = a
im_ = jelly.unjelly(jelly.jelly(b)).a.bmethod
self.assertEqual(im_.im_class, im_.im_self.__class__)
def test_methodsNotSelfIdentity(self):
"""
If a class change after an instance has been created, L{jelly.unjelly}
shoud raise a C{TypeError} when trying to unjelly the instance.
"""
a = A()
b = B()
c = C()
a.bmethod = c.cmethod
b.a = a
savecmethod = C.cmethod
del C.cmethod
try:
self.assertRaises(TypeError, jelly.unjelly, jelly.jelly(b))
finally:
C.cmethod = savecmethod
def test_newStyle(self):
n = D()
n.x = 1
n2 = D()
n.n2 = n2
n.n3 = n2
c = jelly.jelly(n)
m = jelly.unjelly(c)
self.assertIsInstance(m, D)
self.assertIdentical(m.n2, m.n3)
def test_newStyleWithSlots(self):
"""
A class defined with I{slots} can be jellied and unjellied with the
values for its attributes preserved.
"""
n = E()
n.x = 1
c = jelly.jelly(n)
m = jelly.unjelly(c)
self.assertIsInstance(m, E)
self.assertEqual(n.x, 1)
def test_typeOldStyle(self):
"""
Test that an old style class type can be jellied and unjellied
to the original type.
"""
t = [C]
r = jelly.unjelly(jelly.jelly(t))
self.assertEqual(t, r)
def test_typeNewStyle(self):
"""
Test that a new style class type can be jellied and unjellied
to the original type.
"""
t = [D]
r = jelly.unjelly(jelly.jelly(t))
self.assertEqual(t, r)
def test_typeBuiltin(self):
"""
Test that a builtin type can be jellied and unjellied to the original
type.
"""
t = [str]
r = jelly.unjelly(jelly.jelly(t))
self.assertEqual(t, r)
def test_dateTime(self):
dtn = datetime.datetime.now()
dtd = datetime.datetime.now() - dtn
input = [dtn, dtd]
c = jelly.jelly(input)
output = jelly.unjelly(c)
self.assertEqual(input, output)
self.assertNotIdentical(input, output)
def test_decimal(self):
"""
Jellying L{decimal.Decimal} instances and then unjellying the result
should produce objects which represent the values of the original
inputs.
"""
inputList = [decimal.Decimal('9.95'),
decimal.Decimal(0),
decimal.Decimal(123456),
decimal.Decimal('-78.901')]
c = jelly.jelly(inputList)
output = jelly.unjelly(c)
self.assertEqual(inputList, output)
self.assertNotIdentical(inputList, output)
decimalData = ['list', ['decimal', 995, -2], ['decimal', 0, 0],
['decimal', 123456, 0], ['decimal', -78901, -3]]
def test_decimalUnjelly(self):
"""
Unjellying the s-expressions produced by jelly for L{decimal.Decimal}
instances should result in L{decimal.Decimal} instances with the values
represented by the s-expressions.
This test also verifies that C{self.decimalData} contains valid jellied
data. This is important since L{test_decimalMissing} re-uses
C{self.decimalData} and is expected to be unable to produce
L{decimal.Decimal} instances even though the s-expression correctly
represents a list of them.
"""
expected = [decimal.Decimal('9.95'),
decimal.Decimal(0),
decimal.Decimal(123456),
decimal.Decimal('-78.901')]
output = jelly.unjelly(self.decimalData)
self.assertEqual(output, expected)
def test_decimalSecurity(self):
"""
By default, C{decimal} objects should be allowed by
L{jelly.SecurityOptions}. If not allowed, L{jelly.unjelly} should raise
L{jelly.InsecureJelly} when trying to unjelly it.
"""
inputList = [decimal.Decimal('9.95')]
self._testSecurity(inputList, "decimal")
def test_set(self):
"""
Jellying C{set} instances and then unjellying the result
should produce objects which represent the values of the original
inputs.
"""
inputList = [set([1, 2, 3])]
output = jelly.unjelly(jelly.jelly(inputList))
self.assertEqual(inputList, output)
self.assertNotIdentical(inputList, output)
def test_frozenset(self):
"""
Jellying C{frozenset} instances and then unjellying the result
should produce objects which represent the values of the original
inputs.
"""
inputList = [frozenset([1, 2, 3])]
output = jelly.unjelly(jelly.jelly(inputList))
self.assertEqual(inputList, output)
self.assertNotIdentical(inputList, output)
def test_setSecurity(self):
"""
By default, C{set} objects should be allowed by
L{jelly.SecurityOptions}. If not allowed, L{jelly.unjelly} should raise
L{jelly.InsecureJelly} when trying to unjelly it.
"""
inputList = [set([1, 2, 3])]
self._testSecurity(inputList, "set")
def test_frozensetSecurity(self):
"""
By default, C{frozenset} objects should be allowed by
L{jelly.SecurityOptions}. If not allowed, L{jelly.unjelly} should raise
L{jelly.InsecureJelly} when trying to unjelly it.
"""
inputList = [frozenset([1, 2, 3])]
self._testSecurity(inputList, "frozenset")
def test_oldSets(self):
"""
Test jellying C{sets.Set}: it should serialize to the same thing as
C{set} jelly, and be unjellied as C{set} if available.
"""
inputList = [jelly._sets.Set([1, 2, 3])]
inputJelly = jelly.jelly(inputList)
self.assertEqual(inputJelly, jelly.jelly([set([1, 2, 3])]))
output = jelly.unjelly(inputJelly)
# Even if the class is different, it should coerce to the same list
self.assertEqual(list(inputList[0]), list(output[0]))
if set is jelly._sets.Set:
self.assertIsInstance(output[0], jelly._sets.Set)
else:
self.assertIsInstance(output[0], set)
def test_oldImmutableSets(self):
"""
Test jellying C{sets.ImmutableSet}: it should serialize to the same
thing as C{frozenset} jelly, and be unjellied as C{frozenset} if
available.
"""
inputList = [jelly._sets.ImmutableSet([1, 2, 3])]
inputJelly = jelly.jelly(inputList)
self.assertEqual(inputJelly, jelly.jelly([frozenset([1, 2, 3])]))
output = jelly.unjelly(inputJelly)
# Even if the class is different, it should coerce to the same list
self.assertEqual(list(inputList[0]), list(output[0]))
if frozenset is jelly._sets.ImmutableSet:
self.assertIsInstance(output[0], jelly._sets.ImmutableSet)
else:
self.assertIsInstance(output[0], frozenset)
def test_simple(self):
"""
Simplest test case.
"""
self.failUnless(SimpleJellyTest('a', 'b').isTheSameAs(
SimpleJellyTest('a', 'b')))
a = SimpleJellyTest(1, 2)
cereal = jelly.jelly(a)
b = jelly.unjelly(cereal)
self.failUnless(a.isTheSameAs(b))
def test_identity(self):
"""
Test to make sure that objects retain identity properly.
"""
x = []
y = (x)
x.append(y)
x.append(y)
self.assertIdentical(x[0], x[1])
self.assertIdentical(x[0][0], x)
s = jelly.jelly(x)
z = jelly.unjelly(s)
self.assertIdentical(z[0], z[1])
self.assertIdentical(z[0][0], z)
def test_unicode(self):
x = unicode('blah')
y = jelly.unjelly(jelly.jelly(x))
self.assertEqual(x, y)
self.assertEqual(type(x), type(y))
def test_stressReferences(self):
reref = []
toplevelTuple = ({'list': reref}, reref)
reref.append(toplevelTuple)
s = jelly.jelly(toplevelTuple)
z = jelly.unjelly(s)
self.assertIdentical(z[0]['list'], z[1])
self.assertIdentical(z[0]['list'][0], z)
def test_moreReferences(self):
a = []
t = (a,)
a.append((t,))
s = jelly.jelly(t)
z = jelly.unjelly(s)
self.assertIdentical(z[0][0][0], z)
def test_typeSecurity(self):
"""
Test for type-level security of serialization.
"""
taster = jelly.SecurityOptions()
dct = jelly.jelly({})
self.assertRaises(jelly.InsecureJelly, jelly.unjelly, dct, taster)
def test_newStyleClasses(self):
j = jelly.jelly(D)
uj = jelly.unjelly(D)
self.assertIdentical(D, uj)
def test_lotsaTypes(self):
"""
Test for all types currently supported in jelly
"""
a = A()
jelly.unjelly(jelly.jelly(a))
jelly.unjelly(jelly.jelly(a.amethod))
items = [afunc, [1, 2, 3], not bool(1), bool(1), 'test', 20.3,
(1, 2, 3), None, A, unittest, {'a': 1}, A.amethod]
for i in items:
self.assertEqual(i, jelly.unjelly(jelly.jelly(i)))
def test_setState(self):
global TupleState
class TupleState:
def __init__(self, other):
self.other = other
def __getstate__(self):
return (self.other,)
def __setstate__(self, state):
self.other = state[0]
def __hash__(self):
return hash(self.other)
a = A()
t1 = TupleState(a)
t2 = TupleState(a)
t3 = TupleState((t1, t2))
d = {t1: t1, t2: t2, t3: t3, "t3": t3}
t3prime = jelly.unjelly(jelly.jelly(d))["t3"]
self.assertIdentical(t3prime.other[0].other, t3prime.other[1].other)
def test_classSecurity(self):
"""
Test for class-level security of serialization.
"""
taster = jelly.SecurityOptions()
taster.allowInstancesOf(A, B)
a = A()
b = B()
c = C()
# add a little complexity to the data
a.b = b
a.c = c
# and a backreference
a.x = b
b.c = c
# first, a friendly insecure serialization
friendly = jelly.jelly(a, taster)
x = jelly.unjelly(friendly, taster)
self.assertIsInstance(x.c, jelly.Unpersistable)
# now, a malicious one
mean = jelly.jelly(a)
self.assertRaises(jelly.InsecureJelly, jelly.unjelly, mean, taster)
self.assertIdentical(x.x, x.b, "Identity mismatch")
# test class serialization
friendly = jelly.jelly(A, taster)
x = jelly.unjelly(friendly, taster)
self.assertIdentical(x, A, "A came back: %s" % x)
def test_unjellyable(self):
"""
Test that if Unjellyable is used to deserialize a jellied object,
state comes out right.
"""
class JellyableTestClass(jelly.Jellyable):
pass
jelly.setUnjellyableForClass(JellyableTestClass, jelly.Unjellyable)
input = JellyableTestClass()
input.attribute = 'value'
output = jelly.unjelly(jelly.jelly(input))
self.assertEqual(output.attribute, 'value')
self.assertIsInstance(output, jelly.Unjellyable)
def test_persistentStorage(self):
perst = [{}, 1]
def persistentStore(obj, jel, perst = perst):
perst[1] = perst[1] + 1
perst[0][perst[1]] = obj
return str(perst[1])
def persistentLoad(pidstr, unj, perst = perst):
pid = int(pidstr)
return perst[0][pid]
a = SimpleJellyTest(1, 2)
b = SimpleJellyTest(3, 4)
c = SimpleJellyTest(5, 6)
a.b = b
a.c = c
c.b = b
jel = jelly.jelly(a, persistentStore = persistentStore)
x = jelly.unjelly(jel, persistentLoad = persistentLoad)
self.assertIdentical(x.b, x.c.b)
self.failUnless(perst[0], "persistentStore was not called.")
self.assertIdentical(x.b, a.b, "Persistent storage identity failure.")
def test_newStyleClassesAttributes(self):
n = TestNode()
n1 = TestNode(n)
n11 = TestNode(n1)
n2 = TestNode(n)
# Jelly it
jel = jelly.jelly(n)
m = jelly.unjelly(jel)
# Check that it has been restored ok
self._check_newstyle(n, m)
def _check_newstyle(self, a, b):
self.assertEqual(a.id, b.id)
self.assertEqual(a.classAttr, 4)
self.assertEqual(b.classAttr, 4)
self.assertEqual(len(a.children), len(b.children))
for x, y in zip(a.children, b.children):
self._check_newstyle(x, y)
def test_referenceable(self):
"""
A L{pb.Referenceable} instance jellies to a structure which unjellies to
a L{pb.RemoteReference}. The C{RemoteReference} has a I{luid} that
matches up with the local object key in the L{pb.Broker} which sent the
L{Referenceable}.
"""
ref = pb.Referenceable()
jellyBroker = pb.Broker()
jellyBroker.makeConnection(StringTransport())
j = jelly.jelly(ref, invoker=jellyBroker)
unjellyBroker = pb.Broker()
unjellyBroker.makeConnection(StringTransport())
uj = jelly.unjelly(j, invoker=unjellyBroker)
self.assertIn(uj.luid, jellyBroker.localObjects)
class ClassA(pb.Copyable, pb.RemoteCopy):
def __init__(self):
self.ref = ClassB(self)
class ClassB(pb.Copyable, pb.RemoteCopy):
def __init__(self, ref):
self.ref = ref
class CircularReferenceTestCase(unittest.TestCase):
"""
Tests for circular references handling in the jelly/unjelly process.
"""
def test_simpleCircle(self):
jelly.setUnjellyableForClass(ClassA, ClassA)
jelly.setUnjellyableForClass(ClassB, ClassB)
a = jelly.unjelly(jelly.jelly(ClassA()))
self.assertIdentical(a.ref.ref, a,
"Identity not preserved in circular reference")
def test_circleWithInvoker(self):
class DummyInvokerClass:
pass
dummyInvoker = DummyInvokerClass()
dummyInvoker.serializingPerspective = None
a0 = ClassA()
jelly.setUnjellyableForClass(ClassA, ClassA)
jelly.setUnjellyableForClass(ClassB, ClassB)
j = jelly.jelly(a0, invoker=dummyInvoker)
a1 = jelly.unjelly(j)
self.failUnlessIdentical(a1.ref.ref, a1,
"Identity not preserved in circular reference")
def test_set(self):
"""
Check that a C{set} can contain a circular reference and be serialized
and unserialized without losing the reference.
"""
s = set()
a = SimpleJellyTest(s, None)
s.add(a)
res = jelly.unjelly(jelly.jelly(a))
self.assertIsInstance(res.x, set)
self.assertEqual(list(res.x), [res])
def test_frozenset(self):
"""
Check that a C{frozenset} can contain a circular reference and be
serializeserialized without losing the reference.
"""
a = SimpleJellyTest(None, None)
s = frozenset([a])
a.x = s
res = jelly.unjelly(jelly.jelly(a))
self.assertIsInstance(res.x, frozenset)
self.assertEqual(list(res.x), [res])

View file

@ -0,0 +1,445 @@
# Copyright (c) 2005 Divmod, Inc.
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.python.lockfile}.
"""
import os, errno
from twisted.trial import unittest
from twisted.python import lockfile
from twisted.python.runtime import platform
skipKill = None
if platform.isWindows():
try:
from win32api import OpenProcess
import pywintypes
except ImportError:
skipKill = ("On windows, lockfile.kill is not implemented in the "
"absence of win32api and/or pywintypes.")
class UtilTests(unittest.TestCase):
"""
Tests for the helper functions used to implement L{FilesystemLock}.
"""
def test_symlinkEEXIST(self):
"""
L{lockfile.symlink} raises L{OSError} with C{errno} set to L{EEXIST}
when an attempt is made to create a symlink which already exists.
"""
name = self.mktemp()
lockfile.symlink('foo', name)
exc = self.assertRaises(OSError, lockfile.symlink, 'foo', name)
self.assertEqual(exc.errno, errno.EEXIST)
def test_symlinkEIOWindows(self):
"""
L{lockfile.symlink} raises L{OSError} with C{errno} set to L{EIO} when
the underlying L{rename} call fails with L{EIO}.
Renaming a file on Windows may fail if the target of the rename is in
the process of being deleted (directory deletion appears not to be
atomic).
"""
name = self.mktemp()
def fakeRename(src, dst):
raise IOError(errno.EIO, None)
self.patch(lockfile, 'rename', fakeRename)
exc = self.assertRaises(IOError, lockfile.symlink, name, "foo")
self.assertEqual(exc.errno, errno.EIO)
if not platform.isWindows():
test_symlinkEIOWindows.skip = (
"special rename EIO handling only necessary and correct on "
"Windows.")
def test_readlinkENOENT(self):
"""
L{lockfile.readlink} raises L{OSError} with C{errno} set to L{ENOENT}
when an attempt is made to read a symlink which does not exist.
"""
name = self.mktemp()
exc = self.assertRaises(OSError, lockfile.readlink, name)
self.assertEqual(exc.errno, errno.ENOENT)
def test_readlinkEACCESWindows(self):
"""
L{lockfile.readlink} raises L{OSError} with C{errno} set to L{EACCES}
on Windows when the underlying file open attempt fails with C{EACCES}.
Opening a file on Windows may fail if the path is inside a directory
which is in the process of being deleted (directory deletion appears
not to be atomic).
"""
name = self.mktemp()
def fakeOpen(path, mode):
raise IOError(errno.EACCES, None)
self.patch(lockfile, '_open', fakeOpen)
exc = self.assertRaises(IOError, lockfile.readlink, name)
self.assertEqual(exc.errno, errno.EACCES)
if not platform.isWindows():
test_readlinkEACCESWindows.skip = (
"special readlink EACCES handling only necessary and correct on "
"Windows.")
def test_kill(self):
"""
L{lockfile.kill} returns without error if passed the PID of a
process which exists and signal C{0}.
"""
lockfile.kill(os.getpid(), 0)
test_kill.skip = skipKill
def test_killESRCH(self):
"""
L{lockfile.kill} raises L{OSError} with errno of L{ESRCH} if
passed a PID which does not correspond to any process.
"""
# Hopefully there is no process with PID 2 ** 31 - 1
exc = self.assertRaises(OSError, lockfile.kill, 2 ** 31 - 1, 0)
self.assertEqual(exc.errno, errno.ESRCH)
test_killESRCH.skip = skipKill
def test_noKillCall(self):
"""
Verify that when L{lockfile.kill} does end up as None (e.g. on Windows
without pywin32), it doesn't end up being called and raising a
L{TypeError}.
"""
self.patch(lockfile, "kill", None)
fl = lockfile.FilesystemLock(self.mktemp())
fl.lock()
self.assertFalse(fl.lock())
class LockingTestCase(unittest.TestCase):
def _symlinkErrorTest(self, errno):
def fakeSymlink(source, dest):
raise OSError(errno, None)
self.patch(lockfile, 'symlink', fakeSymlink)
lockf = self.mktemp()
lock = lockfile.FilesystemLock(lockf)
exc = self.assertRaises(OSError, lock.lock)
self.assertEqual(exc.errno, errno)
def test_symlinkError(self):
"""
An exception raised by C{symlink} other than C{EEXIST} is passed up to
the caller of L{FilesystemLock.lock}.
"""
self._symlinkErrorTest(errno.ENOSYS)
def test_symlinkErrorPOSIX(self):
"""
An L{OSError} raised by C{symlink} on a POSIX platform with an errno of
C{EACCES} or C{EIO} is passed to the caller of L{FilesystemLock.lock}.
On POSIX, unlike on Windows, these are unexpected errors which cannot
be handled by L{FilesystemLock}.
"""
self._symlinkErrorTest(errno.EACCES)
self._symlinkErrorTest(errno.EIO)
if platform.isWindows():
test_symlinkErrorPOSIX.skip = (
"POSIX-specific error propagation not expected on Windows.")
def test_cleanlyAcquire(self):
"""
If the lock has never been held, it can be acquired and the C{clean}
and C{locked} attributes are set to C{True}.
"""
lockf = self.mktemp()
lock = lockfile.FilesystemLock(lockf)
self.assertTrue(lock.lock())
self.assertTrue(lock.clean)
self.assertTrue(lock.locked)
def test_cleanlyRelease(self):
"""
If a lock is released cleanly, it can be re-acquired and the C{clean}
and C{locked} attributes are set to C{True}.
"""
lockf = self.mktemp()
lock = lockfile.FilesystemLock(lockf)
self.assertTrue(lock.lock())
lock.unlock()
self.assertFalse(lock.locked)
lock = lockfile.FilesystemLock(lockf)
self.assertTrue(lock.lock())
self.assertTrue(lock.clean)
self.assertTrue(lock.locked)
def test_cannotLockLocked(self):
"""
If a lock is currently locked, it cannot be locked again.
"""
lockf = self.mktemp()
firstLock = lockfile.FilesystemLock(lockf)
self.assertTrue(firstLock.lock())
secondLock = lockfile.FilesystemLock(lockf)
self.assertFalse(secondLock.lock())
self.assertFalse(secondLock.locked)
def test_uncleanlyAcquire(self):
"""
If a lock was held by a process which no longer exists, it can be
acquired, the C{clean} attribute is set to C{False}, and the
C{locked} attribute is set to C{True}.
"""
owner = 12345
def fakeKill(pid, signal):
if signal != 0:
raise OSError(errno.EPERM, None)
if pid == owner:
raise OSError(errno.ESRCH, None)
lockf = self.mktemp()
self.patch(lockfile, 'kill', fakeKill)
lockfile.symlink(str(owner), lockf)
lock = lockfile.FilesystemLock(lockf)
self.assertTrue(lock.lock())
self.assertFalse(lock.clean)
self.assertTrue(lock.locked)
self.assertEqual(lockfile.readlink(lockf), str(os.getpid()))
def test_lockReleasedBeforeCheck(self):
"""
If the lock is initially held but then released before it can be
examined to determine if the process which held it still exists, it is
acquired and the C{clean} and C{locked} attributes are set to C{True}.
"""
def fakeReadlink(name):
# Pretend to be another process releasing the lock.
lockfile.rmlink(lockf)
# Fall back to the real implementation of readlink.
readlinkPatch.restore()
return lockfile.readlink(name)
readlinkPatch = self.patch(lockfile, 'readlink', fakeReadlink)
def fakeKill(pid, signal):
if signal != 0:
raise OSError(errno.EPERM, None)
if pid == 43125:
raise OSError(errno.ESRCH, None)
self.patch(lockfile, 'kill', fakeKill)
lockf = self.mktemp()
lock = lockfile.FilesystemLock(lockf)
lockfile.symlink(str(43125), lockf)
self.assertTrue(lock.lock())
self.assertTrue(lock.clean)
self.assertTrue(lock.locked)
def test_lockReleasedDuringAcquireSymlink(self):
"""
If the lock is released while an attempt is made to acquire
it, the lock attempt fails and C{FilesystemLock.lock} returns
C{False}. This can happen on Windows when L{lockfile.symlink}
fails with L{IOError} of C{EIO} because another process is in
the middle of a call to L{os.rmdir} (implemented in terms of
RemoveDirectory) which is not atomic.
"""
def fakeSymlink(src, dst):
# While another process id doing os.rmdir which the Windows
# implementation of rmlink does, a rename call will fail with EIO.
raise OSError(errno.EIO, None)
self.patch(lockfile, 'symlink', fakeSymlink)
lockf = self.mktemp()
lock = lockfile.FilesystemLock(lockf)
self.assertFalse(lock.lock())
self.assertFalse(lock.locked)
if not platform.isWindows():
test_lockReleasedDuringAcquireSymlink.skip = (
"special rename EIO handling only necessary and correct on "
"Windows.")
def test_lockReleasedDuringAcquireReadlink(self):
"""
If the lock is initially held but is released while an attempt
is made to acquire it, the lock attempt fails and
L{FilesystemLock.lock} returns C{False}.
"""
def fakeReadlink(name):
# While another process is doing os.rmdir which the
# Windows implementation of rmlink does, a readlink call
# will fail with EACCES.
raise IOError(errno.EACCES, None)
readlinkPatch = self.patch(lockfile, 'readlink', fakeReadlink)
lockf = self.mktemp()
lock = lockfile.FilesystemLock(lockf)
lockfile.symlink(str(43125), lockf)
self.assertFalse(lock.lock())
self.assertFalse(lock.locked)
if not platform.isWindows():
test_lockReleasedDuringAcquireReadlink.skip = (
"special readlink EACCES handling only necessary and correct on "
"Windows.")
def _readlinkErrorTest(self, exceptionType, errno):
def fakeReadlink(name):
raise exceptionType(errno, None)
self.patch(lockfile, 'readlink', fakeReadlink)
lockf = self.mktemp()
# Make it appear locked so it has to use readlink
lockfile.symlink(str(43125), lockf)
lock = lockfile.FilesystemLock(lockf)
exc = self.assertRaises(exceptionType, lock.lock)
self.assertEqual(exc.errno, errno)
self.assertFalse(lock.locked)
def test_readlinkError(self):
"""
An exception raised by C{readlink} other than C{ENOENT} is passed up to
the caller of L{FilesystemLock.lock}.
"""
self._readlinkErrorTest(OSError, errno.ENOSYS)
self._readlinkErrorTest(IOError, errno.ENOSYS)
def test_readlinkErrorPOSIX(self):
"""
Any L{IOError} raised by C{readlink} on a POSIX platform passed to the
caller of L{FilesystemLock.lock}.
On POSIX, unlike on Windows, these are unexpected errors which cannot
be handled by L{FilesystemLock}.
"""
self._readlinkErrorTest(IOError, errno.ENOSYS)
self._readlinkErrorTest(IOError, errno.EACCES)
if platform.isWindows():
test_readlinkErrorPOSIX.skip = (
"POSIX-specific error propagation not expected on Windows.")
def test_lockCleanedUpConcurrently(self):
"""
If a second process cleans up the lock after a first one checks the
lock and finds that no process is holding it, the first process does
not fail when it tries to clean up the lock.
"""
def fakeRmlink(name):
rmlinkPatch.restore()
# Pretend to be another process cleaning up the lock.
lockfile.rmlink(lockf)
# Fall back to the real implementation of rmlink.
return lockfile.rmlink(name)
rmlinkPatch = self.patch(lockfile, 'rmlink', fakeRmlink)
def fakeKill(pid, signal):
if signal != 0:
raise OSError(errno.EPERM, None)
if pid == 43125:
raise OSError(errno.ESRCH, None)
self.patch(lockfile, 'kill', fakeKill)
lockf = self.mktemp()
lock = lockfile.FilesystemLock(lockf)
lockfile.symlink(str(43125), lockf)
self.assertTrue(lock.lock())
self.assertTrue(lock.clean)
self.assertTrue(lock.locked)
def test_rmlinkError(self):
"""
An exception raised by L{rmlink} other than C{ENOENT} is passed up
to the caller of L{FilesystemLock.lock}.
"""
def fakeRmlink(name):
raise OSError(errno.ENOSYS, None)
self.patch(lockfile, 'rmlink', fakeRmlink)
def fakeKill(pid, signal):
if signal != 0:
raise OSError(errno.EPERM, None)
if pid == 43125:
raise OSError(errno.ESRCH, None)
self.patch(lockfile, 'kill', fakeKill)
lockf = self.mktemp()
# Make it appear locked so it has to use readlink
lockfile.symlink(str(43125), lockf)
lock = lockfile.FilesystemLock(lockf)
exc = self.assertRaises(OSError, lock.lock)
self.assertEqual(exc.errno, errno.ENOSYS)
self.assertFalse(lock.locked)
def test_killError(self):
"""
If L{kill} raises an exception other than L{OSError} with errno set to
C{ESRCH}, the exception is passed up to the caller of
L{FilesystemLock.lock}.
"""
def fakeKill(pid, signal):
raise OSError(errno.EPERM, None)
self.patch(lockfile, 'kill', fakeKill)
lockf = self.mktemp()
# Make it appear locked so it has to use readlink
lockfile.symlink(str(43125), lockf)
lock = lockfile.FilesystemLock(lockf)
exc = self.assertRaises(OSError, lock.lock)
self.assertEqual(exc.errno, errno.EPERM)
self.assertFalse(lock.locked)
def test_unlockOther(self):
"""
L{FilesystemLock.unlock} raises L{ValueError} if called for a lock
which is held by a different process.
"""
lockf = self.mktemp()
lockfile.symlink(str(os.getpid() + 1), lockf)
lock = lockfile.FilesystemLock(lockf)
self.assertRaises(ValueError, lock.unlock)
def test_isLocked(self):
"""
L{isLocked} returns C{True} if the named lock is currently locked,
C{False} otherwise.
"""
lockf = self.mktemp()
self.assertFalse(lockfile.isLocked(lockf))
lock = lockfile.FilesystemLock(lockf)
self.assertTrue(lock.lock())
self.assertTrue(lockfile.isLocked(lockf))
lock.unlock()
self.assertFalse(lockfile.isLocked(lockf))

View file

@ -0,0 +1,882 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.python.log}.
"""
from __future__ import division, absolute_import, print_function
from twisted.python.compat import _PY3, NativeStringIO as StringIO
import os, sys, time, logging, warnings, calendar
from twisted.trial import unittest
from twisted.python import log, failure
class FakeWarning(Warning):
"""
A unique L{Warning} subclass used by tests for interactions of
L{twisted.python.log} with the L{warnings} module.
"""
class LogTest(unittest.SynchronousTestCase):
def setUp(self):
self.catcher = []
self.observer = self.catcher.append
log.addObserver(self.observer)
self.addCleanup(log.removeObserver, self.observer)
def testObservation(self):
catcher = self.catcher
log.msg("test", testShouldCatch=True)
i = catcher.pop()
self.assertEqual(i["message"][0], "test")
self.assertEqual(i["testShouldCatch"], True)
self.assertIn("time", i)
self.assertEqual(len(catcher), 0)
def testContext(self):
catcher = self.catcher
log.callWithContext({"subsystem": "not the default",
"subsubsystem": "a",
"other": "c"},
log.callWithContext,
{"subsubsystem": "b"}, log.msg, "foo", other="d")
i = catcher.pop()
self.assertEqual(i['subsubsystem'], 'b')
self.assertEqual(i['subsystem'], 'not the default')
self.assertEqual(i['other'], 'd')
self.assertEqual(i['message'][0], 'foo')
def testErrors(self):
for e, ig in [("hello world","hello world"),
(KeyError(), KeyError),
(failure.Failure(RuntimeError()), RuntimeError)]:
log.err(e)
i = self.catcher.pop()
self.assertEqual(i['isError'], 1)
self.flushLoggedErrors(ig)
def testErrorsWithWhy(self):
for e, ig in [("hello world","hello world"),
(KeyError(), KeyError),
(failure.Failure(RuntimeError()), RuntimeError)]:
log.err(e, 'foobar')
i = self.catcher.pop()
self.assertEqual(i['isError'], 1)
self.assertEqual(i['why'], 'foobar')
self.flushLoggedErrors(ig)
def test_erroneousErrors(self):
"""
Exceptions raised by log observers are logged but the observer which
raised the exception remains registered with the publisher. These
exceptions do not prevent the event from being sent to other observers
registered with the publisher.
"""
L1 = []
L2 = []
def broken(events):
1 // 0
for observer in [L1.append, broken, L2.append]:
log.addObserver(observer)
self.addCleanup(log.removeObserver, observer)
for i in range(3):
# Reset the lists for simpler comparison.
L1[:] = []
L2[:] = []
# Send out the event which will break one of the observers.
log.msg("Howdy, y'all.")
# The broken observer should have caused this to be logged.
excs = self.flushLoggedErrors(ZeroDivisionError)
del self.catcher[:]
self.assertEqual(len(excs), 1)
# Both other observers should have seen the message.
self.assertEqual(len(L1), 2)
self.assertEqual(len(L2), 2)
# The order is slightly wrong here. The first event should be
# delivered to all observers; then, errors should be delivered.
self.assertEqual(L1[1]['message'], ("Howdy, y'all.",))
self.assertEqual(L2[0]['message'], ("Howdy, y'all.",))
def test_doubleErrorDoesNotRemoveObserver(self):
"""
If logging causes an error, make sure that if logging the fact that
logging failed also causes an error, the log observer is not removed.
"""
events = []
errors = []
publisher = log.LogPublisher()
class FailingObserver(object):
calls = 0
def log(self, msg, **kwargs):
# First call raises RuntimeError:
self.calls += 1
if self.calls < 2:
raise RuntimeError("Failure #%s" % (self.calls,))
else:
events.append(msg)
observer = FailingObserver()
publisher.addObserver(observer.log)
self.assertEqual(publisher.observers, [observer.log])
try:
# When observer throws, the publisher attempts to log the fact by
# calling self._err()... which also fails with recursion error:
oldError = publisher._err
def failingErr(failure, why, **kwargs):
errors.append(failure.value)
raise RuntimeError("Fake recursion error")
publisher._err = failingErr
publisher.msg("error in first observer")
finally:
publisher._err = oldError
# Observer should still exist; we do this in finally since before
# bug was fixed the test would fail due to uncaught exception, so
# we want failing assert too in that case:
self.assertEqual(publisher.observers, [observer.log])
# The next message should succeed:
publisher.msg("but this should succeed")
self.assertEqual(observer.calls, 2)
self.assertEqual(len(events), 1)
self.assertEqual(events[0]['message'], ("but this should succeed",))
self.assertEqual(len(errors), 1)
self.assertIsInstance(errors[0], RuntimeError)
def test_showwarning(self):
"""
L{twisted.python.log.showwarning} emits the warning as a message
to the Twisted logging system.
"""
publisher = log.LogPublisher()
publisher.addObserver(self.observer)
publisher.showwarning(
FakeWarning("unique warning message"), FakeWarning,
"warning-filename.py", 27)
event = self.catcher.pop()
self.assertEqual(
event['format'] % event,
'warning-filename.py:27: twisted.test.test_log.FakeWarning: '
'unique warning message')
self.assertEqual(self.catcher, [])
# Python 2.6 requires that any function used to override the
# warnings.showwarning API accept a "line" parameter or a
# deprecation warning is emitted.
publisher.showwarning(
FakeWarning("unique warning message"), FakeWarning,
"warning-filename.py", 27, line=object())
event = self.catcher.pop()
self.assertEqual(
event['format'] % event,
'warning-filename.py:27: twisted.test.test_log.FakeWarning: '
'unique warning message')
self.assertEqual(self.catcher, [])
def test_warningToFile(self):
"""
L{twisted.python.log.showwarning} passes warnings with an explicit file
target on to the underlying Python warning system.
"""
# log.showwarning depends on _oldshowwarning being set, which only
# happens in startLogging(), which doesn't happen if you're not
# running under trial. So this test only passes by accident of runner
# environment.
if log._oldshowwarning is None:
raise unittest.SkipTest("Currently this test only runs under trial.")
message = "another unique message"
category = FakeWarning
filename = "warning-filename.py"
lineno = 31
output = StringIO()
log.showwarning(message, category, filename, lineno, file=output)
self.assertEqual(
output.getvalue(),
warnings.formatwarning(message, category, filename, lineno))
# In Python 2.6, warnings.showwarning accepts a "line" argument which
# gives the source line the warning message is to include.
if sys.version_info >= (2, 6):
line = "hello world"
output = StringIO()
log.showwarning(message, category, filename, lineno, file=output,
line=line)
self.assertEqual(
output.getvalue(),
warnings.formatwarning(message, category, filename, lineno,
line))
def test_publisherReportsBrokenObserversPrivately(self):
"""
Log publisher does not use the global L{log.err} when reporting broken
observers.
"""
errors = []
def logError(eventDict):
if eventDict.get("isError"):
errors.append(eventDict["failure"].value)
def fail(eventDict):
raise RuntimeError("test_publisherLocalyReportsBrokenObservers")
publisher = log.LogPublisher()
publisher.addObserver(logError)
publisher.addObserver(fail)
publisher.msg("Hello!")
self.assertEqual(publisher.observers, [logError, fail])
self.assertEqual(len(errors), 1)
self.assertIsInstance(errors[0], RuntimeError)
class FakeFile(list):
def write(self, bytes):
self.append(bytes)
def flush(self):
pass
class EvilStr:
def __str__(self):
1//0
class EvilRepr:
def __str__(self):
return "Happy Evil Repr"
def __repr__(self):
1//0
class EvilReprStr(EvilStr, EvilRepr):
pass
class LogPublisherTestCaseMixin:
def setUp(self):
"""
Add a log observer which records log events in C{self.out}. Also,
make sure the default string encoding is ASCII so that
L{testSingleUnicode} can test the behavior of logging unencodable
unicode messages.
"""
self.out = FakeFile()
self.lp = log.LogPublisher()
self.flo = log.FileLogObserver(self.out)
self.lp.addObserver(self.flo.emit)
try:
str(u'\N{VULGAR FRACTION ONE HALF}')
except UnicodeEncodeError:
# This is the behavior we want - don't change anything.
self._origEncoding = None
else:
if _PY3:
self._origEncoding = None
return
reload(sys)
self._origEncoding = sys.getdefaultencoding()
sys.setdefaultencoding('ascii')
def tearDown(self):
"""
Verify that everything written to the fake file C{self.out} was a
C{str}. Also, restore the default string encoding to its previous
setting, if it was modified by L{setUp}.
"""
for chunk in self.out:
self.failUnless(isinstance(chunk, str), "%r was not a string" % (chunk,))
if self._origEncoding is not None:
sys.setdefaultencoding(self._origEncoding)
del sys.setdefaultencoding
class LogPublisherTestCase(LogPublisherTestCaseMixin, unittest.SynchronousTestCase):
def testSingleString(self):
self.lp.msg("Hello, world.")
self.assertEqual(len(self.out), 1)
def testMultipleString(self):
# Test some stupid behavior that will be deprecated real soon.
# If you are reading this and trying to learn how the logging
# system works, *do not use this feature*.
self.lp.msg("Hello, ", "world.")
self.assertEqual(len(self.out), 1)
def test_singleUnicode(self):
"""
L{log.LogPublisher.msg} does not accept non-ASCII Unicode on Python 2,
logging an error instead.
On Python 3, where Unicode is default message type, the message is
logged normally.
"""
message = u"Hello, \N{VULGAR FRACTION ONE HALF} world."
self.lp.msg(message)
self.assertEqual(len(self.out), 1)
if _PY3:
self.assertIn(message, self.out[0])
else:
self.assertIn('with str error', self.out[0])
self.assertIn('UnicodeEncodeError', self.out[0])
class FileObserverTestCase(LogPublisherTestCaseMixin, unittest.SynchronousTestCase):
"""
Tests for L{log.FileObserver}.
"""
def _getTimezoneOffsetTest(self, tzname, daylightOffset, standardOffset):
"""
Verify that L{getTimezoneOffset} produces the expected offset for a
certain timezone both when daylight saving time is in effect and when
it is not.
@param tzname: The name of a timezone to exercise.
@type tzname: L{bytes}
@param daylightOffset: The number of seconds west of UTC the timezone
should be when daylight saving time is in effect.
@type daylightOffset: L{int}
@param standardOffset: The number of seconds west of UTC the timezone
should be when daylight saving time is not in effect.
@type standardOffset: L{int}
"""
if getattr(time, 'tzset', None) is None:
raise unittest.SkipTest(
"Platform cannot change timezone, cannot verify correct offsets "
"in well-known timezones.")
originalTimezone = os.environ.get('TZ', None)
try:
os.environ['TZ'] = tzname
time.tzset()
# The behavior of mktime depends on the current timezone setting.
# So only do this after changing the timezone.
# Compute a POSIX timestamp for a certain date and time that is
# known to occur at a time when daylight saving time is in effect.
localDaylightTuple = (2006, 6, 30, 0, 0, 0, 4, 181, 1)
daylight = time.mktime(localDaylightTuple)
# Compute a POSIX timestamp for a certain date and time that is
# known to occur at a time when daylight saving time is not in
# effect.
localStandardTuple = (2007, 1, 31, 0, 0, 0, 2, 31, 0)
standard = time.mktime(localStandardTuple)
self.assertEqual(
(self.flo.getTimezoneOffset(daylight),
self.flo.getTimezoneOffset(standard)),
(daylightOffset, standardOffset))
finally:
if originalTimezone is None:
del os.environ['TZ']
else:
os.environ['TZ'] = originalTimezone
time.tzset()
def test_getTimezoneOffsetWestOfUTC(self):
"""
Attempt to verify that L{FileLogObserver.getTimezoneOffset} returns
correct values for the current C{TZ} environment setting for at least
some cases. This test method exercises a timezone that is west of UTC
(and should produce positive results).
"""
self._getTimezoneOffsetTest("America/New_York", 14400, 18000)
def test_getTimezoneOffsetEastOfUTC(self):
"""
Attempt to verify that L{FileLogObserver.getTimezoneOffset} returns
correct values for the current C{TZ} environment setting for at least
some cases. This test method exercises a timezone that is east of UTC
(and should produce negative results).
"""
self._getTimezoneOffsetTest("Europe/Berlin", -7200, -3600)
def test_getTimezoneOffsetWithoutDaylightSavingTime(self):
"""
Attempt to verify that L{FileLogObserver.getTimezoneOffset} returns
correct values for the current C{TZ} environment setting for at least
some cases. This test method exercises a timezone that does not use
daylight saving time at all (so both summer and winter time test values
should have the same offset).
"""
# Test a timezone that doesn't have DST. mktime() implementations
# available for testing seem happy to produce results for this even
# though it's not entirely valid.
self._getTimezoneOffsetTest("Africa/Johannesburg", -7200, -7200)
def test_timeFormatting(self):
"""
Test the method of L{FileLogObserver} which turns a timestamp into a
human-readable string.
"""
when = calendar.timegm((2001, 2, 3, 4, 5, 6, 7, 8, 0))
# Pretend to be in US/Eastern for a moment
self.flo.getTimezoneOffset = lambda when: 18000
self.assertEqual(self.flo.formatTime(when), '2001-02-02 23:05:06-0500')
# Okay now we're in Eastern Europe somewhere
self.flo.getTimezoneOffset = lambda when: -3600
self.assertEqual(self.flo.formatTime(when), '2001-02-03 05:05:06+0100')
# And off in the Pacific or someplace like that
self.flo.getTimezoneOffset = lambda when: -39600
self.assertEqual(self.flo.formatTime(when), '2001-02-03 15:05:06+1100')
# One of those weird places with a half-hour offset timezone
self.flo.getTimezoneOffset = lambda when: 5400
self.assertEqual(self.flo.formatTime(when), '2001-02-03 02:35:06-0130')
# Half-hour offset in the other direction
self.flo.getTimezoneOffset = lambda when: -5400
self.assertEqual(self.flo.formatTime(when), '2001-02-03 05:35:06+0130')
# Test an offset which is between 0 and 60 minutes to make sure the
# sign comes out properly in that case.
self.flo.getTimezoneOffset = lambda when: 1800
self.assertEqual(self.flo.formatTime(when), '2001-02-03 03:35:06-0030')
# Test an offset between 0 and 60 minutes in the other direction.
self.flo.getTimezoneOffset = lambda when: -1800
self.assertEqual(self.flo.formatTime(when), '2001-02-03 04:35:06+0030')
# If a strftime-format string is present on the logger, it should
# use that instead. Note we don't assert anything about day, hour
# or minute because we cannot easily control what time.strftime()
# thinks the local timezone is.
self.flo.timeFormat = '%Y %m'
self.assertEqual(self.flo.formatTime(when), '2001 02')
def test_microsecondTimestampFormatting(self):
"""
L{FileLogObserver.formatTime} supports a value of C{timeFormat} which
includes C{"%f"}, a L{datetime}-only format specifier for microseconds.
"""
self.flo.timeFormat = '%f'
self.assertEqual("600000", self.flo.formatTime(12345.6))
def test_loggingAnObjectWithBroken__str__(self):
#HELLO, MCFLY
self.lp.msg(EvilStr())
self.assertEqual(len(self.out), 1)
# Logging system shouldn't need to crap itself for this trivial case
self.assertNotIn('UNFORMATTABLE', self.out[0])
def test_formattingAnObjectWithBroken__str__(self):
self.lp.msg(format='%(blat)s', blat=EvilStr())
self.assertEqual(len(self.out), 1)
self.assertIn('Invalid format string or unformattable object', self.out[0])
def test_brokenSystem__str__(self):
self.lp.msg('huh', system=EvilStr())
self.assertEqual(len(self.out), 1)
self.assertIn('Invalid format string or unformattable object', self.out[0])
def test_formattingAnObjectWithBroken__repr__Indirect(self):
self.lp.msg(format='%(blat)s', blat=[EvilRepr()])
self.assertEqual(len(self.out), 1)
self.assertIn('UNFORMATTABLE OBJECT', self.out[0])
def test_systemWithBroker__repr__Indirect(self):
self.lp.msg('huh', system=[EvilRepr()])
self.assertEqual(len(self.out), 1)
self.assertIn('UNFORMATTABLE OBJECT', self.out[0])
def test_simpleBrokenFormat(self):
self.lp.msg(format='hooj %s %s', blat=1)
self.assertEqual(len(self.out), 1)
self.assertIn('Invalid format string or unformattable object', self.out[0])
def test_ridiculousFormat(self):
self.lp.msg(format=42, blat=1)
self.assertEqual(len(self.out), 1)
self.assertIn('Invalid format string or unformattable object', self.out[0])
def test_evilFormat__repr__And__str__(self):
self.lp.msg(format=EvilReprStr(), blat=1)
self.assertEqual(len(self.out), 1)
self.assertIn('PATHOLOGICAL', self.out[0])
def test_strangeEventDict(self):
"""
This kind of eventDict used to fail silently, so test it does.
"""
self.lp.msg(message='', isError=False)
self.assertEqual(len(self.out), 0)
def _startLoggingCleanup(self):
"""
Cleanup after a startLogging() call that mutates the hell out of some
global state.
"""
origShowwarnings = log._oldshowwarning
self.addCleanup(setattr, log, "_oldshowwarning", origShowwarnings)
self.addCleanup(setattr, sys, 'stdout', sys.stdout)
self.addCleanup(setattr, sys, 'stderr', sys.stderr)
def test_startLogging(self):
"""
startLogging() installs FileLogObserver and overrides sys.stdout and
sys.stderr.
"""
origStdout, origStderr = sys.stdout, sys.stderr
self._startLoggingCleanup()
# When done with test, reset stdout and stderr to current values:
fakeFile = StringIO()
observer = log.startLogging(fakeFile)
self.addCleanup(observer.stop)
log.msg("Hello!")
self.assertIn("Hello!", fakeFile.getvalue())
self.assertIsInstance(sys.stdout, log.StdioOnnaStick)
self.assertEqual(sys.stdout.isError, False)
encoding = getattr(origStdout, "encoding", None)
if not encoding:
encoding = sys.getdefaultencoding()
self.assertEqual(sys.stdout.encoding, encoding)
self.assertIsInstance(sys.stderr, log.StdioOnnaStick)
self.assertEqual(sys.stderr.isError, True)
encoding = getattr(origStderr, "encoding", None)
if not encoding:
encoding = sys.getdefaultencoding()
self.assertEqual(sys.stderr.encoding, encoding)
def test_startLoggingTwice(self):
"""
There are some obscure error conditions that can occur when logging is
started twice. See http://twistedmatrix.com/trac/ticket/3289 for more
information.
"""
self._startLoggingCleanup()
# The bug is particular to the way that the t.p.log 'global' function
# handle stdout. If we use our own stream, the error doesn't occur. If
# we use our own LogPublisher, the error doesn't occur.
sys.stdout = StringIO()
def showError(eventDict):
if eventDict['isError']:
sys.__stdout__.write(eventDict['failure'].getTraceback())
log.addObserver(showError)
self.addCleanup(log.removeObserver, showError)
observer = log.startLogging(sys.stdout)
self.addCleanup(observer.stop)
# At this point, we expect that sys.stdout is a StdioOnnaStick object.
self.assertIsInstance(sys.stdout, log.StdioOnnaStick)
fakeStdout = sys.stdout
observer = log.startLogging(sys.stdout)
self.assertIdentical(sys.stdout, fakeStdout)
def test_startLoggingOverridesWarning(self):
"""
startLogging() overrides global C{warnings.showwarning} such that
warnings go to Twisted log observers.
"""
self._startLoggingCleanup()
# Ugggh, pretend we're starting from newly imported module:
log._oldshowwarning = None
fakeFile = StringIO()
observer = log.startLogging(fakeFile)
self.addCleanup(observer.stop)
warnings.warn("hello!")
output = fakeFile.getvalue()
self.assertIn("UserWarning: hello!", output)
class PythonLoggingObserverTestCase(unittest.SynchronousTestCase):
"""
Test the bridge with python logging module.
"""
def setUp(self):
self.out = StringIO()
rootLogger = logging.getLogger("")
self.originalLevel = rootLogger.getEffectiveLevel()
rootLogger.setLevel(logging.DEBUG)
self.hdlr = logging.StreamHandler(self.out)
fmt = logging.Formatter(logging.BASIC_FORMAT)
self.hdlr.setFormatter(fmt)
rootLogger.addHandler(self.hdlr)
self.lp = log.LogPublisher()
self.obs = log.PythonLoggingObserver()
self.lp.addObserver(self.obs.emit)
def tearDown(self):
rootLogger = logging.getLogger("")
rootLogger.removeHandler(self.hdlr)
rootLogger.setLevel(self.originalLevel)
logging.shutdown()
def test_singleString(self):
"""
Test simple output, and default log level.
"""
self.lp.msg("Hello, world.")
self.assertIn("Hello, world.", self.out.getvalue())
self.assertIn("INFO", self.out.getvalue())
def test_errorString(self):
"""
Test error output.
"""
self.lp.msg(failure=failure.Failure(ValueError("That is bad.")), isError=True)
self.assertIn("ERROR", self.out.getvalue())
def test_formatString(self):
"""
Test logging with a format.
"""
self.lp.msg(format="%(bar)s oo %(foo)s", bar="Hello", foo="world")
self.assertIn("Hello oo world", self.out.getvalue())
def test_customLevel(self):
"""
Test the logLevel keyword for customizing level used.
"""
self.lp.msg("Spam egg.", logLevel=logging.DEBUG)
self.assertIn("Spam egg.", self.out.getvalue())
self.assertIn("DEBUG", self.out.getvalue())
self.out.seek(0, 0)
self.out.truncate()
self.lp.msg("Foo bar.", logLevel=logging.WARNING)
self.assertIn("Foo bar.", self.out.getvalue())
self.assertIn("WARNING", self.out.getvalue())
def test_strangeEventDict(self):
"""
Verify that an event dictionary which is not an error and has an empty
message isn't recorded.
"""
self.lp.msg(message='', isError=False)
self.assertEqual(self.out.getvalue(), '')
class PythonLoggingIntegrationTestCase(unittest.SynchronousTestCase):
"""
Test integration of python logging bridge.
"""
def test_startStopObserver(self):
"""
Test that start and stop methods of the observer actually register
and unregister to the log system.
"""
oldAddObserver = log.addObserver
oldRemoveObserver = log.removeObserver
l = []
try:
log.addObserver = l.append
log.removeObserver = l.remove
obs = log.PythonLoggingObserver()
obs.start()
self.assertEqual(l[0], obs.emit)
obs.stop()
self.assertEqual(len(l), 0)
finally:
log.addObserver = oldAddObserver
log.removeObserver = oldRemoveObserver
def test_inheritance(self):
"""
Test that we can inherit L{log.PythonLoggingObserver} and use super:
that's basically a validation that L{log.PythonLoggingObserver} is
new-style class.
"""
class MyObserver(log.PythonLoggingObserver):
def emit(self, eventDict):
super(MyObserver, self).emit(eventDict)
obs = MyObserver()
l = []
oldEmit = log.PythonLoggingObserver.emit
try:
log.PythonLoggingObserver.emit = l.append
obs.emit('foo')
self.assertEqual(len(l), 1)
finally:
log.PythonLoggingObserver.emit = oldEmit
class DefaultObserverTestCase(unittest.SynchronousTestCase):
"""
Test the default observer.
"""
def test_failureLogger(self):
"""
The reason argument passed to log.err() appears in the report
generated by DefaultObserver.
"""
self.catcher = []
self.observer = self.catcher.append
log.addObserver(self.observer)
self.addCleanup(log.removeObserver, self.observer)
obs = log.DefaultObserver()
obs.stderr = StringIO()
obs.start()
self.addCleanup(obs.stop)
reason = "The reason."
log.err(Exception(), reason)
errors = self.flushLoggedErrors()
self.assertIn(reason, obs.stderr.getvalue())
self.assertEqual(len(errors), 1)
class StdioOnnaStickTestCase(unittest.SynchronousTestCase):
"""
StdioOnnaStick should act like the normal sys.stdout object.
"""
def setUp(self):
self.resultLogs = []
log.addObserver(self.resultLogs.append)
def tearDown(self):
log.removeObserver(self.resultLogs.append)
def getLogMessages(self):
return ["".join(d['message']) for d in self.resultLogs]
def test_write(self):
"""
Writing to a StdioOnnaStick instance results in Twisted log messages.
Log messages are generated every time a '\n' is encountered.
"""
stdio = log.StdioOnnaStick()
stdio.write("Hello there\nThis is a test")
self.assertEqual(self.getLogMessages(), ["Hello there"])
stdio.write("!\n")
self.assertEqual(self.getLogMessages(), ["Hello there", "This is a test!"])
def test_metadata(self):
"""
The log messages written by StdioOnnaStick have printed=1 keyword, and
by default are not errors.
"""
stdio = log.StdioOnnaStick()
stdio.write("hello\n")
self.assertEqual(self.resultLogs[0]['isError'], False)
self.assertEqual(self.resultLogs[0]['printed'], True)
def test_writeLines(self):
"""
Writing lines to a StdioOnnaStick results in Twisted log messages.
"""
stdio = log.StdioOnnaStick()
stdio.writelines(["log 1", "log 2"])
self.assertEqual(self.getLogMessages(), ["log 1", "log 2"])
def test_print(self):
"""
When StdioOnnaStick is set as sys.stdout, prints become log messages.
"""
oldStdout = sys.stdout
sys.stdout = log.StdioOnnaStick()
self.addCleanup(setattr, sys, "stdout", oldStdout)
print("This", end=" ")
print("is a test")
self.assertEqual(self.getLogMessages(), ["This is a test"])
def test_error(self):
"""
StdioOnnaStick created with isError=True log messages as errors.
"""
stdio = log.StdioOnnaStick(isError=True)
stdio.write("log 1\n")
self.assertEqual(self.resultLogs[0]['isError'], True)
def test_unicode(self):
"""
StdioOnnaStick converts unicode prints to byte strings on Python 2, in
order to be compatible with the normal stdout/stderr objects.
On Python 3, the prints are left unmodified.
"""
unicodeString = u"Hello, \N{VULGAR FRACTION ONE HALF} world."
stdio = log.StdioOnnaStick(encoding="utf-8")
self.assertEqual(stdio.encoding, "utf-8")
stdio.write(unicodeString + u"\n")
stdio.writelines([u"Also, " + unicodeString])
oldStdout = sys.stdout
sys.stdout = stdio
self.addCleanup(setattr, sys, "stdout", oldStdout)
# This should go to the log, utf-8 encoded too:
print(unicodeString)
if _PY3:
self.assertEqual(self.getLogMessages(),
[unicodeString,
u"Also, " + unicodeString,
unicodeString])
else:
self.assertEqual(self.getLogMessages(),
[unicodeString.encode("utf-8"),
(u"Also, " + unicodeString).encode("utf-8"),
unicodeString.encode("utf-8")])

View file

@ -0,0 +1,320 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import os, time, stat, errno
from twisted.trial import unittest
from twisted.python import logfile, runtime
class LogFileTestCase(unittest.TestCase):
"""
Test the rotating log file.
"""
def setUp(self):
self.dir = self.mktemp()
os.makedirs(self.dir)
self.name = "test.log"
self.path = os.path.join(self.dir, self.name)
def tearDown(self):
"""
Restore back write rights on created paths: if tests modified the
rights, that will allow the paths to be removed easily afterwards.
"""
os.chmod(self.dir, 0777)
if os.path.exists(self.path):
os.chmod(self.path, 0777)
def testWriting(self):
log = logfile.LogFile(self.name, self.dir)
log.write("123")
log.write("456")
log.flush()
log.write("7890")
log.close()
f = open(self.path, "r")
self.assertEqual(f.read(), "1234567890")
f.close()
def testRotation(self):
# this logfile should rotate every 10 bytes
log = logfile.LogFile(self.name, self.dir, rotateLength=10)
# test automatic rotation
log.write("123")
log.write("4567890")
log.write("1" * 11)
self.assert_(os.path.exists("%s.1" % self.path))
self.assert_(not os.path.exists("%s.2" % self.path))
log.write('')
self.assert_(os.path.exists("%s.1" % self.path))
self.assert_(os.path.exists("%s.2" % self.path))
self.assert_(not os.path.exists("%s.3" % self.path))
log.write("3")
self.assert_(not os.path.exists("%s.3" % self.path))
# test manual rotation
log.rotate()
self.assert_(os.path.exists("%s.3" % self.path))
self.assert_(not os.path.exists("%s.4" % self.path))
log.close()
self.assertEqual(log.listLogs(), [1, 2, 3])
def testAppend(self):
log = logfile.LogFile(self.name, self.dir)
log.write("0123456789")
log.close()
log = logfile.LogFile(self.name, self.dir)
self.assertEqual(log.size, 10)
self.assertEqual(log._file.tell(), log.size)
log.write("abc")
self.assertEqual(log.size, 13)
self.assertEqual(log._file.tell(), log.size)
f = log._file
f.seek(0, 0)
self.assertEqual(f.read(), "0123456789abc")
log.close()
def testLogReader(self):
log = logfile.LogFile(self.name, self.dir)
log.write("abc\n")
log.write("def\n")
log.rotate()
log.write("ghi\n")
log.flush()
# check reading logs
self.assertEqual(log.listLogs(), [1])
reader = log.getCurrentLog()
reader._file.seek(0)
self.assertEqual(reader.readLines(), ["ghi\n"])
self.assertEqual(reader.readLines(), [])
reader.close()
reader = log.getLog(1)
self.assertEqual(reader.readLines(), ["abc\n", "def\n"])
self.assertEqual(reader.readLines(), [])
reader.close()
# check getting illegal log readers
self.assertRaises(ValueError, log.getLog, 2)
self.assertRaises(TypeError, log.getLog, "1")
# check that log numbers are higher for older logs
log.rotate()
self.assertEqual(log.listLogs(), [1, 2])
reader = log.getLog(1)
reader._file.seek(0)
self.assertEqual(reader.readLines(), ["ghi\n"])
self.assertEqual(reader.readLines(), [])
reader.close()
reader = log.getLog(2)
self.assertEqual(reader.readLines(), ["abc\n", "def\n"])
self.assertEqual(reader.readLines(), [])
reader.close()
def testModePreservation(self):
"""
Check rotated files have same permissions as original.
"""
f = open(self.path, "w").close()
os.chmod(self.path, 0707)
mode = os.stat(self.path)[stat.ST_MODE]
log = logfile.LogFile(self.name, self.dir)
log.write("abc")
log.rotate()
self.assertEqual(mode, os.stat(self.path)[stat.ST_MODE])
def test_noPermission(self):
"""
Check it keeps working when permission on dir changes.
"""
log = logfile.LogFile(self.name, self.dir)
log.write("abc")
# change permissions so rotation would fail
os.chmod(self.dir, 0555)
# if this succeeds, chmod doesn't restrict us, so we can't
# do the test
try:
f = open(os.path.join(self.dir,"xxx"), "w")
except (OSError, IOError):
pass
else:
f.close()
return
log.rotate() # this should not fail
log.write("def")
log.flush()
f = log._file
self.assertEqual(f.tell(), 6)
f.seek(0, 0)
self.assertEqual(f.read(), "abcdef")
log.close()
def test_maxNumberOfLog(self):
"""
Test it respect the limit on the number of files when maxRotatedFiles
is not None.
"""
log = logfile.LogFile(self.name, self.dir, rotateLength=10,
maxRotatedFiles=3)
log.write("1" * 11)
log.write("2" * 11)
self.failUnless(os.path.exists("%s.1" % self.path))
log.write("3" * 11)
self.failUnless(os.path.exists("%s.2" % self.path))
log.write("4" * 11)
self.failUnless(os.path.exists("%s.3" % self.path))
self.assertEqual(file("%s.3" % self.path).read(), "1" * 11)
log.write("5" * 11)
self.assertEqual(file("%s.3" % self.path).read(), "2" * 11)
self.failUnless(not os.path.exists("%s.4" % self.path))
def test_fromFullPath(self):
"""
Test the fromFullPath method.
"""
log1 = logfile.LogFile(self.name, self.dir, 10, defaultMode=0777)
log2 = logfile.LogFile.fromFullPath(self.path, 10, defaultMode=0777)
self.assertEqual(log1.name, log2.name)
self.assertEqual(os.path.abspath(log1.path), log2.path)
self.assertEqual(log1.rotateLength, log2.rotateLength)
self.assertEqual(log1.defaultMode, log2.defaultMode)
def test_defaultPermissions(self):
"""
Test the default permission of the log file: if the file exist, it
should keep the permission.
"""
f = file(self.path, "w")
os.chmod(self.path, 0707)
currentMode = stat.S_IMODE(os.stat(self.path)[stat.ST_MODE])
f.close()
log1 = logfile.LogFile(self.name, self.dir)
self.assertEqual(stat.S_IMODE(os.stat(self.path)[stat.ST_MODE]),
currentMode)
def test_specifiedPermissions(self):
"""
Test specifying the permissions used on the log file.
"""
log1 = logfile.LogFile(self.name, self.dir, defaultMode=0066)
mode = stat.S_IMODE(os.stat(self.path)[stat.ST_MODE])
if runtime.platform.isWindows():
# The only thing we can get here is global read-only
self.assertEqual(mode, 0444)
else:
self.assertEqual(mode, 0066)
def test_reopen(self):
"""
L{logfile.LogFile.reopen} allows to rename the currently used file and
make L{logfile.LogFile} create a new file.
"""
log1 = logfile.LogFile(self.name, self.dir)
log1.write("hello1")
savePath = os.path.join(self.dir, "save.log")
os.rename(self.path, savePath)
log1.reopen()
log1.write("hello2")
log1.close()
f = open(self.path, "r")
self.assertEqual(f.read(), "hello2")
f.close()
f = open(savePath, "r")
self.assertEqual(f.read(), "hello1")
f.close()
if runtime.platform.isWindows():
test_reopen.skip = "Can't test reopen on Windows"
def test_nonExistentDir(self):
"""
Specifying an invalid directory to L{LogFile} raises C{IOError}.
"""
e = self.assertRaises(
IOError, logfile.LogFile, self.name, 'this_dir_does_not_exist')
self.assertEqual(e.errno, errno.ENOENT)
class RiggedDailyLogFile(logfile.DailyLogFile):
_clock = 0.0
def _openFile(self):
logfile.DailyLogFile._openFile(self)
# rig the date to match _clock, not mtime
self.lastDate = self.toDate()
def toDate(self, *args):
if args:
return time.gmtime(*args)[:3]
return time.gmtime(self._clock)[:3]
class DailyLogFileTestCase(unittest.TestCase):
"""
Test rotating log file.
"""
def setUp(self):
self.dir = self.mktemp()
os.makedirs(self.dir)
self.name = "testdaily.log"
self.path = os.path.join(self.dir, self.name)
def testWriting(self):
log = RiggedDailyLogFile(self.name, self.dir)
log.write("123")
log.write("456")
log.flush()
log.write("7890")
log.close()
f = open(self.path, "r")
self.assertEqual(f.read(), "1234567890")
f.close()
def testRotation(self):
# this logfile should rotate every 10 bytes
log = RiggedDailyLogFile(self.name, self.dir)
days = [(self.path + '.' + log.suffix(day * 86400)) for day in range(3)]
# test automatic rotation
log._clock = 0.0 # 1970/01/01 00:00.00
log.write("123")
log._clock = 43200 # 1970/01/01 12:00.00
log.write("4567890")
log._clock = 86400 # 1970/01/02 00:00.00
log.write("1" * 11)
self.assert_(os.path.exists(days[0]))
self.assert_(not os.path.exists(days[1]))
log._clock = 172800 # 1970/01/03 00:00.00
log.write('')
self.assert_(os.path.exists(days[0]))
self.assert_(os.path.exists(days[1]))
self.assert_(not os.path.exists(days[2]))
log._clock = 259199 # 1970/01/03 23:59.59
log.write("3")
self.assert_(not os.path.exists(days[2]))

View file

@ -0,0 +1,431 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test case for L{twisted.protocols.loopback}.
"""
from __future__ import division, absolute_import
from zope.interface import implementer
from twisted.python.compat import _PY3, intToBytes
from twisted.trial import unittest
from twisted.trial.util import suppress as SUPPRESS
from twisted.protocols import basic, loopback
from twisted.internet import defer
from twisted.internet.protocol import Protocol
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IAddress, IPushProducer, IPullProducer
from twisted.internet import reactor, interfaces
class SimpleProtocol(basic.LineReceiver):
def __init__(self):
self.conn = defer.Deferred()
self.lines = []
self.connLost = []
def connectionMade(self):
self.conn.callback(None)
def lineReceived(self, line):
self.lines.append(line)
def connectionLost(self, reason):
self.connLost.append(reason)
class DoomProtocol(SimpleProtocol):
i = 0
def lineReceived(self, line):
self.i += 1
if self.i < 4:
# by this point we should have connection closed,
# but just in case we didn't we won't ever send 'Hello 4'
self.sendLine(b"Hello " + intToBytes(self.i))
SimpleProtocol.lineReceived(self, line)
if self.lines[-1] == b"Hello 3":
self.transport.loseConnection()
class LoopbackTestCaseMixin:
def testRegularFunction(self):
s = SimpleProtocol()
c = SimpleProtocol()
def sendALine(result):
s.sendLine(b"THIS IS LINE ONE!")
s.transport.loseConnection()
s.conn.addCallback(sendALine)
def check(ignored):
self.assertEqual(c.lines, [b"THIS IS LINE ONE!"])
self.assertEqual(len(s.connLost), 1)
self.assertEqual(len(c.connLost), 1)
d = defer.maybeDeferred(self.loopbackFunc, s, c)
d.addCallback(check)
return d
def testSneakyHiddenDoom(self):
s = DoomProtocol()
c = DoomProtocol()
def sendALine(result):
s.sendLine(b"DOOM LINE")
s.conn.addCallback(sendALine)
def check(ignored):
self.assertEqual(s.lines, [b'Hello 1', b'Hello 2', b'Hello 3'])
self.assertEqual(
c.lines, [b'DOOM LINE', b'Hello 1', b'Hello 2', b'Hello 3'])
self.assertEqual(len(s.connLost), 1)
self.assertEqual(len(c.connLost), 1)
d = defer.maybeDeferred(self.loopbackFunc, s, c)
d.addCallback(check)
return d
class LoopbackAsyncTestCase(LoopbackTestCaseMixin, unittest.TestCase):
loopbackFunc = staticmethod(loopback.loopbackAsync)
def test_makeConnection(self):
"""
Test that the client and server protocol both have makeConnection
invoked on them by loopbackAsync.
"""
class TestProtocol(Protocol):
transport = None
def makeConnection(self, transport):
self.transport = transport
server = TestProtocol()
client = TestProtocol()
loopback.loopbackAsync(server, client)
self.failIfEqual(client.transport, None)
self.failIfEqual(server.transport, None)
def _hostpeertest(self, get, testServer):
"""
Test one of the permutations of client/server host/peer.
"""
class TestProtocol(Protocol):
def makeConnection(self, transport):
Protocol.makeConnection(self, transport)
self.onConnection.callback(transport)
if testServer:
server = TestProtocol()
d = server.onConnection = Deferred()
client = Protocol()
else:
server = Protocol()
client = TestProtocol()
d = client.onConnection = Deferred()
loopback.loopbackAsync(server, client)
def connected(transport):
host = getattr(transport, get)()
self.failUnless(IAddress.providedBy(host))
return d.addCallback(connected)
def test_serverHost(self):
"""
Test that the server gets a transport with a properly functioning
implementation of L{ITransport.getHost}.
"""
return self._hostpeertest("getHost", True)
def test_serverPeer(self):
"""
Like C{test_serverHost} but for L{ITransport.getPeer}
"""
return self._hostpeertest("getPeer", True)
def test_clientHost(self, get="getHost"):
"""
Test that the client gets a transport with a properly functioning
implementation of L{ITransport.getHost}.
"""
return self._hostpeertest("getHost", False)
def test_clientPeer(self):
"""
Like C{test_clientHost} but for L{ITransport.getPeer}.
"""
return self._hostpeertest("getPeer", False)
def _greetingtest(self, write, testServer):
"""
Test one of the permutations of write/writeSequence client/server.
@param write: The name of the method to test, C{"write"} or
C{"writeSequence"}.
"""
class GreeteeProtocol(Protocol):
bytes = b""
def dataReceived(self, bytes):
self.bytes += bytes
if self.bytes == b"bytes":
self.received.callback(None)
class GreeterProtocol(Protocol):
def connectionMade(self):
if write == "write":
self.transport.write(b"bytes")
else:
self.transport.writeSequence([b"byt", b"es"])
if testServer:
server = GreeterProtocol()
client = GreeteeProtocol()
d = client.received = Deferred()
else:
server = GreeteeProtocol()
d = server.received = Deferred()
client = GreeterProtocol()
loopback.loopbackAsync(server, client)
return d
def test_clientGreeting(self):
"""
Test that on a connection where the client speaks first, the server
receives the bytes sent by the client.
"""
return self._greetingtest("write", False)
def test_clientGreetingSequence(self):
"""
Like C{test_clientGreeting}, but use C{writeSequence} instead of
C{write} to issue the greeting.
"""
return self._greetingtest("writeSequence", False)
def test_serverGreeting(self, write="write"):
"""
Test that on a connection where the server speaks first, the client
receives the bytes sent by the server.
"""
return self._greetingtest("write", True)
def test_serverGreetingSequence(self):
"""
Like C{test_serverGreeting}, but use C{writeSequence} instead of
C{write} to issue the greeting.
"""
return self._greetingtest("writeSequence", True)
def _producertest(self, producerClass):
toProduce = list(map(intToBytes, range(0, 10)))
class ProducingProtocol(Protocol):
def connectionMade(self):
self.producer = producerClass(list(toProduce))
self.producer.start(self.transport)
class ReceivingProtocol(Protocol):
bytes = b""
def dataReceived(self, data):
self.bytes += data
if self.bytes == b''.join(toProduce):
self.received.callback((client, server))
server = ProducingProtocol()
client = ReceivingProtocol()
client.received = Deferred()
loopback.loopbackAsync(server, client)
return client.received
def test_pushProducer(self):
"""
Test a push producer registered against a loopback transport.
"""
@implementer(IPushProducer)
class PushProducer(object):
resumed = False
def __init__(self, toProduce):
self.toProduce = toProduce
def resumeProducing(self):
self.resumed = True
def start(self, consumer):
self.consumer = consumer
consumer.registerProducer(self, True)
self._produceAndSchedule()
def _produceAndSchedule(self):
if self.toProduce:
self.consumer.write(self.toProduce.pop(0))
reactor.callLater(0, self._produceAndSchedule)
else:
self.consumer.unregisterProducer()
d = self._producertest(PushProducer)
def finished(results):
(client, server) = results
self.assertFalse(
server.producer.resumed,
"Streaming producer should not have been resumed.")
d.addCallback(finished)
return d
def test_pullProducer(self):
"""
Test a pull producer registered against a loopback transport.
"""
@implementer(IPullProducer)
class PullProducer(object):
def __init__(self, toProduce):
self.toProduce = toProduce
def start(self, consumer):
self.consumer = consumer
self.consumer.registerProducer(self, False)
def resumeProducing(self):
self.consumer.write(self.toProduce.pop(0))
if not self.toProduce:
self.consumer.unregisterProducer()
return self._producertest(PullProducer)
def test_writeNotReentrant(self):
"""
L{loopback.loopbackAsync} does not call a protocol's C{dataReceived}
method while that protocol's transport's C{write} method is higher up
on the stack.
"""
class Server(Protocol):
def dataReceived(self, bytes):
self.transport.write(b"bytes")
class Client(Protocol):
ready = False
def connectionMade(self):
reactor.callLater(0, self.go)
def go(self):
self.transport.write(b"foo")
self.ready = True
def dataReceived(self, bytes):
self.wasReady = self.ready
self.transport.loseConnection()
server = Server()
client = Client()
d = loopback.loopbackAsync(client, server)
def cbFinished(ignored):
self.assertTrue(client.wasReady)
d.addCallback(cbFinished)
return d
def test_pumpPolicy(self):
"""
The callable passed as the value for the C{pumpPolicy} parameter to
L{loopbackAsync} is called with a L{_LoopbackQueue} of pending bytes
and a protocol to which they should be delivered.
"""
pumpCalls = []
def dummyPolicy(queue, target):
bytes = []
while queue:
bytes.append(queue.get())
pumpCalls.append((target, bytes))
client = Protocol()
server = Protocol()
finished = loopback.loopbackAsync(server, client, dummyPolicy)
self.assertEqual(pumpCalls, [])
client.transport.write(b"foo")
client.transport.write(b"bar")
server.transport.write(b"baz")
server.transport.write(b"quux")
server.transport.loseConnection()
def cbComplete(ignored):
self.assertEqual(
pumpCalls,
# The order here is somewhat arbitrary. The implementation
# happens to always deliver data to the client first.
[(client, [b"baz", b"quux", None]),
(server, [b"foo", b"bar"])])
finished.addCallback(cbComplete)
return finished
def test_identityPumpPolicy(self):
"""
L{identityPumpPolicy} is a pump policy which calls the target's
C{dataReceived} method one for each string in the queue passed to it.
"""
bytes = []
client = Protocol()
client.dataReceived = bytes.append
queue = loopback._LoopbackQueue()
queue.put(b"foo")
queue.put(b"bar")
queue.put(None)
loopback.identityPumpPolicy(queue, client)
self.assertEqual(bytes, [b"foo", b"bar"])
def test_collapsingPumpPolicy(self):
"""
L{collapsingPumpPolicy} is a pump policy which calls the target's
C{dataReceived} only once with all of the strings in the queue passed
to it joined together.
"""
bytes = []
client = Protocol()
client.dataReceived = bytes.append
queue = loopback._LoopbackQueue()
queue.put(b"foo")
queue.put(b"bar")
queue.put(None)
loopback.collapsingPumpPolicy(queue, client)
self.assertEqual(bytes, [b"foobar"])
class LoopbackTCPTestCase(LoopbackTestCaseMixin, unittest.TestCase):
loopbackFunc = staticmethod(loopback.loopbackTCP)
class LoopbackUNIXTestCase(LoopbackTestCaseMixin, unittest.TestCase):
loopbackFunc = staticmethod(loopback.loopbackUNIX)
if interfaces.IReactorUNIX(reactor, None) is None:
skip = "Current reactor does not support UNIX sockets"
elif _PY3:
skip = "UNIX sockets not supported on Python 3. See #6136"

View file

@ -0,0 +1,75 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.trial import unittest
from twisted.manhole import service
from twisted.spread.util import LocalAsRemote
class Dummy:
pass
class DummyTransport:
def getHost(self):
return 'INET', '127.0.0.1', 0
class DummyManholeClient(LocalAsRemote):
zero = 0
broker = Dummy()
broker.transport = DummyTransport()
def __init__(self):
self.messages = []
def console(self, messages):
self.messages.extend(messages)
def receiveExplorer(self, xplorer):
pass
def setZero(self):
self.zero = len(self.messages)
def getMessages(self):
return self.messages[self.zero:]
# local interface
sync_console = console
sync_receiveExplorer = receiveExplorer
sync_setZero = setZero
sync_getMessages = getMessages
class ManholeTest(unittest.TestCase):
"""Various tests for the manhole service.
Both the the importIdentity and importMain tests are known to fail
when the __name__ in the manhole namespace is set to certain
values.
"""
def setUp(self):
self.service = service.Service()
self.p = service.Perspective(self.service)
self.client = DummyManholeClient()
self.p.attached(self.client, None)
def test_importIdentity(self):
"""Making sure imported module is the same as one previously loaded.
"""
self.p.perspective_do("from twisted.manhole import service")
self.client.setZero()
self.p.perspective_do("int(service is sys.modules['twisted.manhole.service'])")
msg = self.client.getMessages()[0]
self.assertEqual(msg, ('result',"1\n"))
def test_importMain(self):
"""Trying to import __main__"""
self.client.setZero()
self.p.perspective_do("import __main__")
if self.client.getMessages():
msg = self.client.getMessages()[0]
if msg[0] in ("exception","stderr"):
self.fail(msg[1])
#if __name__=='__main__':
# unittest.main()

View file

@ -0,0 +1,699 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test the memcache client protocol.
"""
from twisted.internet.error import ConnectionDone
from twisted.protocols.memcache import MemCacheProtocol, NoSuchCommand
from twisted.protocols.memcache import ClientError, ServerError
from twisted.trial.unittest import TestCase
from twisted.test.proto_helpers import StringTransportWithDisconnection
from twisted.internet.task import Clock
from twisted.internet.defer import Deferred, gatherResults, TimeoutError
from twisted.internet.defer import DeferredList
class CommandMixin:
"""
Setup and tests for basic invocation of L{MemCacheProtocol} commands.
"""
def _test(self, d, send, recv, result):
"""
Helper test method to test the resulting C{Deferred} of a
L{MemCacheProtocol} command.
"""
raise NotImplementedError()
def test_get(self):
"""
L{MemCacheProtocol.get} returns a L{Deferred} which is called back with
the value and the flag associated with the given key if the server
returns a successful result.
"""
return self._test(
self.proto.get("foo"), "get foo\r\n",
"VALUE foo 0 3\r\nbar\r\nEND\r\n", (0, "bar"))
def test_emptyGet(self):
"""
Test getting a non-available key: it succeeds but return C{None} as
value and C{0} as flag.
"""
return self._test(
self.proto.get("foo"), "get foo\r\n",
"END\r\n", (0, None))
def test_getMultiple(self):
"""
L{MemCacheProtocol.getMultiple} returns a L{Deferred} which is called
back with a dictionary of flag, value for each given key.
"""
return self._test(
self.proto.getMultiple(['foo', 'cow']),
"get foo cow\r\n",
"VALUE foo 0 3\r\nbar\r\nVALUE cow 0 7\r\nchicken\r\nEND\r\n",
{'cow': (0, 'chicken'), 'foo': (0, 'bar')})
def test_getMultipleWithEmpty(self):
"""
When L{MemCacheProtocol.getMultiple} is called with non-available keys,
the corresponding tuples are (0, None).
"""
return self._test(
self.proto.getMultiple(['foo', 'cow']),
"get foo cow\r\n",
"VALUE cow 1 3\r\nbar\r\nEND\r\n",
{'cow': (1, 'bar'), 'foo': (0, None)})
def test_set(self):
"""
L{MemCacheProtocol.set} returns a L{Deferred} which is called back with
C{True} when the operation succeeds.
"""
return self._test(
self.proto.set("foo", "bar"),
"set foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)
def test_add(self):
"""
L{MemCacheProtocol.add} returns a L{Deferred} which is called back with
C{True} when the operation succeeds.
"""
return self._test(
self.proto.add("foo", "bar"),
"add foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)
def test_replace(self):
"""
L{MemCacheProtocol.replace} returns a L{Deferred} which is called back
with C{True} when the operation succeeds.
"""
return self._test(
self.proto.replace("foo", "bar"),
"replace foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)
def test_errorAdd(self):
"""
Test an erroneous add: if a L{MemCacheProtocol.add} is called but the
key already exists on the server, it returns a B{NOT STORED} answer,
which calls back the resulting L{Deferred} with C{False}.
"""
return self._test(
self.proto.add("foo", "bar"),
"add foo 0 0 3\r\nbar\r\n", "NOT STORED\r\n", False)
def test_errorReplace(self):
"""
Test an erroneous replace: if a L{MemCacheProtocol.replace} is called
but the key doesn't exist on the server, it returns a B{NOT STORED}
answer, which calls back the resulting L{Deferred} with C{False}.
"""
return self._test(
self.proto.replace("foo", "bar"),
"replace foo 0 0 3\r\nbar\r\n", "NOT STORED\r\n", False)
def test_delete(self):
"""
L{MemCacheProtocol.delete} returns a L{Deferred} which is called back
with C{True} when the server notifies a success.
"""
return self._test(
self.proto.delete("bar"), "delete bar\r\n", "DELETED\r\n", True)
def test_errorDelete(self):
"""
Test a error during a delete: if key doesn't exist on the server, it
returns a B{NOT FOUND} answer which calls back the resulting
L{Deferred} with C{False}.
"""
return self._test(
self.proto.delete("bar"), "delete bar\r\n", "NOT FOUND\r\n", False)
def test_increment(self):
"""
Test incrementing a variable: L{MemCacheProtocol.increment} returns a
L{Deferred} which is called back with the incremented value of the
given key.
"""
return self._test(
self.proto.increment("foo"), "incr foo 1\r\n", "4\r\n", 4)
def test_decrement(self):
"""
Test decrementing a variable: L{MemCacheProtocol.decrement} returns a
L{Deferred} which is called back with the decremented value of the
given key.
"""
return self._test(
self.proto.decrement("foo"), "decr foo 1\r\n", "5\r\n", 5)
def test_incrementVal(self):
"""
L{MemCacheProtocol.increment} takes an optional argument C{value} which
replaces the default value of 1 when specified.
"""
return self._test(
self.proto.increment("foo", 8), "incr foo 8\r\n", "4\r\n", 4)
def test_decrementVal(self):
"""
L{MemCacheProtocol.decrement} takes an optional argument C{value} which
replaces the default value of 1 when specified.
"""
return self._test(
self.proto.decrement("foo", 3), "decr foo 3\r\n", "5\r\n", 5)
def test_stats(self):
"""
Test retrieving server statistics via the L{MemCacheProtocol.stats}
command: it parses the data sent by the server and calls back the
resulting L{Deferred} with a dictionary of the received statistics.
"""
return self._test(
self.proto.stats(), "stats\r\n",
"STAT foo bar\r\nSTAT egg spam\r\nEND\r\n",
{"foo": "bar", "egg": "spam"})
def test_statsWithArgument(self):
"""
L{MemCacheProtocol.stats} takes an optional C{str} argument which,
if specified, is sent along with the I{STAT} command. The I{STAT}
responses from the server are parsed as key/value pairs and returned
as a C{dict} (as in the case where the argument is not specified).
"""
return self._test(
self.proto.stats("blah"), "stats blah\r\n",
"STAT foo bar\r\nSTAT egg spam\r\nEND\r\n",
{"foo": "bar", "egg": "spam"})
def test_version(self):
"""
Test version retrieval via the L{MemCacheProtocol.version} command: it
returns a L{Deferred} which is called back with the version sent by the
server.
"""
return self._test(
self.proto.version(), "version\r\n", "VERSION 1.1\r\n", "1.1")
def test_flushAll(self):
"""
L{MemCacheProtocol.flushAll} returns a L{Deferred} which is called back
with C{True} if the server acknowledges success.
"""
return self._test(
self.proto.flushAll(), "flush_all\r\n", "OK\r\n", True)
class MemCacheTestCase(CommandMixin, TestCase):
"""
Test client protocol class L{MemCacheProtocol}.
"""
def setUp(self):
"""
Create a memcache client, connect it to a string protocol, and make it
use a deterministic clock.
"""
self.proto = MemCacheProtocol()
self.clock = Clock()
self.proto.callLater = self.clock.callLater
self.transport = StringTransportWithDisconnection()
self.transport.protocol = self.proto
self.proto.makeConnection(self.transport)
def _test(self, d, send, recv, result):
"""
Implementation of C{_test} which checks that the command sends C{send}
data, and that upon reception of C{recv} the result is C{result}.
@param d: the resulting deferred from the memcache command.
@type d: C{Deferred}
@param send: the expected data to be sent.
@type send: C{str}
@param recv: the data to simulate as reception.
@type recv: C{str}
@param result: the expected result.
@type result: C{any}
"""
def cb(res):
self.assertEqual(res, result)
self.assertEqual(self.transport.value(), send)
d.addCallback(cb)
self.proto.dataReceived(recv)
return d
def test_invalidGetResponse(self):
"""
If the value returned doesn't match the expected key of the current
C{get} command, an error is raised in L{MemCacheProtocol.dataReceived}.
"""
self.proto.get("foo")
s = "spamegg"
self.assertRaises(
RuntimeError, self.proto.dataReceived,
"VALUE bar 0 %s\r\n%s\r\nEND\r\n" % (len(s), s))
def test_invalidMultipleGetResponse(self):
"""
If the value returned doesn't match one the expected keys of the
current multiple C{get} command, an error is raised error in
L{MemCacheProtocol.dataReceived}.
"""
self.proto.getMultiple(["foo", "bar"])
s = "spamegg"
self.assertRaises(
RuntimeError, self.proto.dataReceived,
"VALUE egg 0 %s\r\n%s\r\nEND\r\n" % (len(s), s))
def test_timeOut(self):
"""
Test the timeout on outgoing requests: when timeout is detected, all
current commands fail with a L{TimeoutError}, and the connection is
closed.
"""
d1 = self.proto.get("foo")
d2 = self.proto.get("bar")
d3 = Deferred()
self.proto.connectionLost = d3.callback
self.clock.advance(self.proto.persistentTimeOut)
self.assertFailure(d1, TimeoutError)
self.assertFailure(d2, TimeoutError)
def checkMessage(error):
self.assertEqual(str(error), "Connection timeout")
d1.addCallback(checkMessage)
self.assertFailure(d3, ConnectionDone)
return gatherResults([d1, d2, d3])
def test_timeoutRemoved(self):
"""
When a request gets a response, no pending timeout call remains around.
"""
d = self.proto.get("foo")
self.clock.advance(self.proto.persistentTimeOut - 1)
self.proto.dataReceived("VALUE foo 0 3\r\nbar\r\nEND\r\n")
def check(result):
self.assertEqual(result, (0, "bar"))
self.assertEqual(len(self.clock.calls), 0)
d.addCallback(check)
return d
def test_timeOutRaw(self):
"""
Test the timeout when raw mode was started: the timeout is not reset
until all the data has been received, so we can have a L{TimeoutError}
when waiting for raw data.
"""
d1 = self.proto.get("foo")
d2 = Deferred()
self.proto.connectionLost = d2.callback
self.proto.dataReceived("VALUE foo 0 10\r\n12345")
self.clock.advance(self.proto.persistentTimeOut)
self.assertFailure(d1, TimeoutError)
self.assertFailure(d2, ConnectionDone)
return gatherResults([d1, d2])
def test_timeOutStat(self):
"""
Test the timeout when stat command has started: the timeout is not
reset until the final B{END} is received.
"""
d1 = self.proto.stats()
d2 = Deferred()
self.proto.connectionLost = d2.callback
self.proto.dataReceived("STAT foo bar\r\n")
self.clock.advance(self.proto.persistentTimeOut)
self.assertFailure(d1, TimeoutError)
self.assertFailure(d2, ConnectionDone)
return gatherResults([d1, d2])
def test_timeoutPipelining(self):
"""
When two requests are sent, a timeout call remains around for the
second request, and its timeout time is correct.
"""
d1 = self.proto.get("foo")
d2 = self.proto.get("bar")
d3 = Deferred()
self.proto.connectionLost = d3.callback
self.clock.advance(self.proto.persistentTimeOut - 1)
self.proto.dataReceived("VALUE foo 0 3\r\nbar\r\nEND\r\n")
def check(result):
self.assertEqual(result, (0, "bar"))
self.assertEqual(len(self.clock.calls), 1)
for i in range(self.proto.persistentTimeOut):
self.clock.advance(1)
return self.assertFailure(d2, TimeoutError).addCallback(checkTime)
def checkTime(ignored):
# Check that the timeout happened C{self.proto.persistentTimeOut}
# after the last response
self.assertEqual(
self.clock.seconds(), 2 * self.proto.persistentTimeOut - 1)
d1.addCallback(check)
self.assertFailure(d3, ConnectionDone)
return d1
def test_timeoutNotReset(self):
"""
Check that timeout is not resetted for every command, but keep the
timeout from the first command without response.
"""
d1 = self.proto.get("foo")
d3 = Deferred()
self.proto.connectionLost = d3.callback
self.clock.advance(self.proto.persistentTimeOut - 1)
d2 = self.proto.get("bar")
self.clock.advance(1)
self.assertFailure(d1, TimeoutError)
self.assertFailure(d2, TimeoutError)
self.assertFailure(d3, ConnectionDone)
return gatherResults([d1, d2, d3])
def test_timeoutCleanDeferreds(self):
"""
C{timeoutConnection} cleans the list of commands that it fires with
C{TimeoutError}: C{connectionLost} doesn't try to fire them again, but
sets the disconnected state so that future commands fail with a
C{RuntimeError}.
"""
d1 = self.proto.get("foo")
self.clock.advance(self.proto.persistentTimeOut)
self.assertFailure(d1, TimeoutError)
d2 = self.proto.get("bar")
self.assertFailure(d2, RuntimeError)
return gatherResults([d1, d2])
def test_connectionLost(self):
"""
When disconnection occurs while commands are still outstanding, the
commands fail.
"""
d1 = self.proto.get("foo")
d2 = self.proto.get("bar")
self.transport.loseConnection()
done = DeferredList([d1, d2], consumeErrors=True)
def checkFailures(results):
for success, result in results:
self.assertFalse(success)
result.trap(ConnectionDone)
return done.addCallback(checkFailures)
def test_tooLongKey(self):
"""
An error is raised when trying to use a too long key: the called
command returns a L{Deferred} which fails with a L{ClientError}.
"""
d1 = self.assertFailure(self.proto.set("a" * 500, "bar"), ClientError)
d2 = self.assertFailure(self.proto.increment("a" * 500), ClientError)
d3 = self.assertFailure(self.proto.get("a" * 500), ClientError)
d4 = self.assertFailure(
self.proto.append("a" * 500, "bar"), ClientError)
d5 = self.assertFailure(
self.proto.prepend("a" * 500, "bar"), ClientError)
d6 = self.assertFailure(
self.proto.getMultiple(["foo", "a" * 500]), ClientError)
return gatherResults([d1, d2, d3, d4, d5, d6])
def test_invalidCommand(self):
"""
When an unknown command is sent directly (not through public API), the
server answers with an B{ERROR} token, and the command fails with
L{NoSuchCommand}.
"""
d = self.proto._set("egg", "foo", "bar", 0, 0, "")
self.assertEqual(self.transport.value(), "egg foo 0 0 3\r\nbar\r\n")
self.assertFailure(d, NoSuchCommand)
self.proto.dataReceived("ERROR\r\n")
return d
def test_clientError(self):
"""
Test the L{ClientError} error: when the server sends a B{CLIENT_ERROR}
token, the originating command fails with L{ClientError}, and the error
contains the text sent by the server.
"""
a = "eggspamm"
d = self.proto.set("foo", a)
self.assertEqual(self.transport.value(),
"set foo 0 0 8\r\neggspamm\r\n")
self.assertFailure(d, ClientError)
def check(err):
self.assertEqual(str(err), "We don't like egg and spam")
d.addCallback(check)
self.proto.dataReceived("CLIENT_ERROR We don't like egg and spam\r\n")
return d
def test_serverError(self):
"""
Test the L{ServerError} error: when the server sends a B{SERVER_ERROR}
token, the originating command fails with L{ServerError}, and the error
contains the text sent by the server.
"""
a = "eggspamm"
d = self.proto.set("foo", a)
self.assertEqual(self.transport.value(),
"set foo 0 0 8\r\neggspamm\r\n")
self.assertFailure(d, ServerError)
def check(err):
self.assertEqual(str(err), "zomg")
d.addCallback(check)
self.proto.dataReceived("SERVER_ERROR zomg\r\n")
return d
def test_unicodeKey(self):
"""
Using a non-string key as argument to commands raises an error.
"""
d1 = self.assertFailure(self.proto.set(u"foo", "bar"), ClientError)
d2 = self.assertFailure(self.proto.increment(u"egg"), ClientError)
d3 = self.assertFailure(self.proto.get(1), ClientError)
d4 = self.assertFailure(self.proto.delete(u"bar"), ClientError)
d5 = self.assertFailure(self.proto.append(u"foo", "bar"), ClientError)
d6 = self.assertFailure(self.proto.prepend(u"foo", "bar"), ClientError)
d7 = self.assertFailure(
self.proto.getMultiple(["egg", 1]), ClientError)
return gatherResults([d1, d2, d3, d4, d5, d6, d7])
def test_unicodeValue(self):
"""
Using a non-string value raises an error.
"""
return self.assertFailure(self.proto.set("foo", u"bar"), ClientError)
def test_pipelining(self):
"""
Multiple requests can be sent subsequently to the server, and the
protocol orders the responses correctly and dispatch to the
corresponding client command.
"""
d1 = self.proto.get("foo")
d1.addCallback(self.assertEqual, (0, "bar"))
d2 = self.proto.set("bar", "spamspamspam")
d2.addCallback(self.assertEqual, True)
d3 = self.proto.get("egg")
d3.addCallback(self.assertEqual, (0, "spam"))
self.assertEqual(
self.transport.value(),
"get foo\r\nset bar 0 0 12\r\nspamspamspam\r\nget egg\r\n")
self.proto.dataReceived("VALUE foo 0 3\r\nbar\r\nEND\r\n"
"STORED\r\n"
"VALUE egg 0 4\r\nspam\r\nEND\r\n")
return gatherResults([d1, d2, d3])
def test_getInChunks(self):
"""
If the value retrieved by a C{get} arrive in chunks, the protocol
is able to reconstruct it and to produce the good value.
"""
d = self.proto.get("foo")
d.addCallback(self.assertEqual, (0, "0123456789"))
self.assertEqual(self.transport.value(), "get foo\r\n")
self.proto.dataReceived("VALUE foo 0 10\r\n0123456")
self.proto.dataReceived("789")
self.proto.dataReceived("\r\nEND")
self.proto.dataReceived("\r\n")
return d
def test_append(self):
"""
L{MemCacheProtocol.append} behaves like a L{MemCacheProtocol.set}
method: it returns a L{Deferred} which is called back with C{True} when
the operation succeeds.
"""
return self._test(
self.proto.append("foo", "bar"),
"append foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)
def test_prepend(self):
"""
L{MemCacheProtocol.prepend} behaves like a L{MemCacheProtocol.set}
method: it returns a L{Deferred} which is called back with C{True} when
the operation succeeds.
"""
return self._test(
self.proto.prepend("foo", "bar"),
"prepend foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)
def test_gets(self):
"""
L{MemCacheProtocol.get} handles an additional cas result when
C{withIdentifier} is C{True} and forward it in the resulting
L{Deferred}.
"""
return self._test(
self.proto.get("foo", True), "gets foo\r\n",
"VALUE foo 0 3 1234\r\nbar\r\nEND\r\n", (0, "1234", "bar"))
def test_emptyGets(self):
"""
Test getting a non-available key with gets: it succeeds but return
C{None} as value, C{0} as flag and an empty cas value.
"""
return self._test(
self.proto.get("foo", True), "gets foo\r\n",
"END\r\n", (0, "", None))
def test_getsMultiple(self):
"""
L{MemCacheProtocol.getMultiple} handles an additional cas field in the
returned tuples if C{withIdentifier} is C{True}.
"""
return self._test(
self.proto.getMultiple(["foo", "bar"], True),
"gets foo bar\r\n",
"VALUE foo 0 3 1234\r\negg\r\n"
"VALUE bar 0 4 2345\r\nspam\r\nEND\r\n",
{'bar': (0, '2345', 'spam'), 'foo': (0, '1234', 'egg')})
def test_getsMultipleWithEmpty(self):
"""
When getting a non-available key with L{MemCacheProtocol.getMultiple}
when C{withIdentifier} is C{True}, the other keys are retrieved
correctly, and the non-available key gets a tuple of C{0} as flag,
C{None} as value, and an empty cas value.
"""
return self._test(
self.proto.getMultiple(["foo", "bar"], True),
"gets foo bar\r\n",
"VALUE foo 0 3 1234\r\negg\r\nEND\r\n",
{'bar': (0, '', None), 'foo': (0, '1234', 'egg')})
def test_checkAndSet(self):
"""
L{MemCacheProtocol.checkAndSet} passes an additional cas identifier
that the server handles to check if the data has to be updated.
"""
return self._test(
self.proto.checkAndSet("foo", "bar", cas="1234"),
"cas foo 0 0 3 1234\r\nbar\r\n", "STORED\r\n", True)
def test_casUnknowKey(self):
"""
When L{MemCacheProtocol.checkAndSet} response is C{EXISTS}, the
resulting L{Deferred} fires with C{False}.
"""
return self._test(
self.proto.checkAndSet("foo", "bar", cas="1234"),
"cas foo 0 0 3 1234\r\nbar\r\n", "EXISTS\r\n", False)
class CommandFailureTests(CommandMixin, TestCase):
"""
Tests for correct failure of commands on a disconnected
L{MemCacheProtocol}.
"""
def setUp(self):
"""
Create a disconnected memcache client, using a deterministic clock.
"""
self.proto = MemCacheProtocol()
self.clock = Clock()
self.proto.callLater = self.clock.callLater
self.transport = StringTransportWithDisconnection()
self.transport.protocol = self.proto
self.proto.makeConnection(self.transport)
self.transport.loseConnection()
def _test(self, d, send, recv, result):
"""
Implementation of C{_test} which checks that the command fails with
C{RuntimeError} because the transport is disconnected. All the
parameters except C{d} are ignored.
"""
return self.assertFailure(d, RuntimeError)

View file

@ -0,0 +1,514 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for twisted.python.modules, abstract access to imported or importable
objects.
"""
import sys
import itertools
import zipfile
import compileall
import twisted
from twisted.trial.unittest import TestCase
from twisted.python import modules
from twisted.python.filepath import FilePath
from twisted.python.reflect import namedAny
from twisted.python.test.modules_helpers import TwistedModulesMixin
from twisted.python.test.test_zippath import zipit
class TwistedModulesTestCase(TwistedModulesMixin, TestCase):
"""
Base class for L{modules} test cases.
"""
def findByIteration(self, modname, where=modules, importPackages=False):
"""
You don't ever actually want to do this, so it's not in the public
API, but sometimes we want to compare the result of an iterative call
with a lookup call and make sure they're the same for test purposes.
"""
for modinfo in where.walkModules(importPackages=importPackages):
if modinfo.name == modname:
return modinfo
self.fail("Unable to find module %r through iteration." % (modname,))
class BasicTests(TwistedModulesTestCase):
def test_namespacedPackages(self):
"""
Duplicate packages are not yielded when iterating over namespace
packages.
"""
# Force pkgutil to be loaded already, since the probe package being
# created depends on it, and the replaceSysPath call below will make
# pretty much everything unimportable.
__import__('pkgutil')
namespaceBoilerplate = (
'import pkgutil; '
'__path__ = pkgutil.extend_path(__path__, __name__)')
# Create two temporary directories with packages:
#
# entry:
# test_package/
# __init__.py
# nested_package/
# __init__.py
# module.py
#
# anotherEntry:
# test_package/
# __init__.py
# nested_package/
# __init__.py
# module2.py
#
# test_package and test_package.nested_package are namespace packages,
# and when both of these are in sys.path, test_package.nested_package
# should become a virtual package containing both "module" and
# "module2"
entry = self.pathEntryWithOnePackage()
testPackagePath = entry.child('test_package')
testPackagePath.child('__init__.py').setContent(namespaceBoilerplate)
nestedEntry = testPackagePath.child('nested_package')
nestedEntry.makedirs()
nestedEntry.child('__init__.py').setContent(namespaceBoilerplate)
nestedEntry.child('module.py').setContent('')
anotherEntry = self.pathEntryWithOnePackage()
anotherPackagePath = anotherEntry.child('test_package')
anotherPackagePath.child('__init__.py').setContent(namespaceBoilerplate)
anotherNestedEntry = anotherPackagePath.child('nested_package')
anotherNestedEntry.makedirs()
anotherNestedEntry.child('__init__.py').setContent(namespaceBoilerplate)
anotherNestedEntry.child('module2.py').setContent('')
self.replaceSysPath([entry.path, anotherEntry.path])
module = modules.getModule('test_package')
# We have to use importPackages=True in order to resolve the namespace
# packages, so we remove the imported packages from sys.modules after
# walking
try:
walkedNames = [
mod.name for mod in module.walkModules(importPackages=True)]
finally:
for module in sys.modules.keys():
if module.startswith('test_package'):
del sys.modules[module]
expected = [
'test_package',
'test_package.nested_package',
'test_package.nested_package.module',
'test_package.nested_package.module2',
]
self.assertEqual(walkedNames, expected)
def test_unimportablePackageGetItem(self):
"""
If a package has been explicitly forbidden from importing by setting a
C{None} key in sys.modules under its name,
L{modules.PythonPath.__getitem__} should still be able to retrieve an
unloaded L{modules.PythonModule} for that package.
"""
shouldNotLoad = []
path = modules.PythonPath(sysPath=[self.pathEntryWithOnePackage().path],
moduleLoader=shouldNotLoad.append,
importerCache={},
sysPathHooks={},
moduleDict={'test_package': None})
self.assertEqual(shouldNotLoad, [])
self.assertEqual(path['test_package'].isLoaded(), False)
def test_unimportablePackageWalkModules(self):
"""
If a package has been explicitly forbidden from importing by setting a
C{None} key in sys.modules under its name, L{modules.walkModules} should
still be able to retrieve an unloaded L{modules.PythonModule} for that
package.
"""
existentPath = self.pathEntryWithOnePackage()
self.replaceSysPath([existentPath.path])
self.replaceSysModules({"test_package": None})
walked = list(modules.walkModules())
self.assertEqual([m.name for m in walked],
["test_package"])
self.assertEqual(walked[0].isLoaded(), False)
def test_nonexistentPaths(self):
"""
Verify that L{modules.walkModules} ignores entries in sys.path which
do not exist in the filesystem.
"""
existentPath = self.pathEntryWithOnePackage()
nonexistentPath = FilePath(self.mktemp())
self.failIf(nonexistentPath.exists())
self.replaceSysPath([existentPath.path])
expected = [modules.getModule("test_package")]
beforeModules = list(modules.walkModules())
sys.path.append(nonexistentPath.path)
afterModules = list(modules.walkModules())
self.assertEqual(beforeModules, expected)
self.assertEqual(afterModules, expected)
def test_nonDirectoryPaths(self):
"""
Verify that L{modules.walkModules} ignores entries in sys.path which
refer to regular files in the filesystem.
"""
existentPath = self.pathEntryWithOnePackage()
nonDirectoryPath = FilePath(self.mktemp())
self.failIf(nonDirectoryPath.exists())
nonDirectoryPath.setContent("zip file or whatever\n")
self.replaceSysPath([existentPath.path])
beforeModules = list(modules.walkModules())
sys.path.append(nonDirectoryPath.path)
afterModules = list(modules.walkModules())
self.assertEqual(beforeModules, afterModules)
def test_twistedShowsUp(self):
"""
Scrounge around in the top-level module namespace and make sure that
Twisted shows up, and that the module thusly obtained is the same as
the module that we find when we look for it explicitly by name.
"""
self.assertEqual(modules.getModule('twisted'),
self.findByIteration("twisted"))
def test_dottedNames(self):
"""
Verify that the walkModules APIs will give us back subpackages, not just
subpackages.
"""
self.assertEqual(
modules.getModule('twisted.python'),
self.findByIteration("twisted.python",
where=modules.getModule('twisted')))
def test_onlyTopModules(self):
"""
Verify that the iterModules API will only return top-level modules and
packages, not submodules or subpackages.
"""
for module in modules.iterModules():
self.failIf(
'.' in module.name,
"no nested modules should be returned from iterModules: %r"
% (module.filePath))
def test_loadPackagesAndModules(self):
"""
Verify that we can locate and load packages, modules, submodules, and
subpackages.
"""
for n in ['os',
'twisted',
'twisted.python',
'twisted.python.reflect']:
m = namedAny(n)
self.failUnlessIdentical(
modules.getModule(n).load(),
m)
self.failUnlessIdentical(
self.findByIteration(n).load(),
m)
def test_pathEntriesOnPath(self):
"""
Verify that path entries discovered via module loading are, in fact, on
sys.path somewhere.
"""
for n in ['os',
'twisted',
'twisted.python',
'twisted.python.reflect']:
self.failUnlessIn(
modules.getModule(n).pathEntry.filePath.path,
sys.path)
def test_alwaysPreferPy(self):
"""
Verify that .py files will always be preferred to .pyc files, regardless of
directory listing order.
"""
mypath = FilePath(self.mktemp())
mypath.createDirectory()
pp = modules.PythonPath(sysPath=[mypath.path])
originalSmartPath = pp._smartPath
def _evilSmartPath(pathName):
o = originalSmartPath(pathName)
originalChildren = o.children
def evilChildren():
# normally this order is random; let's make sure it always
# comes up .pyc-first.
x = originalChildren()
x.sort()
x.reverse()
return x
o.children = evilChildren
return o
mypath.child("abcd.py").setContent('\n')
compileall.compile_dir(mypath.path, quiet=True)
# sanity check
self.assertEqual(len(mypath.children()), 2)
pp._smartPath = _evilSmartPath
self.assertEqual(pp['abcd'].filePath,
mypath.child('abcd.py'))
def test_packageMissingPath(self):
"""
A package can delete its __path__ for some reasons,
C{modules.PythonPath} should be able to deal with it.
"""
mypath = FilePath(self.mktemp())
mypath.createDirectory()
pp = modules.PythonPath(sysPath=[mypath.path])
subpath = mypath.child("abcd")
subpath.createDirectory()
subpath.child("__init__.py").setContent('del __path__\n')
sys.path.append(mypath.path)
__import__("abcd")
try:
l = list(pp.walkModules())
self.assertEqual(len(l), 1)
self.assertEqual(l[0].name, 'abcd')
finally:
del sys.modules['abcd']
sys.path.remove(mypath.path)
class PathModificationTest(TwistedModulesTestCase):
"""
These tests share setup/cleanup behavior of creating a dummy package and
stuffing some code in it.
"""
_serialnum = itertools.count().next # used to generate serial numbers for
# package names.
def setUp(self):
self.pathExtensionName = self.mktemp()
self.pathExtension = FilePath(self.pathExtensionName)
self.pathExtension.createDirectory()
self.packageName = "pyspacetests%d" % (self._serialnum(),)
self.packagePath = self.pathExtension.child(self.packageName)
self.packagePath.createDirectory()
self.packagePath.child("__init__.py").setContent("")
self.packagePath.child("a.py").setContent("")
self.packagePath.child("b.py").setContent("")
self.packagePath.child("c__init__.py").setContent("")
self.pathSetUp = False
def _setupSysPath(self):
assert not self.pathSetUp
self.pathSetUp = True
sys.path.append(self.pathExtensionName)
def _underUnderPathTest(self, doImport=True):
moddir2 = self.mktemp()
fpmd = FilePath(moddir2)
fpmd.createDirectory()
fpmd.child("foozle.py").setContent("x = 123\n")
self.packagePath.child("__init__.py").setContent(
"__path__.append(%r)\n" % (moddir2,))
# Cut here
self._setupSysPath()
modinfo = modules.getModule(self.packageName)
self.assertEqual(
self.findByIteration(self.packageName+".foozle", modinfo,
importPackages=doImport),
modinfo['foozle'])
self.assertEqual(modinfo['foozle'].load().x, 123)
def test_underUnderPathAlreadyImported(self):
"""
Verify that iterModules will honor the __path__ of already-loaded packages.
"""
self._underUnderPathTest()
def test_underUnderPathNotAlreadyImported(self):
"""
Verify that iterModules will honor the __path__ of already-loaded packages.
"""
self._underUnderPathTest(False)
test_underUnderPathNotAlreadyImported.todo = (
"This may be impossible but it sure would be nice.")
def _listModules(self):
pkginfo = modules.getModule(self.packageName)
nfni = [modinfo.name.split(".")[-1] for modinfo in
pkginfo.iterModules()]
nfni.sort()
self.assertEqual(nfni, ['a', 'b', 'c__init__'])
def test_listingModules(self):
"""
Make sure the module list comes back as we expect from iterModules on a
package, whether zipped or not.
"""
self._setupSysPath()
self._listModules()
def test_listingModulesAlreadyImported(self):
"""
Make sure the module list comes back as we expect from iterModules on a
package, whether zipped or not, even if the package has already been
imported.
"""
self._setupSysPath()
namedAny(self.packageName)
self._listModules()
def tearDown(self):
# Intentionally using 'assert' here, this is not a test assertion, this
# is just an "oh fuck what is going ON" assertion. -glyph
if self.pathSetUp:
HORK = "path cleanup failed: don't be surprised if other tests break"
assert sys.path.pop() is self.pathExtensionName, HORK+", 1"
assert self.pathExtensionName not in sys.path, HORK+", 2"
class RebindingTest(PathModificationTest):
"""
These tests verify that the default path interrogation API works properly
even when sys.path has been rebound to a different object.
"""
def _setupSysPath(self):
assert not self.pathSetUp
self.pathSetUp = True
self.savedSysPath = sys.path
sys.path = sys.path[:]
sys.path.append(self.pathExtensionName)
def tearDown(self):
"""
Clean up sys.path by re-binding our original object.
"""
if self.pathSetUp:
sys.path = self.savedSysPath
class ZipPathModificationTest(PathModificationTest):
def _setupSysPath(self):
assert not self.pathSetUp
zipit(self.pathExtensionName, self.pathExtensionName+'.zip')
self.pathExtensionName += '.zip'
assert zipfile.is_zipfile(self.pathExtensionName)
PathModificationTest._setupSysPath(self)
class PythonPathTestCase(TestCase):
"""
Tests for the class which provides the implementation for all of the
public API of L{twisted.python.modules}, L{PythonPath}.
"""
def test_unhandledImporter(self):
"""
Make sure that the behavior when encountering an unknown importer
type is not catastrophic failure.
"""
class SecretImporter(object):
pass
def hook(name):
return SecretImporter()
syspath = ['example/path']
sysmodules = {}
syshooks = [hook]
syscache = {}
def sysloader(name):
return None
space = modules.PythonPath(
syspath, sysmodules, syshooks, syscache, sysloader)
entries = list(space.iterEntries())
self.assertEqual(len(entries), 1)
self.assertRaises(KeyError, lambda: entries[0]['module'])
def test_inconsistentImporterCache(self):
"""
If the path a module loaded with L{PythonPath.__getitem__} is not
present in the path importer cache, a warning is emitted, but the
L{PythonModule} is returned as usual.
"""
space = modules.PythonPath([], sys.modules, [], {})
thisModule = space[__name__]
warnings = self.flushWarnings([self.test_inconsistentImporterCache])
self.assertEqual(warnings[0]['category'], UserWarning)
self.assertEqual(
warnings[0]['message'],
FilePath(twisted.__file__).parent().dirname() +
" (for module " + __name__ + ") not in path importer cache "
"(PEP 302 violation - check your local configuration).")
self.assertEqual(len(warnings), 1)
self.assertEqual(thisModule.name, __name__)
def test_containsModule(self):
"""
L{PythonPath} implements the C{in} operator so that when it is the
right-hand argument and the name of a module which exists on that
L{PythonPath} is the left-hand argument, the result is C{True}.
"""
thePath = modules.PythonPath()
self.assertIn('os', thePath)
def test_doesntContainModule(self):
"""
L{PythonPath} implements the C{in} operator so that when it is the
right-hand argument and the name of a module which does not exist on
that L{PythonPath} is the left-hand argument, the result is C{False}.
"""
thePath = modules.PythonPath()
self.assertNotIn('bogusModule', thePath)

View file

@ -0,0 +1,164 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.python.monkey}.
"""
from __future__ import division, absolute_import
from twisted.trial import unittest
from twisted.python.monkey import MonkeyPatcher
class TestObj:
def __init__(self):
self.foo = 'foo value'
self.bar = 'bar value'
self.baz = 'baz value'
class MonkeyPatcherTest(unittest.SynchronousTestCase):
"""
Tests for L{MonkeyPatcher} monkey-patching class.
"""
def setUp(self):
self.testObject = TestObj()
self.originalObject = TestObj()
self.monkeyPatcher = MonkeyPatcher()
def test_empty(self):
"""
A monkey patcher without patches shouldn't change a thing.
"""
self.monkeyPatcher.patch()
# We can't assert that all state is unchanged, but at least we can
# check our test object.
self.assertEqual(self.originalObject.foo, self.testObject.foo)
self.assertEqual(self.originalObject.bar, self.testObject.bar)
self.assertEqual(self.originalObject.baz, self.testObject.baz)
def test_constructWithPatches(self):
"""
Constructing a L{MonkeyPatcher} with patches should add all of the
given patches to the patch list.
"""
patcher = MonkeyPatcher((self.testObject, 'foo', 'haha'),
(self.testObject, 'bar', 'hehe'))
patcher.patch()
self.assertEqual('haha', self.testObject.foo)
self.assertEqual('hehe', self.testObject.bar)
self.assertEqual(self.originalObject.baz, self.testObject.baz)
def test_patchExisting(self):
"""
Patching an attribute that exists sets it to the value defined in the
patch.
"""
self.monkeyPatcher.addPatch(self.testObject, 'foo', 'haha')
self.monkeyPatcher.patch()
self.assertEqual(self.testObject.foo, 'haha')
def test_patchNonExisting(self):
"""
Patching a non-existing attribute fails with an C{AttributeError}.
"""
self.monkeyPatcher.addPatch(self.testObject, 'nowhere',
'blow up please')
self.assertRaises(AttributeError, self.monkeyPatcher.patch)
def test_patchAlreadyPatched(self):
"""
Adding a patch for an object and attribute that already have a patch
overrides the existing patch.
"""
self.monkeyPatcher.addPatch(self.testObject, 'foo', 'blah')
self.monkeyPatcher.addPatch(self.testObject, 'foo', 'BLAH')
self.monkeyPatcher.patch()
self.assertEqual(self.testObject.foo, 'BLAH')
self.monkeyPatcher.restore()
self.assertEqual(self.testObject.foo, self.originalObject.foo)
def test_restoreTwiceIsANoOp(self):
"""
Restoring an already-restored monkey patch is a no-op.
"""
self.monkeyPatcher.addPatch(self.testObject, 'foo', 'blah')
self.monkeyPatcher.patch()
self.monkeyPatcher.restore()
self.assertEqual(self.testObject.foo, self.originalObject.foo)
self.monkeyPatcher.restore()
self.assertEqual(self.testObject.foo, self.originalObject.foo)
def test_runWithPatchesDecoration(self):
"""
runWithPatches should run the given callable, passing in all arguments
and keyword arguments, and return the return value of the callable.
"""
log = []
def f(a, b, c=None):
log.append((a, b, c))
return 'foo'
result = self.monkeyPatcher.runWithPatches(f, 1, 2, c=10)
self.assertEqual('foo', result)
self.assertEqual([(1, 2, 10)], log)
def test_repeatedRunWithPatches(self):
"""
We should be able to call the same function with runWithPatches more
than once. All patches should apply for each call.
"""
def f():
return (self.testObject.foo, self.testObject.bar,
self.testObject.baz)
self.monkeyPatcher.addPatch(self.testObject, 'foo', 'haha')
result = self.monkeyPatcher.runWithPatches(f)
self.assertEqual(
('haha', self.originalObject.bar, self.originalObject.baz), result)
result = self.monkeyPatcher.runWithPatches(f)
self.assertEqual(
('haha', self.originalObject.bar, self.originalObject.baz),
result)
def test_runWithPatchesRestores(self):
"""
C{runWithPatches} should restore the original values after the function
has executed.
"""
self.monkeyPatcher.addPatch(self.testObject, 'foo', 'haha')
self.assertEqual(self.originalObject.foo, self.testObject.foo)
self.monkeyPatcher.runWithPatches(lambda: None)
self.assertEqual(self.originalObject.foo, self.testObject.foo)
def test_runWithPatchesRestoresOnException(self):
"""
Test runWithPatches restores the original values even when the function
raises an exception.
"""
def _():
self.assertEqual(self.testObject.foo, 'haha')
self.assertEqual(self.testObject.bar, 'blahblah')
raise RuntimeError("Something went wrong!")
self.monkeyPatcher.addPatch(self.testObject, 'foo', 'haha')
self.monkeyPatcher.addPatch(self.testObject, 'bar', 'blahblah')
self.assertRaises(RuntimeError, self.monkeyPatcher.runWithPatches, _)
self.assertEqual(self.testObject.foo, self.originalObject.foo)
self.assertEqual(self.testObject.bar, self.originalObject.bar)

View file

@ -0,0 +1,435 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.cred}, now with 30% more starch.
"""
import hmac
from zope.interface import implements, Interface
from twisted.trial import unittest
from twisted.cred import portal, checkers, credentials, error
from twisted.python import components
from twisted.internet import defer
from twisted.internet.defer import deferredGenerator as dG, waitForDeferred as wFD
try:
from crypt import crypt
except ImportError:
crypt = None
try:
from twisted.cred import pamauth
except ImportError:
pamauth = None
class ITestable(Interface):
pass
class TestAvatar:
def __init__(self, name):
self.name = name
self.loggedIn = False
self.loggedOut = False
def login(self):
assert not self.loggedIn
self.loggedIn = True
def logout(self):
self.loggedOut = True
class Testable(components.Adapter):
implements(ITestable)
# components.Interface(TestAvatar).adaptWith(Testable, ITestable)
components.registerAdapter(Testable, TestAvatar, ITestable)
class IDerivedCredentials(credentials.IUsernamePassword):
pass
class DerivedCredentials(object):
implements(IDerivedCredentials, ITestable)
def __init__(self, username, password):
self.username = username
self.password = password
def checkPassword(self, password):
return password == self.password
class TestRealm:
implements(portal.IRealm)
def __init__(self):
self.avatars = {}
def requestAvatar(self, avatarId, mind, *interfaces):
if avatarId in self.avatars:
avatar = self.avatars[avatarId]
else:
avatar = TestAvatar(avatarId)
self.avatars[avatarId] = avatar
avatar.login()
return (interfaces[0], interfaces[0](avatar),
avatar.logout)
class NewCredTest(unittest.TestCase):
def setUp(self):
r = self.realm = TestRealm()
p = self.portal = portal.Portal(r)
up = self.checker = checkers.InMemoryUsernamePasswordDatabaseDontUse()
up.addUser("bob", "hello")
p.registerChecker(up)
def testListCheckers(self):
expected = [credentials.IUsernamePassword, credentials.IUsernameHashedPassword]
got = self.portal.listCredentialsInterfaces()
expected.sort()
got.sort()
self.assertEqual(got, expected)
def testBasicLogin(self):
l = []; f = []
self.portal.login(credentials.UsernamePassword("bob", "hello"),
self, ITestable).addCallback(
l.append).addErrback(f.append)
if f:
raise f[0]
# print l[0].getBriefTraceback()
iface, impl, logout = l[0]
# whitebox
self.assertEqual(iface, ITestable)
self.failUnless(iface.providedBy(impl),
"%s does not implement %s" % (impl, iface))
# greybox
self.failUnless(impl.original.loggedIn)
self.failUnless(not impl.original.loggedOut)
logout()
self.failUnless(impl.original.loggedOut)
def test_derivedInterface(self):
"""
Login with credentials implementing an interface inheriting from an
interface registered with a checker (but not itself registered).
"""
l = []
f = []
self.portal.login(DerivedCredentials("bob", "hello"), self, ITestable
).addCallback(l.append
).addErrback(f.append)
if f:
raise f[0]
iface, impl, logout = l[0]
# whitebox
self.assertEqual(iface, ITestable)
self.failUnless(iface.providedBy(impl),
"%s does not implement %s" % (impl, iface))
# greybox
self.failUnless(impl.original.loggedIn)
self.failUnless(not impl.original.loggedOut)
logout()
self.failUnless(impl.original.loggedOut)
def testFailedLogin(self):
l = []
self.portal.login(credentials.UsernamePassword("bob", "h3llo"),
self, ITestable).addErrback(
lambda x: x.trap(error.UnauthorizedLogin)).addCallback(l.append)
self.failUnless(l)
self.assertEqual(error.UnauthorizedLogin, l[0])
def testFailedLoginName(self):
l = []
self.portal.login(credentials.UsernamePassword("jay", "hello"),
self, ITestable).addErrback(
lambda x: x.trap(error.UnauthorizedLogin)).addCallback(l.append)
self.failUnless(l)
self.assertEqual(error.UnauthorizedLogin, l[0])
class CramMD5CredentialsTestCase(unittest.TestCase):
def testIdempotentChallenge(self):
c = credentials.CramMD5Credentials()
chal = c.getChallenge()
self.assertEqual(chal, c.getChallenge())
def testCheckPassword(self):
c = credentials.CramMD5Credentials()
chal = c.getChallenge()
c.response = hmac.HMAC('secret', chal).hexdigest()
self.failUnless(c.checkPassword('secret'))
def testWrongPassword(self):
c = credentials.CramMD5Credentials()
self.failIf(c.checkPassword('secret'))
class OnDiskDatabaseTestCase(unittest.TestCase):
users = [
('user1', 'pass1'),
('user2', 'pass2'),
('user3', 'pass3'),
]
def testUserLookup(self):
dbfile = self.mktemp()
db = checkers.FilePasswordDB(dbfile)
f = file(dbfile, 'w')
for (u, p) in self.users:
f.write('%s:%s\n' % (u, p))
f.close()
for (u, p) in self.users:
self.failUnlessRaises(KeyError, db.getUser, u.upper())
self.assertEqual(db.getUser(u), (u, p))
def testCaseInSensitivity(self):
dbfile = self.mktemp()
db = checkers.FilePasswordDB(dbfile, caseSensitive=0)
f = file(dbfile, 'w')
for (u, p) in self.users:
f.write('%s:%s\n' % (u, p))
f.close()
for (u, p) in self.users:
self.assertEqual(db.getUser(u.upper()), (u, p))
def testRequestAvatarId(self):
dbfile = self.mktemp()
db = checkers.FilePasswordDB(dbfile, caseSensitive=0)
f = file(dbfile, 'w')
for (u, p) in self.users:
f.write('%s:%s\n' % (u, p))
f.close()
creds = [credentials.UsernamePassword(u, p) for u, p in self.users]
d = defer.gatherResults(
[defer.maybeDeferred(db.requestAvatarId, c) for c in creds])
d.addCallback(self.assertEqual, [u for u, p in self.users])
return d
def testRequestAvatarId_hashed(self):
dbfile = self.mktemp()
db = checkers.FilePasswordDB(dbfile, caseSensitive=0)
f = file(dbfile, 'w')
for (u, p) in self.users:
f.write('%s:%s\n' % (u, p))
f.close()
creds = [credentials.UsernameHashedPassword(u, p) for u, p in self.users]
d = defer.gatherResults(
[defer.maybeDeferred(db.requestAvatarId, c) for c in creds])
d.addCallback(self.assertEqual, [u for u, p in self.users])
return d
class HashedPasswordOnDiskDatabaseTestCase(unittest.TestCase):
users = [
('user1', 'pass1'),
('user2', 'pass2'),
('user3', 'pass3'),
]
def hash(self, u, p, s):
return crypt(p, s)
def setUp(self):
dbfile = self.mktemp()
self.db = checkers.FilePasswordDB(dbfile, hash=self.hash)
f = file(dbfile, 'w')
for (u, p) in self.users:
f.write('%s:%s\n' % (u, crypt(p, u[:2])))
f.close()
r = TestRealm()
self.port = portal.Portal(r)
self.port.registerChecker(self.db)
def testGoodCredentials(self):
goodCreds = [credentials.UsernamePassword(u, p) for u, p in self.users]
d = defer.gatherResults([self.db.requestAvatarId(c) for c in goodCreds])
d.addCallback(self.assertEqual, [u for u, p in self.users])
return d
def testGoodCredentials_login(self):
goodCreds = [credentials.UsernamePassword(u, p) for u, p in self.users]
d = defer.gatherResults([self.port.login(c, None, ITestable)
for c in goodCreds])
d.addCallback(lambda x: [a.original.name for i, a, l in x])
d.addCallback(self.assertEqual, [u for u, p in self.users])
return d
def testBadCredentials(self):
badCreds = [credentials.UsernamePassword(u, 'wrong password')
for u, p in self.users]
d = defer.DeferredList([self.port.login(c, None, ITestable)
for c in badCreds], consumeErrors=True)
d.addCallback(self._assertFailures, error.UnauthorizedLogin)
return d
def testHashedCredentials(self):
hashedCreds = [credentials.UsernameHashedPassword(u, crypt(p, u[:2]))
for u, p in self.users]
d = defer.DeferredList([self.port.login(c, None, ITestable)
for c in hashedCreds], consumeErrors=True)
d.addCallback(self._assertFailures, error.UnhandledCredentials)
return d
def _assertFailures(self, failures, *expectedFailures):
for flag, failure in failures:
self.assertEqual(flag, defer.FAILURE)
failure.trap(*expectedFailures)
return None
if crypt is None:
skip = "crypt module not available"
class PluggableAuthenticationModulesTest(unittest.TestCase):
def setUp(self):
"""
Replace L{pamauth.callIntoPAM} with a dummy implementation with
easily-controlled behavior.
"""
self.patch(pamauth, 'callIntoPAM', self.callIntoPAM)
def callIntoPAM(self, service, user, conv):
if service != 'Twisted':
raise error.UnauthorizedLogin('bad service: %s' % service)
if user != 'testuser':
raise error.UnauthorizedLogin('bad username: %s' % user)
questions = [
(1, "Password"),
(2, "Message w/ Input"),
(3, "Message w/o Input"),
]
replies = conv(questions)
if replies != [
("password", 0),
("entry", 0),
("", 0)
]:
raise error.UnauthorizedLogin('bad conversion: %s' % repr(replies))
return 1
def _makeConv(self, d):
def conv(questions):
return defer.succeed([(d[t], 0) for t, q in questions])
return conv
def testRequestAvatarId(self):
db = checkers.PluggableAuthenticationModulesChecker()
conv = self._makeConv({1:'password', 2:'entry', 3:''})
creds = credentials.PluggableAuthenticationModules('testuser',
conv)
d = db.requestAvatarId(creds)
d.addCallback(self.assertEqual, 'testuser')
return d
def testBadCredentials(self):
db = checkers.PluggableAuthenticationModulesChecker()
conv = self._makeConv({1:'', 2:'', 3:''})
creds = credentials.PluggableAuthenticationModules('testuser',
conv)
d = db.requestAvatarId(creds)
self.assertFailure(d, error.UnauthorizedLogin)
return d
def testBadUsername(self):
db = checkers.PluggableAuthenticationModulesChecker()
conv = self._makeConv({1:'password', 2:'entry', 3:''})
creds = credentials.PluggableAuthenticationModules('baduser',
conv)
d = db.requestAvatarId(creds)
self.assertFailure(d, error.UnauthorizedLogin)
return d
if not pamauth:
skip = "Can't run without PyPAM"
class CheckersMixin:
def testPositive(self):
for chk in self.getCheckers():
for (cred, avatarId) in self.getGoodCredentials():
r = wFD(chk.requestAvatarId(cred))
yield r
self.assertEqual(r.getResult(), avatarId)
testPositive = dG(testPositive)
def testNegative(self):
for chk in self.getCheckers():
for cred in self.getBadCredentials():
r = wFD(chk.requestAvatarId(cred))
yield r
self.assertRaises(error.UnauthorizedLogin, r.getResult)
testNegative = dG(testNegative)
class HashlessFilePasswordDBMixin:
credClass = credentials.UsernamePassword
diskHash = None
networkHash = staticmethod(lambda x: x)
_validCredentials = [
('user1', 'password1'),
('user2', 'password2'),
('user3', 'password3')]
def getGoodCredentials(self):
for u, p in self._validCredentials:
yield self.credClass(u, self.networkHash(p)), u
def getBadCredentials(self):
for u, p in [('user1', 'password3'),
('user2', 'password1'),
('bloof', 'blarf')]:
yield self.credClass(u, self.networkHash(p))
def getCheckers(self):
diskHash = self.diskHash or (lambda x: x)
hashCheck = self.diskHash and (lambda username, password, stored: self.diskHash(password))
for cache in True, False:
fn = self.mktemp()
fObj = file(fn, 'w')
for u, p in self._validCredentials:
fObj.write('%s:%s\n' % (u, diskHash(p)))
fObj.close()
yield checkers.FilePasswordDB(fn, cache=cache, hash=hashCheck)
fn = self.mktemp()
fObj = file(fn, 'w')
for u, p in self._validCredentials:
fObj.write('%s dingle dongle %s\n' % (diskHash(p), u))
fObj.close()
yield checkers.FilePasswordDB(fn, ' ', 3, 0, cache=cache, hash=hashCheck)
fn = self.mktemp()
fObj = file(fn, 'w')
for u, p in self._validCredentials:
fObj.write('zip,zap,%s,zup,%s\n' % (u.title(), diskHash(p)))
fObj.close()
yield checkers.FilePasswordDB(fn, ',', 2, 4, False, cache=cache, hash=hashCheck)
class LocallyHashedFilePasswordDBMixin(HashlessFilePasswordDBMixin):
diskHash = staticmethod(lambda x: x.encode('hex'))
class NetworkHashedFilePasswordDBMixin(HashlessFilePasswordDBMixin):
networkHash = staticmethod(lambda x: x.encode('hex'))
class credClass(credentials.UsernameHashedPassword):
def checkPassword(self, password):
return self.hashed.decode('hex') == password
class HashlessFilePasswordDBCheckerTestCase(HashlessFilePasswordDBMixin, CheckersMixin, unittest.TestCase):
pass
class LocallyHashedFilePasswordDBCheckerTestCase(LocallyHashedFilePasswordDBMixin, CheckersMixin, unittest.TestCase):
pass
class NetworkHashedFilePasswordDBCheckerTestCase(NetworkHashedFilePasswordDBMixin, CheckersMixin, unittest.TestCase):
pass

View file

@ -0,0 +1,115 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""Test cases for the NMEA GPS protocol"""
import StringIO
from twisted.trial import unittest
from twisted.internet import reactor, protocol
from twisted.python import reflect
from twisted.protocols.gps import nmea
class StringIOWithNoClose(StringIO.StringIO):
def close(self):
pass
class ResultHarvester:
def __init__(self):
self.results = []
def __call__(self, *args):
self.results.append(args)
def performTest(self, function, *args, **kwargs):
l = len(self.results)
try:
function(*args, **kwargs)
except Exception, e:
self.results.append(e)
if l == len(self.results):
self.results.append(NotImplementedError())
class NMEATester(nmea.NMEAReceiver):
ignore_invalid_sentence = 0
ignore_checksum_mismatch = 0
ignore_unknown_sentencetypes = 0
convert_dates_before_y2k = 1
def connectionMade(self):
self.resultHarvester = ResultHarvester()
for fn in reflect.prefixedMethodNames(self.__class__, 'decode_'):
setattr(self, 'handle_' + fn, self.resultHarvester)
class NMEAReceiverTestCase(unittest.TestCase):
messages = (
# fix - signal acquired
"$GPGGA,231713.0,3910.413,N,07641.994,W,1,05,1.35,00044,M,-033,M,,*69",
# fix - signal not acquired
"$GPGGA,235947.000,0000.0000,N,00000.0000,E,0,00,0.0,0.0,M,,,,0000*00",
# junk
"lkjasdfkl!@#(*$!@(*#(ASDkfjasdfLMASDCVKAW!@#($)!(@#)(*",
# fix - signal acquired (invalid checksum)
"$GPGGA,231713.0,3910.413,N,07641.994,W,1,05,1.35,00044,M,-033,M,,*68",
# invalid sentence
"$GPGGX,231713.0,3910.413,N,07641.994,W,1,05,1.35,00044,M,-033,M,,*68",
# position acquired
"$GPGLL,4250.5589,S,14718.5084,E,092204.999,A*2D",
# position not acquired
"$GPGLL,0000.0000,N,00000.0000,E,235947.000,V*2D",
# active satellites (no fix)
"$GPGSA,A,1,,,,,,,,,,,,,0.0,0.0,0.0*30",
# active satellites
"$GPGSA,A,3,01,20,19,13,,,,,,,,,40.4,24.4,32.2*0A",
# positiontime (no fix)
"$GPRMC,235947.000,V,0000.0000,N,00000.0000,E,,,041299,,*1D",
# positiontime
"$GPRMC,092204.999,A,4250.5589,S,14718.5084,E,0.00,89.68,211200,,*25",
# course over ground (no fix - not implemented)
"$GPVTG,,T,,M,,N,,K*4E",
# course over ground (not implemented)
"$GPVTG,89.68,T,,M,0.00,N,0.0,K*5F",
)
results = (
(83833.0, 39.17355, -76.6999, nmea.POSFIX_SPS, 5, 1.35, (44.0, 'M'), (-33.0, 'M'), None),
(86387.0, 0.0, 0.0, 0, 0, 0.0, (0.0, 'M'), None, None),
nmea.InvalidSentence(),
nmea.InvalidChecksum(),
nmea.InvalidSentence(),
(-42.842648333333337, 147.30847333333332, 33724.999000000003, 1),
(0.0, 0.0, 86387.0, 0),
((None, None, None, None, None, None, None, None, None, None, None, None), (nmea.MODE_AUTO, nmea.MODE_NOFIX), 0.0, 0.0, 0.0),
((1, 20, 19, 13, None, None, None, None, None, None, None, None), (nmea.MODE_AUTO, nmea.MODE_3D), 40.4, 24.4, 32.2),
(0.0, 0.0, None, None, 86387.0, (1999, 12, 4), None),
(-42.842648333333337, 147.30847333333332, 0.0, 89.68, 33724.999, (2000, 12, 21), None),
NotImplementedError(),
NotImplementedError(),
)
def testGPSMessages(self):
dummy = NMEATester()
dummy.makeConnection(protocol.FileWrapper(StringIOWithNoClose()))
for line in self.messages:
dummy.resultHarvester.performTest(dummy.lineReceived, line)
def munge(myTuple):
if type(myTuple) != type(()):
return
newTuple = []
for v in myTuple:
if type(v) == type(1.1):
v = float(int(v * 10000.0)) * 0.0001
newTuple.append(v)
return tuple(newTuple)
for (message, expectedResult, actualResult) in zip(self.messages, self.results, dummy.resultHarvester.results):
expectedResult = munge(expectedResult)
actualResult = munge(actualResult)
if isinstance(expectedResult, Exception):
if isinstance(actualResult, Exception):
self.assertEqual(expectedResult.__class__, actualResult.__class__, "\nInput:\n%s\nExpected:\n%s.%s\nResults:\n%s.%s\n" % (message, expectedResult.__class__.__module__, expectedResult.__class__.__name__, actualResult.__class__.__module__, actualResult.__class__.__name__))
else:
self.assertEqual(1, 0, "\nInput:\n%s\nExpected:\n%s.%s\nResults:\n%r\n" % (message, expectedResult.__class__.__module__, expectedResult.__class__.__name__, actualResult))
else:
self.assertEqual(expectedResult, actualResult, "\nInput:\n%s\nExpected: %r\nResults: %r\n" % (message, expectedResult, actualResult))
testCases = [NMEAReceiverTestCase]

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,469 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for error handling in PB.
"""
from StringIO import StringIO
from twisted.trial import unittest
from twisted.spread import pb, flavors, jelly
from twisted.internet import reactor, defer
from twisted.python import log
##
# test exceptions
##
class AsynchronousException(Exception):
"""
Helper used to test remote methods which return Deferreds which fail with
exceptions which are not L{pb.Error} subclasses.
"""
class SynchronousException(Exception):
"""
Helper used to test remote methods which raise exceptions which are not
L{pb.Error} subclasses.
"""
class AsynchronousError(pb.Error):
"""
Helper used to test remote methods which return Deferreds which fail with
exceptions which are L{pb.Error} subclasses.
"""
class SynchronousError(pb.Error):
"""
Helper used to test remote methods which raise exceptions which are
L{pb.Error} subclasses.
"""
#class JellyError(flavors.Jellyable, pb.Error): pass
class JellyError(flavors.Jellyable, pb.Error, pb.RemoteCopy):
pass
class SecurityError(pb.Error, pb.RemoteCopy):
pass
pb.setUnjellyableForClass(JellyError, JellyError)
pb.setUnjellyableForClass(SecurityError, SecurityError)
pb.globalSecurity.allowInstancesOf(SecurityError)
####
# server-side
####
class SimpleRoot(pb.Root):
def remote_asynchronousException(self):
"""
Fail asynchronously with a non-pb.Error exception.
"""
return defer.fail(AsynchronousException("remote asynchronous exception"))
def remote_synchronousException(self):
"""
Fail synchronously with a non-pb.Error exception.
"""
raise SynchronousException("remote synchronous exception")
def remote_asynchronousError(self):
"""
Fail asynchronously with a pb.Error exception.
"""
return defer.fail(AsynchronousError("remote asynchronous error"))
def remote_synchronousError(self):
"""
Fail synchronously with a pb.Error exception.
"""
raise SynchronousError("remote synchronous error")
def remote_unknownError(self):
"""
Fail with error that is not known to client.
"""
class UnknownError(pb.Error):
pass
raise UnknownError("I'm not known to client!")
def remote_jelly(self):
self.raiseJelly()
def remote_security(self):
self.raiseSecurity()
def remote_deferredJelly(self):
d = defer.Deferred()
d.addCallback(self.raiseJelly)
d.callback(None)
return d
def remote_deferredSecurity(self):
d = defer.Deferred()
d.addCallback(self.raiseSecurity)
d.callback(None)
return d
def raiseJelly(self, results=None):
raise JellyError("I'm jellyable!")
def raiseSecurity(self, results=None):
raise SecurityError("I'm secure!")
class SaveProtocolServerFactory(pb.PBServerFactory):
"""
A L{pb.PBServerFactory} that saves the latest connected client in
C{protocolInstance}.
"""
protocolInstance = None
def clientConnectionMade(self, protocol):
"""
Keep track of the given protocol.
"""
self.protocolInstance = protocol
class PBConnTestCase(unittest.TestCase):
unsafeTracebacks = 0
def setUp(self):
self._setUpServer()
self._setUpClient()
def _setUpServer(self):
self.serverFactory = SaveProtocolServerFactory(SimpleRoot())
self.serverFactory.unsafeTracebacks = self.unsafeTracebacks
self.serverPort = reactor.listenTCP(0, self.serverFactory, interface="127.0.0.1")
def _setUpClient(self):
portNo = self.serverPort.getHost().port
self.clientFactory = pb.PBClientFactory()
self.clientConnector = reactor.connectTCP("127.0.0.1", portNo, self.clientFactory)
def tearDown(self):
if self.serverFactory.protocolInstance is not None:
self.serverFactory.protocolInstance.transport.loseConnection()
return defer.gatherResults([
self._tearDownServer(),
self._tearDownClient()])
def _tearDownServer(self):
return defer.maybeDeferred(self.serverPort.stopListening)
def _tearDownClient(self):
self.clientConnector.disconnect()
return defer.succeed(None)
class PBFailureTest(PBConnTestCase):
compare = unittest.TestCase.assertEqual
def _exceptionTest(self, method, exceptionType, flush):
def eb(err):
err.trap(exceptionType)
self.compare(err.traceback, "Traceback unavailable\n")
if flush:
errs = self.flushLoggedErrors(exceptionType)
self.assertEqual(len(errs), 1)
return (err.type, err.value, err.traceback)
d = self.clientFactory.getRootObject()
def gotRootObject(root):
d = root.callRemote(method)
d.addErrback(eb)
return d
d.addCallback(gotRootObject)
return d
def test_asynchronousException(self):
"""
Test that a Deferred returned by a remote method which already has a
Failure correctly has that error passed back to the calling side.
"""
return self._exceptionTest(
'asynchronousException', AsynchronousException, True)
def test_synchronousException(self):
"""
Like L{test_asynchronousException}, but for a method which raises an
exception synchronously.
"""
return self._exceptionTest(
'synchronousException', SynchronousException, True)
def test_asynchronousError(self):
"""
Like L{test_asynchronousException}, but for a method which returns a
Deferred failing with an L{pb.Error} subclass.
"""
return self._exceptionTest(
'asynchronousError', AsynchronousError, False)
def test_synchronousError(self):
"""
Like L{test_asynchronousError}, but for a method which synchronously
raises a L{pb.Error} subclass.
"""
return self._exceptionTest(
'synchronousError', SynchronousError, False)
def _success(self, result, expectedResult):
self.assertEqual(result, expectedResult)
return result
def _addFailingCallbacks(self, remoteCall, expectedResult, eb):
remoteCall.addCallbacks(self._success, eb,
callbackArgs=(expectedResult,))
return remoteCall
def _testImpl(self, method, expected, eb, exc=None):
"""
Call the given remote method and attach the given errback to the
resulting Deferred. If C{exc} is not None, also assert that one
exception of that type was logged.
"""
rootDeferred = self.clientFactory.getRootObject()
def gotRootObj(obj):
failureDeferred = self._addFailingCallbacks(obj.callRemote(method), expected, eb)
if exc is not None:
def gotFailure(err):
self.assertEqual(len(self.flushLoggedErrors(exc)), 1)
return err
failureDeferred.addBoth(gotFailure)
return failureDeferred
rootDeferred.addCallback(gotRootObj)
return rootDeferred
def test_jellyFailure(self):
"""
Test that an exception which is a subclass of L{pb.Error} has more
information passed across the network to the calling side.
"""
def failureJelly(fail):
fail.trap(JellyError)
self.failIf(isinstance(fail.type, str))
self.failUnless(isinstance(fail.value, fail.type))
return 43
return self._testImpl('jelly', 43, failureJelly)
def test_deferredJellyFailure(self):
"""
Test that a Deferred which fails with a L{pb.Error} is treated in
the same way as a synchronously raised L{pb.Error}.
"""
def failureDeferredJelly(fail):
fail.trap(JellyError)
self.failIf(isinstance(fail.type, str))
self.failUnless(isinstance(fail.value, fail.type))
return 430
return self._testImpl('deferredJelly', 430, failureDeferredJelly)
def test_unjellyableFailure(self):
"""
An non-jellyable L{pb.Error} subclass raised by a remote method is
turned into a Failure with a type set to the FQPN of the exception
type.
"""
def failureUnjellyable(fail):
self.assertEqual(
fail.type, 'twisted.test.test_pbfailure.SynchronousError')
return 431
return self._testImpl('synchronousError', 431, failureUnjellyable)
def test_unknownFailure(self):
"""
Test that an exception which is a subclass of L{pb.Error} but not
known on the client side has its type set properly.
"""
def failureUnknown(fail):
self.assertEqual(
fail.type, 'twisted.test.test_pbfailure.UnknownError')
return 4310
return self._testImpl('unknownError', 4310, failureUnknown)
def test_securityFailure(self):
"""
Test that even if an exception is not explicitly jellyable (by being
a L{pb.Jellyable} subclass), as long as it is an L{pb.Error}
subclass it receives the same special treatment.
"""
def failureSecurity(fail):
fail.trap(SecurityError)
self.failIf(isinstance(fail.type, str))
self.failUnless(isinstance(fail.value, fail.type))
return 4300
return self._testImpl('security', 4300, failureSecurity)
def test_deferredSecurity(self):
"""
Test that a Deferred which fails with a L{pb.Error} which is not
also a L{pb.Jellyable} is treated in the same way as a synchronously
raised exception of the same type.
"""
def failureDeferredSecurity(fail):
fail.trap(SecurityError)
self.failIf(isinstance(fail.type, str))
self.failUnless(isinstance(fail.value, fail.type))
return 43000
return self._testImpl('deferredSecurity', 43000, failureDeferredSecurity)
def test_noSuchMethodFailure(self):
"""
Test that attempting to call a method which is not defined correctly
results in an AttributeError on the calling side.
"""
def failureNoSuch(fail):
fail.trap(pb.NoSuchMethod)
self.compare(fail.traceback, "Traceback unavailable\n")
return 42000
return self._testImpl('nosuch', 42000, failureNoSuch, AttributeError)
def test_copiedFailureLogging(self):
"""
Test that a copied failure received from a PB call can be logged
locally.
Note: this test needs some serious help: all it really tests is that
log.err(copiedFailure) doesn't raise an exception.
"""
d = self.clientFactory.getRootObject()
def connected(rootObj):
return rootObj.callRemote('synchronousException')
d.addCallback(connected)
def exception(failure):
log.err(failure)
errs = self.flushLoggedErrors(SynchronousException)
self.assertEqual(len(errs), 2)
d.addErrback(exception)
return d
def test_throwExceptionIntoGenerator(self):
"""
L{pb.CopiedFailure.throwExceptionIntoGenerator} will throw a
L{RemoteError} into the given paused generator at the point where it
last yielded.
"""
original = pb.CopyableFailure(AttributeError("foo"))
copy = jelly.unjelly(jelly.jelly(original, invoker=DummyInvoker()))
exception = []
def generatorFunc():
try:
yield None
except pb.RemoteError, exc:
exception.append(exc)
else:
self.fail("RemoteError not raised")
gen = generatorFunc()
gen.send(None)
self.assertRaises(StopIteration, copy.throwExceptionIntoGenerator, gen)
self.assertEqual(len(exception), 1)
exc = exception[0]
self.assertEqual(exc.remoteType, "exceptions.AttributeError")
self.assertEqual(exc.args, ("foo",))
self.assertEqual(exc.remoteTraceback, 'Traceback unavailable\n')
class PBFailureTestUnsafe(PBFailureTest):
compare = unittest.TestCase.failIfEquals
unsafeTracebacks = 1
class DummyInvoker(object):
"""
A behaviorless object to be used as the invoker parameter to
L{jelly.jelly}.
"""
serializingPerspective = None
class FailureJellyingTests(unittest.TestCase):
"""
Tests for the interaction of jelly and failures.
"""
def test_unjelliedFailureCheck(self):
"""
An unjellied L{CopyableFailure} has a check method which behaves the
same way as the original L{CopyableFailure}'s check method.
"""
original = pb.CopyableFailure(ZeroDivisionError())
self.assertIdentical(
original.check(ZeroDivisionError), ZeroDivisionError)
self.assertIdentical(original.check(ArithmeticError), ArithmeticError)
copied = jelly.unjelly(jelly.jelly(original, invoker=DummyInvoker()))
self.assertIdentical(
copied.check(ZeroDivisionError), ZeroDivisionError)
self.assertIdentical(copied.check(ArithmeticError), ArithmeticError)
def test_twiceUnjelliedFailureCheck(self):
"""
The object which results from jellying a L{CopyableFailure}, unjellying
the result, creating a new L{CopyableFailure} from the result of that,
jellying it, and finally unjellying the result of that has a check
method which behaves the same way as the original L{CopyableFailure}'s
check method.
"""
original = pb.CopyableFailure(ZeroDivisionError())
self.assertIdentical(
original.check(ZeroDivisionError), ZeroDivisionError)
self.assertIdentical(original.check(ArithmeticError), ArithmeticError)
copiedOnce = jelly.unjelly(
jelly.jelly(original, invoker=DummyInvoker()))
derivative = pb.CopyableFailure(copiedOnce)
copiedTwice = jelly.unjelly(
jelly.jelly(derivative, invoker=DummyInvoker()))
self.assertIdentical(
copiedTwice.check(ZeroDivisionError), ZeroDivisionError)
self.assertIdentical(
copiedTwice.check(ArithmeticError), ArithmeticError)
def test_printTracebackIncludesValue(self):
"""
When L{CopiedFailure.printTraceback} is used to print a copied failure
which was unjellied from a L{CopyableFailure} with C{unsafeTracebacks}
set to C{False}, the string representation of the exception value is
included in the output.
"""
original = pb.CopyableFailure(Exception("some reason"))
copied = jelly.unjelly(jelly.jelly(original, invoker=DummyInvoker()))
output = StringIO()
copied.printTraceback(output)
self.assertEqual(
"Traceback from remote host -- Traceback unavailable\n"
"exceptions.Exception: some reason\n",
output.getvalue())

View file

@ -0,0 +1,368 @@
# -*- Python -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
__version__ = '$Revision: 1.5 $'[11:-2]
from StringIO import StringIO
from twisted.trial import unittest
from twisted.protocols import pcp
# Goal:
# Take a Protocol instance. Own all outgoing data - anything that
# would go to p.transport.write. Own all incoming data - anything
# that comes to p.dataReceived.
# I need:
# Something with the AbstractFileDescriptor interface.
# That is:
# - acts as a Transport
# - has a method write()
# - which buffers
# - acts as a Consumer
# - has a registerProducer, unRegisterProducer
# - tells the Producer to back off (pauseProducing) when its buffer is full.
# - tells the Producer to resumeProducing when its buffer is not so full.
# - acts as a Producer
# - calls registerProducer
# - calls write() on consumers
# - honors requests to pause/resume producing
# - honors stopProducing, and passes it along to upstream Producers
class DummyTransport:
"""A dumb transport to wrap around."""
def __init__(self):
self._writes = []
def write(self, data):
self._writes.append(data)
def getvalue(self):
return ''.join(self._writes)
class DummyProducer:
resumed = False
stopped = False
paused = False
def __init__(self, consumer):
self.consumer = consumer
def resumeProducing(self):
self.resumed = True
self.paused = False
def pauseProducing(self):
self.paused = True
def stopProducing(self):
self.stopped = True
class DummyConsumer(DummyTransport):
producer = None
finished = False
unregistered = True
def registerProducer(self, producer, streaming):
self.producer = (producer, streaming)
def unregisterProducer(self):
self.unregistered = True
def finish(self):
self.finished = True
class TransportInterfaceTest(unittest.TestCase):
proxyClass = pcp.BasicProducerConsumerProxy
def setUp(self):
self.underlying = DummyConsumer()
self.transport = self.proxyClass(self.underlying)
def testWrite(self):
self.transport.write("some bytes")
class ConsumerInterfaceTest:
"""Test ProducerConsumerProxy as a Consumer.
Normally we have ProducingServer -> ConsumingTransport.
If I am to go between (Server -> Shaper -> Transport), I have to
play the role of Consumer convincingly for the ProducingServer.
"""
def setUp(self):
self.underlying = DummyConsumer()
self.consumer = self.proxyClass(self.underlying)
self.producer = DummyProducer(self.consumer)
def testRegisterPush(self):
self.consumer.registerProducer(self.producer, True)
## Consumer should NOT have called PushProducer.resumeProducing
self.failIf(self.producer.resumed)
## I'm I'm just a proxy, should I only do resumeProducing when
## I get poked myself?
#def testRegisterPull(self):
# self.consumer.registerProducer(self.producer, False)
# ## Consumer SHOULD have called PushProducer.resumeProducing
# self.failUnless(self.producer.resumed)
def testUnregister(self):
self.consumer.registerProducer(self.producer, False)
self.consumer.unregisterProducer()
# Now when the consumer would ordinarily want more data, it
# shouldn't ask producer for it.
# The most succinct way to trigger "want more data" is to proxy for
# a PullProducer and have someone ask me for data.
self.producer.resumed = False
self.consumer.resumeProducing()
self.failIf(self.producer.resumed)
def testFinish(self):
self.consumer.registerProducer(self.producer, False)
self.consumer.finish()
# I guess finish should behave like unregister?
self.producer.resumed = False
self.consumer.resumeProducing()
self.failIf(self.producer.resumed)
class ProducerInterfaceTest:
"""Test ProducerConsumerProxy as a Producer.
Normally we have ProducingServer -> ConsumingTransport.
If I am to go between (Server -> Shaper -> Transport), I have to
play the role of Producer convincingly for the ConsumingTransport.
"""
def setUp(self):
self.consumer = DummyConsumer()
self.producer = self.proxyClass(self.consumer)
def testRegistersProducer(self):
self.assertEqual(self.consumer.producer[0], self.producer)
def testPause(self):
self.producer.pauseProducing()
self.producer.write("yakkity yak")
self.failIf(self.consumer.getvalue(),
"Paused producer should not have sent data.")
def testResume(self):
self.producer.pauseProducing()
self.producer.resumeProducing()
self.producer.write("yakkity yak")
self.assertEqual(self.consumer.getvalue(), "yakkity yak")
def testResumeNoEmptyWrite(self):
self.producer.pauseProducing()
self.producer.resumeProducing()
self.assertEqual(len(self.consumer._writes), 0,
"Resume triggered an empty write.")
def testResumeBuffer(self):
self.producer.pauseProducing()
self.producer.write("buffer this")
self.producer.resumeProducing()
self.assertEqual(self.consumer.getvalue(), "buffer this")
def testStop(self):
self.producer.stopProducing()
self.producer.write("yakkity yak")
self.failIf(self.consumer.getvalue(),
"Stopped producer should not have sent data.")
class PCP_ConsumerInterfaceTest(ConsumerInterfaceTest, unittest.TestCase):
proxyClass = pcp.BasicProducerConsumerProxy
class PCPII_ConsumerInterfaceTest(ConsumerInterfaceTest, unittest.TestCase):
proxyClass = pcp.ProducerConsumerProxy
class PCP_ProducerInterfaceTest(ProducerInterfaceTest, unittest.TestCase):
proxyClass = pcp.BasicProducerConsumerProxy
class PCPII_ProducerInterfaceTest(ProducerInterfaceTest, unittest.TestCase):
proxyClass = pcp.ProducerConsumerProxy
class ProducerProxyTest(unittest.TestCase):
"""Producer methods on me should be relayed to the Producer I proxy.
"""
proxyClass = pcp.BasicProducerConsumerProxy
def setUp(self):
self.proxy = self.proxyClass(None)
self.parentProducer = DummyProducer(self.proxy)
self.proxy.registerProducer(self.parentProducer, True)
def testStop(self):
self.proxy.stopProducing()
self.failUnless(self.parentProducer.stopped)
class ConsumerProxyTest(unittest.TestCase):
"""Consumer methods on me should be relayed to the Consumer I proxy.
"""
proxyClass = pcp.BasicProducerConsumerProxy
def setUp(self):
self.underlying = DummyConsumer()
self.consumer = self.proxyClass(self.underlying)
def testWrite(self):
# NOTE: This test only valid for streaming (Push) systems.
self.consumer.write("some bytes")
self.assertEqual(self.underlying.getvalue(), "some bytes")
def testFinish(self):
self.consumer.finish()
self.failUnless(self.underlying.finished)
def testUnregister(self):
self.consumer.unregisterProducer()
self.failUnless(self.underlying.unregistered)
class PullProducerTest:
def setUp(self):
self.underlying = DummyConsumer()
self.proxy = self.proxyClass(self.underlying)
self.parentProducer = DummyProducer(self.proxy)
self.proxy.registerProducer(self.parentProducer, True)
def testHoldWrites(self):
self.proxy.write("hello")
# Consumer should get no data before it says resumeProducing.
self.failIf(self.underlying.getvalue(),
"Pulling Consumer got data before it pulled.")
def testPull(self):
self.proxy.write("hello")
self.proxy.resumeProducing()
self.assertEqual(self.underlying.getvalue(), "hello")
def testMergeWrites(self):
self.proxy.write("hello ")
self.proxy.write("sunshine")
self.proxy.resumeProducing()
nwrites = len(self.underlying._writes)
self.assertEqual(nwrites, 1, "Pull resulted in %d writes instead "
"of 1." % (nwrites,))
self.assertEqual(self.underlying.getvalue(), "hello sunshine")
def testLateWrite(self):
# consumer sends its initial pull before we have data
self.proxy.resumeProducing()
self.proxy.write("data")
# This data should answer that pull request.
self.assertEqual(self.underlying.getvalue(), "data")
class PCP_PullProducerTest(PullProducerTest, unittest.TestCase):
class proxyClass(pcp.BasicProducerConsumerProxy):
iAmStreaming = False
class PCPII_PullProducerTest(PullProducerTest, unittest.TestCase):
class proxyClass(pcp.ProducerConsumerProxy):
iAmStreaming = False
# Buffering!
class BufferedConsumerTest(unittest.TestCase):
"""As a consumer, ask the producer to pause after too much data."""
proxyClass = pcp.ProducerConsumerProxy
def setUp(self):
self.underlying = DummyConsumer()
self.proxy = self.proxyClass(self.underlying)
self.proxy.bufferSize = 100
self.parentProducer = DummyProducer(self.proxy)
self.proxy.registerProducer(self.parentProducer, True)
def testRegisterPull(self):
self.proxy.registerProducer(self.parentProducer, False)
## Consumer SHOULD have called PushProducer.resumeProducing
self.failUnless(self.parentProducer.resumed)
def testPauseIntercept(self):
self.proxy.pauseProducing()
self.failIf(self.parentProducer.paused)
def testResumeIntercept(self):
self.proxy.pauseProducing()
self.proxy.resumeProducing()
# With a streaming producer, just because the proxy was resumed is
# not necessarily a reason to resume the parent producer. The state
# of the buffer should decide that.
self.failIf(self.parentProducer.resumed)
def testTriggerPause(self):
"""Make sure I say \"when.\""""
# Pause the proxy so data sent to it builds up in its buffer.
self.proxy.pauseProducing()
self.failIf(self.parentProducer.paused, "don't pause yet")
self.proxy.write("x" * 51)
self.failIf(self.parentProducer.paused, "don't pause yet")
self.proxy.write("x" * 51)
self.failUnless(self.parentProducer.paused)
def testTriggerResume(self):
"""Make sure I resumeProducing when my buffer empties."""
self.proxy.pauseProducing()
self.proxy.write("x" * 102)
self.failUnless(self.parentProducer.paused, "should be paused")
self.proxy.resumeProducing()
# Resuming should have emptied my buffer, so I should tell my
# parent to resume too.
self.failIf(self.parentProducer.paused,
"Producer should have resumed.")
self.failIf(self.proxy.producerPaused)
class BufferedPullTests(unittest.TestCase):
class proxyClass(pcp.ProducerConsumerProxy):
iAmStreaming = False
def _writeSomeData(self, data):
pcp.ProducerConsumerProxy._writeSomeData(self, data[:100])
return min(len(data), 100)
def setUp(self):
self.underlying = DummyConsumer()
self.proxy = self.proxyClass(self.underlying)
self.proxy.bufferSize = 100
self.parentProducer = DummyProducer(self.proxy)
self.proxy.registerProducer(self.parentProducer, False)
def testResumePull(self):
# If proxy has no data to send on resumeProducing, it had better pull
# some from its PullProducer.
self.parentProducer.resumed = False
self.proxy.resumeProducing()
self.failUnless(self.parentProducer.resumed)
def testLateWriteBuffering(self):
# consumer sends its initial pull before we have data
self.proxy.resumeProducing()
self.proxy.write("datum" * 21)
# This data should answer that pull request.
self.assertEqual(self.underlying.getvalue(), "datum" * 20)
# but there should be some left over
self.assertEqual(self.proxy._buffer, ["datum"])
# TODO:
# test that web request finishing bug (when we weren't proxying
# unregisterProducer but were proxying finish, web file transfers
# would hang on the last block.)
# test what happens if writeSomeBytes decided to write zero bytes.

View file

@ -0,0 +1,377 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
# System Imports
import sys
from twisted.trial import unittest
try:
import cPickle as pickle
except ImportError:
import pickle
try:
import cStringIO as StringIO
except ImportError:
import StringIO
# Twisted Imports
from twisted.persisted import styles, aot, crefutil
class VersionTestCase(unittest.TestCase):
def testNullVersionUpgrade(self):
global NullVersioned
class NullVersioned:
ok = 0
pkcl = pickle.dumps(NullVersioned())
class NullVersioned(styles.Versioned):
persistenceVersion = 1
def upgradeToVersion1(self):
self.ok = 1
mnv = pickle.loads(pkcl)
styles.doUpgrade()
assert mnv.ok, "initial upgrade not run!"
def testVersionUpgrade(self):
global MyVersioned
class MyVersioned(styles.Versioned):
persistenceVersion = 2
persistenceForgets = ['garbagedata']
v3 = 0
v4 = 0
def __init__(self):
self.somedata = 'xxx'
self.garbagedata = lambda q: 'cant persist'
def upgradeToVersion3(self):
self.v3 += 1
def upgradeToVersion4(self):
self.v4 += 1
mv = MyVersioned()
assert not (mv.v3 or mv.v4), "hasn't been upgraded yet"
pickl = pickle.dumps(mv)
MyVersioned.persistenceVersion = 4
obj = pickle.loads(pickl)
styles.doUpgrade()
assert obj.v3, "didn't do version 3 upgrade"
assert obj.v4, "didn't do version 4 upgrade"
pickl = pickle.dumps(obj)
obj = pickle.loads(pickl)
styles.doUpgrade()
assert obj.v3 == 1, "upgraded unnecessarily"
assert obj.v4 == 1, "upgraded unnecessarily"
def testNonIdentityHash(self):
global ClassWithCustomHash
class ClassWithCustomHash(styles.Versioned):
def __init__(self, unique, hash):
self.unique = unique
self.hash = hash
def __hash__(self):
return self.hash
v1 = ClassWithCustomHash('v1', 0)
v2 = ClassWithCustomHash('v2', 0)
pkl = pickle.dumps((v1, v2))
del v1, v2
ClassWithCustomHash.persistenceVersion = 1
ClassWithCustomHash.upgradeToVersion1 = lambda self: setattr(self, 'upgraded', True)
v1, v2 = pickle.loads(pkl)
styles.doUpgrade()
self.assertEqual(v1.unique, 'v1')
self.assertEqual(v2.unique, 'v2')
self.failUnless(v1.upgraded)
self.failUnless(v2.upgraded)
def testUpgradeDeserializesObjectsRequiringUpgrade(self):
global ToyClassA, ToyClassB
class ToyClassA(styles.Versioned):
pass
class ToyClassB(styles.Versioned):
pass
x = ToyClassA()
y = ToyClassB()
pklA, pklB = pickle.dumps(x), pickle.dumps(y)
del x, y
ToyClassA.persistenceVersion = 1
def upgradeToVersion1(self):
self.y = pickle.loads(pklB)
styles.doUpgrade()
ToyClassA.upgradeToVersion1 = upgradeToVersion1
ToyClassB.persistenceVersion = 1
ToyClassB.upgradeToVersion1 = lambda self: setattr(self, 'upgraded', True)
x = pickle.loads(pklA)
styles.doUpgrade()
self.failUnless(x.y.upgraded)
class VersionedSubClass(styles.Versioned):
pass
class SecondVersionedSubClass(styles.Versioned):
pass
class VersionedSubSubClass(VersionedSubClass):
pass
class VersionedDiamondSubClass(VersionedSubSubClass, SecondVersionedSubClass):
pass
class AybabtuTests(unittest.TestCase):
"""
L{styles._aybabtu} gets all of classes in the inheritance hierarchy of its
argument that are strictly between L{Versioned} and the class itself.
"""
def test_aybabtuStrictEmpty(self):
"""
L{styles._aybabtu} of L{Versioned} itself is an empty list.
"""
self.assertEqual(styles._aybabtu(styles.Versioned), [])
def test_aybabtuStrictSubclass(self):
"""
There are no classes I{between} L{VersionedSubClass} and L{Versioned},
so L{styles._aybabtu} returns an empty list.
"""
self.assertEqual(styles._aybabtu(VersionedSubClass), [])
def test_aybabtuSubsubclass(self):
"""
With a sub-sub-class of L{Versioned}, L{styles._aybabtu} returns a list
containing the intervening subclass.
"""
self.assertEqual(styles._aybabtu(VersionedSubSubClass),
[VersionedSubClass])
def test_aybabtuStrict(self):
"""
For a diamond-shaped inheritance graph, L{styles._aybabtu} returns a
list containing I{both} intermediate subclasses.
"""
self.assertEqual(
styles._aybabtu(VersionedDiamondSubClass),
[VersionedSubSubClass, VersionedSubClass, SecondVersionedSubClass])
class MyEphemeral(styles.Ephemeral):
def __init__(self, x):
self.x = x
class EphemeralTestCase(unittest.TestCase):
def testEphemeral(self):
o = MyEphemeral(3)
self.assertEqual(o.__class__, MyEphemeral)
self.assertEqual(o.x, 3)
pickl = pickle.dumps(o)
o = pickle.loads(pickl)
self.assertEqual(o.__class__, styles.Ephemeral)
self.assert_(not hasattr(o, 'x'))
class Pickleable:
def __init__(self, x):
self.x = x
def getX(self):
return self.x
class A:
"""
dummy class
"""
def amethod(self):
pass
class B:
"""
dummy class
"""
def bmethod(self):
pass
def funktion():
pass
class PicklingTestCase(unittest.TestCase):
"""Test pickling of extra object types."""
def testModule(self):
pickl = pickle.dumps(styles)
o = pickle.loads(pickl)
self.assertEqual(o, styles)
def testClassMethod(self):
pickl = pickle.dumps(Pickleable.getX)
o = pickle.loads(pickl)
self.assertEqual(o, Pickleable.getX)
def testInstanceMethod(self):
obj = Pickleable(4)
pickl = pickle.dumps(obj.getX)
o = pickle.loads(pickl)
self.assertEqual(o(), 4)
self.assertEqual(type(o), type(obj.getX))
def testStringIO(self):
f = StringIO.StringIO()
f.write("abc")
pickl = pickle.dumps(f)
o = pickle.loads(pickl)
self.assertEqual(type(o), type(f))
self.assertEqual(f.getvalue(), "abc")
class EvilSourceror:
def __init__(self, x):
self.a = self
self.a.b = self
self.a.b.c = x
class NonDictState:
def __getstate__(self):
return self.state
def __setstate__(self, state):
self.state = state
class AOTTestCase(unittest.TestCase):
def testSimpleTypes(self):
obj = (1, 2.0, 3j, True, slice(1, 2, 3), 'hello', u'world', sys.maxint + 1, None, Ellipsis)
rtObj = aot.unjellyFromSource(aot.jellyToSource(obj))
self.assertEqual(obj, rtObj)
def testMethodSelfIdentity(self):
a = A()
b = B()
a.bmethod = b.bmethod
b.a = a
im_ = aot.unjellyFromSource(aot.jellyToSource(b)).a.bmethod
self.assertEqual(im_.im_class, im_.im_self.__class__)
def test_methodNotSelfIdentity(self):
"""
If a class change after an instance has been created,
L{aot.unjellyFromSource} shoud raise a C{TypeError} when trying to
unjelly the instance.
"""
a = A()
b = B()
a.bmethod = b.bmethod
b.a = a
savedbmethod = B.bmethod
del B.bmethod
try:
self.assertRaises(TypeError, aot.unjellyFromSource,
aot.jellyToSource(b))
finally:
B.bmethod = savedbmethod
def test_unsupportedType(self):
"""
L{aot.jellyToSource} should raise a C{TypeError} when trying to jelly
an unknown type.
"""
try:
set
except:
from sets import Set as set
self.assertRaises(TypeError, aot.jellyToSource, set())
def testBasicIdentity(self):
# Anyone wanting to make this datastructure more complex, and thus this
# test more comprehensive, is welcome to do so.
aj = aot.AOTJellier().jellyToAO
d = {'hello': 'world', "method": aj}
l = [1, 2, 3,
"he\tllo\n\n\"x world!",
u"goodbye \n\t\u1010 world!",
1, 1.0, 100 ** 100l, unittest, aot.AOTJellier, d,
funktion
]
t = tuple(l)
l.append(l)
l.append(t)
l.append(t)
uj = aot.unjellyFromSource(aot.jellyToSource([l, l]))
assert uj[0] is uj[1]
assert uj[1][0:5] == l[0:5]
def testNonDictState(self):
a = NonDictState()
a.state = "meringue!"
assert aot.unjellyFromSource(aot.jellyToSource(a)).state == a.state
def testCopyReg(self):
s = "foo_bar"
sio = StringIO.StringIO()
sio.write(s)
uj = aot.unjellyFromSource(aot.jellyToSource(sio))
# print repr(uj.__dict__)
assert uj.getvalue() == s
def testFunkyReferences(self):
o = EvilSourceror(EvilSourceror([]))
j1 = aot.jellyToAOT(o)
oj = aot.unjellyFromAOT(j1)
assert oj.a is oj
assert oj.a.b is oj.b
assert oj.c is not oj.c.c
class CrefUtilTestCase(unittest.TestCase):
"""
Tests for L{crefutil}.
"""
def test_dictUnknownKey(self):
"""
L{crefutil._DictKeyAndValue} only support keys C{0} and C{1}.
"""
d = crefutil._DictKeyAndValue({})
self.assertRaises(RuntimeError, d.__setitem__, 2, 3)
def test_deferSetMultipleTimes(self):
"""
L{crefutil._Defer} can be assigned a key only one time.
"""
d = crefutil._Defer()
d[0] = 1
self.assertRaises(RuntimeError, d.__setitem__, 0, 1)
testCases = [VersionTestCase, EphemeralTestCase, PicklingTestCase]

View file

@ -0,0 +1,719 @@
# Copyright (c) 2005 Divmod, Inc.
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for Twisted plugin system.
"""
import sys, errno, os, time
import compileall
from zope.interface import Interface
from twisted.trial import unittest
from twisted.python.log import textFromEventDict, addObserver, removeObserver
from twisted.python.filepath import FilePath
from twisted.python.util import mergeFunctionMetadata
from twisted import plugin
class ITestPlugin(Interface):
"""
A plugin for use by the plugin system's unit tests.
Do not use this.
"""
class ITestPlugin2(Interface):
"""
See L{ITestPlugin}.
"""
class PluginTestCase(unittest.TestCase):
"""
Tests which verify the behavior of the current, active Twisted plugins
directory.
"""
def setUp(self):
"""
Save C{sys.path} and C{sys.modules}, and create a package for tests.
"""
self.originalPath = sys.path[:]
self.savedModules = sys.modules.copy()
self.root = FilePath(self.mktemp())
self.root.createDirectory()
self.package = self.root.child('mypackage')
self.package.createDirectory()
self.package.child('__init__.py').setContent("")
FilePath(__file__).sibling('plugin_basic.py'
).copyTo(self.package.child('testplugin.py'))
self.originalPlugin = "testplugin"
sys.path.insert(0, self.root.path)
import mypackage
self.module = mypackage
def tearDown(self):
"""
Restore C{sys.path} and C{sys.modules} to their original values.
"""
sys.path[:] = self.originalPath
sys.modules.clear()
sys.modules.update(self.savedModules)
def _unimportPythonModule(self, module, deleteSource=False):
modulePath = module.__name__.split('.')
packageName = '.'.join(modulePath[:-1])
moduleName = modulePath[-1]
delattr(sys.modules[packageName], moduleName)
del sys.modules[module.__name__]
for ext in ['c', 'o'] + (deleteSource and [''] or []):
try:
os.remove(module.__file__ + ext)
except OSError, ose:
if ose.errno != errno.ENOENT:
raise
def _clearCache(self):
"""
Remove the plugins B{droping.cache} file.
"""
self.package.child('dropin.cache').remove()
def _withCacheness(meth):
"""
This is a paranoid test wrapper, that calls C{meth} 2 times, clear the
cache, and calls it 2 other times. It's supposed to ensure that the
plugin system behaves correctly no matter what the state of the cache
is.
"""
def wrapped(self):
meth(self)
meth(self)
self._clearCache()
meth(self)
meth(self)
return mergeFunctionMetadata(meth, wrapped)
def test_cache(self):
"""
Check that the cache returned by L{plugin.getCache} hold the plugin
B{testplugin}, and that this plugin has the properties we expect:
provide L{TestPlugin}, has the good name and description, and can be
loaded successfully.
"""
cache = plugin.getCache(self.module)
dropin = cache[self.originalPlugin]
self.assertEqual(dropin.moduleName,
'mypackage.%s' % (self.originalPlugin,))
self.assertIn("I'm a test drop-in.", dropin.description)
# Note, not the preferred way to get a plugin by its interface.
p1 = [p for p in dropin.plugins if ITestPlugin in p.provided][0]
self.assertIdentical(p1.dropin, dropin)
self.assertEqual(p1.name, "TestPlugin")
# Check the content of the description comes from the plugin module
# docstring
self.assertEqual(
p1.description.strip(),
"A plugin used solely for testing purposes.")
self.assertEqual(p1.provided, [ITestPlugin, plugin.IPlugin])
realPlugin = p1.load()
# The plugin should match the class present in sys.modules
self.assertIdentical(
realPlugin,
sys.modules['mypackage.%s' % (self.originalPlugin,)].TestPlugin)
# And it should also match if we import it classicly
import mypackage.testplugin as tp
self.assertIdentical(realPlugin, tp.TestPlugin)
test_cache = _withCacheness(test_cache)
def test_plugins(self):
"""
L{plugin.getPlugins} should return the list of plugins matching the
specified interface (here, L{ITestPlugin2}), and these plugins
should be instances of classes with a C{test} method, to be sure
L{plugin.getPlugins} load classes correctly.
"""
plugins = list(plugin.getPlugins(ITestPlugin2, self.module))
self.assertEqual(len(plugins), 2)
names = ['AnotherTestPlugin', 'ThirdTestPlugin']
for p in plugins:
names.remove(p.__name__)
p.test()
test_plugins = _withCacheness(test_plugins)
def test_detectNewFiles(self):
"""
Check that L{plugin.getPlugins} is able to detect plugins added at
runtime.
"""
FilePath(__file__).sibling('plugin_extra1.py'
).copyTo(self.package.child('pluginextra.py'))
try:
# Check that the current situation is clean
self.failIfIn('mypackage.pluginextra', sys.modules)
self.failIf(hasattr(sys.modules['mypackage'], 'pluginextra'),
"mypackage still has pluginextra module")
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
# We should find 2 plugins: the one in testplugin, and the one in
# pluginextra
self.assertEqual(len(plgs), 2)
names = ['TestPlugin', 'FourthTestPlugin']
for p in plgs:
names.remove(p.__name__)
p.test1()
finally:
self._unimportPythonModule(
sys.modules['mypackage.pluginextra'],
True)
test_detectNewFiles = _withCacheness(test_detectNewFiles)
def test_detectFilesChanged(self):
"""
Check that if the content of a plugin change, L{plugin.getPlugins} is
able to detect the new plugins added.
"""
FilePath(__file__).sibling('plugin_extra1.py'
).copyTo(self.package.child('pluginextra.py'))
try:
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
# Sanity check
self.assertEqual(len(plgs), 2)
FilePath(__file__).sibling('plugin_extra2.py'
).copyTo(self.package.child('pluginextra.py'))
# Fake out Python.
self._unimportPythonModule(sys.modules['mypackage.pluginextra'])
# Make sure additions are noticed
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
self.assertEqual(len(plgs), 3)
names = ['TestPlugin', 'FourthTestPlugin', 'FifthTestPlugin']
for p in plgs:
names.remove(p.__name__)
p.test1()
finally:
self._unimportPythonModule(
sys.modules['mypackage.pluginextra'],
True)
test_detectFilesChanged = _withCacheness(test_detectFilesChanged)
def test_detectFilesRemoved(self):
"""
Check that when a dropin file is removed, L{plugin.getPlugins} doesn't
return it anymore.
"""
FilePath(__file__).sibling('plugin_extra1.py'
).copyTo(self.package.child('pluginextra.py'))
try:
# Generate a cache with pluginextra in it.
list(plugin.getPlugins(ITestPlugin, self.module))
finally:
self._unimportPythonModule(
sys.modules['mypackage.pluginextra'],
True)
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
self.assertEqual(1, len(plgs))
test_detectFilesRemoved = _withCacheness(test_detectFilesRemoved)
def test_nonexistentPathEntry(self):
"""
Test that getCache skips over any entries in a plugin package's
C{__path__} which do not exist.
"""
path = self.mktemp()
self.failIf(os.path.exists(path))
# Add the test directory to the plugins path
self.module.__path__.append(path)
try:
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
self.assertEqual(len(plgs), 1)
finally:
self.module.__path__.remove(path)
test_nonexistentPathEntry = _withCacheness(test_nonexistentPathEntry)
def test_nonDirectoryChildEntry(self):
"""
Test that getCache skips over any entries in a plugin package's
C{__path__} which refer to children of paths which are not directories.
"""
path = FilePath(self.mktemp())
self.failIf(path.exists())
path.touch()
child = path.child("test_package").path
self.module.__path__.append(child)
try:
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
self.assertEqual(len(plgs), 1)
finally:
self.module.__path__.remove(child)
test_nonDirectoryChildEntry = _withCacheness(test_nonDirectoryChildEntry)
def test_deployedMode(self):
"""
The C{dropin.cache} file may not be writable: the cache should still be
attainable, but an error should be logged to show that the cache
couldn't be updated.
"""
# Generate the cache
plugin.getCache(self.module)
cachepath = self.package.child('dropin.cache')
# Add a new plugin
FilePath(__file__).sibling('plugin_extra1.py'
).copyTo(self.package.child('pluginextra.py'))
os.chmod(self.package.path, 0500)
# Change the right of dropin.cache too for windows
os.chmod(cachepath.path, 0400)
self.addCleanup(os.chmod, self.package.path, 0700)
self.addCleanup(os.chmod, cachepath.path, 0700)
# Start observing log events to see the warning
events = []
addObserver(events.append)
self.addCleanup(removeObserver, events.append)
cache = plugin.getCache(self.module)
# The new plugin should be reported
self.assertIn('pluginextra', cache)
self.assertIn(self.originalPlugin, cache)
# Make sure something was logged about the cache.
expected = "Unable to write to plugin cache %s: error number %d" % (
cachepath.path, errno.EPERM)
for event in events:
if expected in textFromEventDict(event):
break
else:
self.fail(
"Did not observe unwriteable cache warning in log "
"events: %r" % (events,))
# This is something like the Twisted plugins file.
pluginInitFile = """
from twisted.plugin import pluginPackagePaths
__path__.extend(pluginPackagePaths(__name__))
__all__ = []
"""
def pluginFileContents(name):
return (
"from zope.interface import classProvides\n"
"from twisted.plugin import IPlugin\n"
"from twisted.test.test_plugin import ITestPlugin\n"
"\n"
"class %s(object):\n"
" classProvides(IPlugin, ITestPlugin)\n") % (name,)
def _createPluginDummy(entrypath, pluginContent, real, pluginModule):
"""
Create a plugindummy package.
"""
entrypath.createDirectory()
pkg = entrypath.child('plugindummy')
pkg.createDirectory()
if real:
pkg.child('__init__.py').setContent('')
plugs = pkg.child('plugins')
plugs.createDirectory()
if real:
plugs.child('__init__.py').setContent(pluginInitFile)
plugs.child(pluginModule + '.py').setContent(pluginContent)
return plugs
class DeveloperSetupTests(unittest.TestCase):
"""
These tests verify things about the plugin system without actually
interacting with the deployed 'twisted.plugins' package, instead creating a
temporary package.
"""
def setUp(self):
"""
Create a complex environment with multiple entries on sys.path, akin to
a developer's environment who has a development (trunk) checkout of
Twisted, a system installed version of Twisted (for their operating
system's tools) and a project which provides Twisted plugins.
"""
self.savedPath = sys.path[:]
self.savedModules = sys.modules.copy()
self.fakeRoot = FilePath(self.mktemp())
self.fakeRoot.createDirectory()
self.systemPath = self.fakeRoot.child('system_path')
self.devPath = self.fakeRoot.child('development_path')
self.appPath = self.fakeRoot.child('application_path')
self.systemPackage = _createPluginDummy(
self.systemPath, pluginFileContents('system'),
True, 'plugindummy_builtin')
self.devPackage = _createPluginDummy(
self.devPath, pluginFileContents('dev'),
True, 'plugindummy_builtin')
self.appPackage = _createPluginDummy(
self.appPath, pluginFileContents('app'),
False, 'plugindummy_app')
# Now we're going to do the system installation.
sys.path.extend([x.path for x in [self.systemPath,
self.appPath]])
# Run all the way through the plugins list to cause the
# L{plugin.getPlugins} generator to write cache files for the system
# installation.
self.getAllPlugins()
self.sysplug = self.systemPath.child('plugindummy').child('plugins')
self.syscache = self.sysplug.child('dropin.cache')
# Make sure there's a nice big difference in modification times so that
# we won't re-build the system cache.
now = time.time()
os.utime(
self.sysplug.child('plugindummy_builtin.py').path,
(now - 5000,) * 2)
os.utime(self.syscache.path, (now - 2000,) * 2)
# For extra realism, let's make sure that the system path is no longer
# writable.
self.lockSystem()
self.resetEnvironment()
def lockSystem(self):
"""
Lock the system directories, as if they were unwritable by this user.
"""
os.chmod(self.sysplug.path, 0555)
os.chmod(self.syscache.path, 0555)
def unlockSystem(self):
"""
Unlock the system directories, as if they were writable by this user.
"""
os.chmod(self.sysplug.path, 0777)
os.chmod(self.syscache.path, 0777)
def getAllPlugins(self):
"""
Get all the plugins loadable from our dummy package, and return their
short names.
"""
# Import the module we just added to our path. (Local scope because
# this package doesn't exist outside of this test.)
import plugindummy.plugins
x = list(plugin.getPlugins(ITestPlugin, plugindummy.plugins))
return [plug.__name__ for plug in x]
def resetEnvironment(self):
"""
Change the environment to what it should be just as the test is
starting.
"""
self.unsetEnvironment()
sys.path.extend([x.path for x in [self.devPath,
self.systemPath,
self.appPath]])
def unsetEnvironment(self):
"""
Change the Python environment back to what it was before the test was
started.
"""
sys.modules.clear()
sys.modules.update(self.savedModules)
sys.path[:] = self.savedPath
def tearDown(self):
"""
Reset the Python environment to what it was before this test ran, and
restore permissions on files which were marked read-only so that the
directory may be cleanly cleaned up.
"""
self.unsetEnvironment()
# Normally we wouldn't "clean up" the filesystem like this (leaving
# things for post-test inspection), but if we left the permissions the
# way they were, we'd be leaving files around that the buildbots
# couldn't delete, and that would be bad.
self.unlockSystem()
def test_developmentPluginAvailability(self):
"""
Plugins added in the development path should be loadable, even when
the (now non-importable) system path contains its own idea of the
list of plugins for a package. Inversely, plugins added in the
system path should not be available.
"""
# Run 3 times: uncached, cached, and then cached again to make sure we
# didn't overwrite / corrupt the cache on the cached try.
for x in range(3):
names = self.getAllPlugins()
names.sort()
self.assertEqual(names, ['app', 'dev'])
def test_freshPyReplacesStalePyc(self):
"""
Verify that if a stale .pyc file on the PYTHONPATH is replaced by a
fresh .py file, the plugins in the new .py are picked up rather than
the stale .pyc, even if the .pyc is still around.
"""
mypath = self.appPackage.child("stale.py")
mypath.setContent(pluginFileContents('one'))
# Make it super stale
x = time.time() - 1000
os.utime(mypath.path, (x, x))
pyc = mypath.sibling('stale.pyc')
# compile it
compileall.compile_dir(self.appPackage.path, quiet=1)
os.utime(pyc.path, (x, x))
# Eliminate the other option.
mypath.remove()
# Make sure it's the .pyc path getting cached.
self.resetEnvironment()
# Sanity check.
self.assertIn('one', self.getAllPlugins())
self.failIfIn('two', self.getAllPlugins())
self.resetEnvironment()
mypath.setContent(pluginFileContents('two'))
self.failIfIn('one', self.getAllPlugins())
self.assertIn('two', self.getAllPlugins())
def test_newPluginsOnReadOnlyPath(self):
"""
Verify that a failure to write the dropin.cache file on a read-only
path will not affect the list of plugins returned.
Note: this test should pass on both Linux and Windows, but may not
provide useful coverage on Windows due to the different meaning of
"read-only directory".
"""
self.unlockSystem()
self.sysplug.child('newstuff.py').setContent(pluginFileContents('one'))
self.lockSystem()
# Take the developer path out, so that the system plugins are actually
# examined.
sys.path.remove(self.devPath.path)
# Start observing log events to see the warning
events = []
addObserver(events.append)
self.addCleanup(removeObserver, events.append)
self.assertIn('one', self.getAllPlugins())
# Make sure something was logged about the cache.
expected = "Unable to write to plugin cache %s: error number %d" % (
self.syscache.path, errno.EPERM)
for event in events:
if expected in textFromEventDict(event):
break
else:
self.fail(
"Did not observe unwriteable cache warning in log "
"events: %r" % (events,))
class AdjacentPackageTests(unittest.TestCase):
"""
Tests for the behavior of the plugin system when there are multiple
installed copies of the package containing the plugins being loaded.
"""
def setUp(self):
"""
Save the elements of C{sys.path} and the items of C{sys.modules}.
"""
self.originalPath = sys.path[:]
self.savedModules = sys.modules.copy()
def tearDown(self):
"""
Restore C{sys.path} and C{sys.modules} to their original values.
"""
sys.path[:] = self.originalPath
sys.modules.clear()
sys.modules.update(self.savedModules)
def createDummyPackage(self, root, name, pluginName):
"""
Create a directory containing a Python package named I{dummy} with a
I{plugins} subpackage.
@type root: L{FilePath}
@param root: The directory in which to create the hierarchy.
@type name: C{str}
@param name: The name of the directory to create which will contain
the package.
@type pluginName: C{str}
@param pluginName: The name of a module to create in the
I{dummy.plugins} package.
@rtype: L{FilePath}
@return: The directory which was created to contain the I{dummy}
package.
"""
directory = root.child(name)
package = directory.child('dummy')
package.makedirs()
package.child('__init__.py').setContent('')
plugins = package.child('plugins')
plugins.makedirs()
plugins.child('__init__.py').setContent(pluginInitFile)
pluginModule = plugins.child(pluginName + '.py')
pluginModule.setContent(pluginFileContents(name))
return directory
def test_hiddenPackageSamePluginModuleNameObscured(self):
"""
Only plugins from the first package in sys.path should be returned by
getPlugins in the case where there are two Python packages by the same
name installed, each with a plugin module by a single name.
"""
root = FilePath(self.mktemp())
root.makedirs()
firstDirectory = self.createDummyPackage(root, 'first', 'someplugin')
secondDirectory = self.createDummyPackage(root, 'second', 'someplugin')
sys.path.append(firstDirectory.path)
sys.path.append(secondDirectory.path)
import dummy.plugins
plugins = list(plugin.getPlugins(ITestPlugin, dummy.plugins))
self.assertEqual(['first'], [p.__name__ for p in plugins])
def test_hiddenPackageDifferentPluginModuleNameObscured(self):
"""
Plugins from the first package in sys.path should be returned by
getPlugins in the case where there are two Python packages by the same
name installed, each with a plugin module by a different name.
"""
root = FilePath(self.mktemp())
root.makedirs()
firstDirectory = self.createDummyPackage(root, 'first', 'thisplugin')
secondDirectory = self.createDummyPackage(root, 'second', 'thatplugin')
sys.path.append(firstDirectory.path)
sys.path.append(secondDirectory.path)
import dummy.plugins
plugins = list(plugin.getPlugins(ITestPlugin, dummy.plugins))
self.assertEqual(['first'], [p.__name__ for p in plugins])
class PackagePathTests(unittest.TestCase):
"""
Tests for L{plugin.pluginPackagePaths} which constructs search paths for
plugin packages.
"""
def setUp(self):
"""
Save the elements of C{sys.path}.
"""
self.originalPath = sys.path[:]
def tearDown(self):
"""
Restore C{sys.path} to its original value.
"""
sys.path[:] = self.originalPath
def test_pluginDirectories(self):
"""
L{plugin.pluginPackagePaths} should return a list containing each
directory in C{sys.path} with a suffix based on the supplied package
name.
"""
foo = FilePath('foo')
bar = FilePath('bar')
sys.path = [foo.path, bar.path]
self.assertEqual(
plugin.pluginPackagePaths('dummy.plugins'),
[foo.child('dummy').child('plugins').path,
bar.child('dummy').child('plugins').path])
def test_pluginPackagesExcluded(self):
"""
L{plugin.pluginPackagePaths} should exclude directories which are
Python packages. The only allowed plugin package (the only one
associated with a I{dummy} package which Python will allow to be
imported) will already be known to the caller of
L{plugin.pluginPackagePaths} and will most commonly already be in
the C{__path__} they are about to mutate.
"""
root = FilePath(self.mktemp())
foo = root.child('foo').child('dummy').child('plugins')
foo.makedirs()
foo.child('__init__.py').setContent('')
sys.path = [root.child('foo').path, root.child('bar').path]
self.assertEqual(
plugin.pluginPackagePaths('dummy.plugins'),
[root.child('bar').child('dummy').child('plugins').path])

View file

@ -0,0 +1,872 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test code for policies.
"""
from __future__ import division, absolute_import
from zope.interface import Interface, implementer, implementedBy
from twisted.python.compat import NativeStringIO, _PY3
from twisted.trial import unittest
from twisted.test.proto_helpers import StringTransport
from twisted.test.proto_helpers import StringTransportWithDisconnection
from twisted.internet import protocol, reactor, address, defer, task
from twisted.protocols import policies
class SimpleProtocol(protocol.Protocol):
connected = disconnected = 0
buffer = b""
def __init__(self):
self.dConnected = defer.Deferred()
self.dDisconnected = defer.Deferred()
def connectionMade(self):
self.connected = 1
self.dConnected.callback('')
def connectionLost(self, reason):
self.disconnected = 1
self.dDisconnected.callback('')
def dataReceived(self, data):
self.buffer += data
class SillyFactory(protocol.ClientFactory):
def __init__(self, p):
self.p = p
def buildProtocol(self, addr):
return self.p
class EchoProtocol(protocol.Protocol):
paused = False
def pauseProducing(self):
self.paused = True
def resumeProducing(self):
self.paused = False
def stopProducing(self):
pass
def dataReceived(self, data):
self.transport.write(data)
class Server(protocol.ServerFactory):
"""
A simple server factory using L{EchoProtocol}.
"""
protocol = EchoProtocol
class TestableThrottlingFactory(policies.ThrottlingFactory):
"""
L{policies.ThrottlingFactory} using a L{task.Clock} for tests.
"""
def __init__(self, clock, *args, **kwargs):
"""
@param clock: object providing a callLater method that can be used
for tests.
@type clock: C{task.Clock} or alike.
"""
policies.ThrottlingFactory.__init__(self, *args, **kwargs)
self.clock = clock
def callLater(self, period, func):
"""
Forward to the testable clock.
"""
return self.clock.callLater(period, func)
class TestableTimeoutFactory(policies.TimeoutFactory):
"""
L{policies.TimeoutFactory} using a L{task.Clock} for tests.
"""
def __init__(self, clock, *args, **kwargs):
"""
@param clock: object providing a callLater method that can be used
for tests.
@type clock: C{task.Clock} or alike.
"""
policies.TimeoutFactory.__init__(self, *args, **kwargs)
self.clock = clock
def callLater(self, period, func):
"""
Forward to the testable clock.
"""
return self.clock.callLater(period, func)
class WrapperTestCase(unittest.TestCase):
"""
Tests for L{WrappingFactory} and L{ProtocolWrapper}.
"""
def test_protocolFactoryAttribute(self):
"""
Make sure protocol.factory is the wrapped factory, not the wrapping
factory.
"""
f = Server()
wf = policies.WrappingFactory(f)
p = wf.buildProtocol(address.IPv4Address('TCP', '127.0.0.1', 35))
self.assertIdentical(p.wrappedProtocol.factory, f)
def test_transportInterfaces(self):
"""
The transport wrapper passed to the wrapped protocol's
C{makeConnection} provides the same interfaces as are provided by the
original transport.
"""
class IStubTransport(Interface):
pass
@implementer(IStubTransport)
class StubTransport:
pass
# Looking up what ProtocolWrapper implements also mutates the class.
# It adds __implemented__ and __providedBy__ attributes to it. These
# prevent __getattr__ from causing the IStubTransport.providedBy call
# below from returning True. If, by accident, nothing else causes
# these attributes to be added to ProtocolWrapper, the test will pass,
# but the interface will only be provided until something does trigger
# their addition. So we just trigger it right now to be sure.
implementedBy(policies.ProtocolWrapper)
proto = protocol.Protocol()
wrapper = policies.ProtocolWrapper(policies.WrappingFactory(None), proto)
wrapper.makeConnection(StubTransport())
self.assertTrue(IStubTransport.providedBy(proto.transport))
def test_factoryLogPrefix(self):
"""
L{WrappingFactory.logPrefix} is customized to mention both the original
factory and the wrapping factory.
"""
server = Server()
factory = policies.WrappingFactory(server)
self.assertEqual("Server (WrappingFactory)", factory.logPrefix())
def test_factoryLogPrefixFallback(self):
"""
If the wrapped factory doesn't have a L{logPrefix} method,
L{WrappingFactory.logPrefix} falls back to the factory class name.
"""
class NoFactory(object):
pass
server = NoFactory()
factory = policies.WrappingFactory(server)
self.assertEqual("NoFactory (WrappingFactory)", factory.logPrefix())
def test_protocolLogPrefix(self):
"""
L{ProtocolWrapper.logPrefix} is customized to mention both the original
protocol and the wrapper.
"""
server = Server()
factory = policies.WrappingFactory(server)
protocol = factory.buildProtocol(
address.IPv4Address('TCP', '127.0.0.1', 35))
self.assertEqual("EchoProtocol (ProtocolWrapper)",
protocol.logPrefix())
def test_protocolLogPrefixFallback(self):
"""
If the wrapped protocol doesn't have a L{logPrefix} method,
L{ProtocolWrapper.logPrefix} falls back to the protocol class name.
"""
class NoProtocol(object):
pass
server = Server()
server.protocol = NoProtocol
factory = policies.WrappingFactory(server)
protocol = factory.buildProtocol(
address.IPv4Address('TCP', '127.0.0.1', 35))
self.assertEqual("NoProtocol (ProtocolWrapper)",
protocol.logPrefix())
def _getWrapper(self):
"""
Return L{policies.ProtocolWrapper} that has been connected to a
L{StringTransport}.
"""
wrapper = policies.ProtocolWrapper(policies.WrappingFactory(Server()),
protocol.Protocol())
transport = StringTransport()
wrapper.makeConnection(transport)
return wrapper
def test_getHost(self):
"""
L{policies.ProtocolWrapper.getHost} calls C{getHost} on the underlying
transport.
"""
wrapper = self._getWrapper()
self.assertEqual(wrapper.getHost(), wrapper.transport.getHost())
def test_getPeer(self):
"""
L{policies.ProtocolWrapper.getPeer} calls C{getPeer} on the underlying
transport.
"""
wrapper = self._getWrapper()
self.assertEqual(wrapper.getPeer(), wrapper.transport.getPeer())
def test_registerProducer(self):
"""
L{policies.ProtocolWrapper.registerProducer} calls C{registerProducer}
on the underlying transport.
"""
wrapper = self._getWrapper()
producer = object()
wrapper.registerProducer(producer, True)
self.assertIdentical(wrapper.transport.producer, producer)
self.assertTrue(wrapper.transport.streaming)
def test_unregisterProducer(self):
"""
L{policies.ProtocolWrapper.unregisterProducer} calls
C{unregisterProducer} on the underlying transport.
"""
wrapper = self._getWrapper()
producer = object()
wrapper.registerProducer(producer, True)
wrapper.unregisterProducer()
self.assertIdentical(wrapper.transport.producer, None)
self.assertIdentical(wrapper.transport.streaming, None)
def test_stopConsuming(self):
"""
L{policies.ProtocolWrapper.stopConsuming} calls C{stopConsuming} on
the underlying transport.
"""
wrapper = self._getWrapper()
result = []
wrapper.transport.stopConsuming = lambda: result.append(True)
wrapper.stopConsuming()
self.assertEqual(result, [True])
def test_startedConnecting(self):
"""
L{policies.WrappingFactory.startedConnecting} calls
C{startedConnecting} on the underlying factory.
"""
result = []
class Factory(object):
def startedConnecting(self, connector):
result.append(connector)
wrapper = policies.WrappingFactory(Factory())
connector = object()
wrapper.startedConnecting(connector)
self.assertEqual(result, [connector])
def test_clientConnectionLost(self):
"""
L{policies.WrappingFactory.clientConnectionLost} calls
C{clientConnectionLost} on the underlying factory.
"""
result = []
class Factory(object):
def clientConnectionLost(self, connector, reason):
result.append((connector, reason))
wrapper = policies.WrappingFactory(Factory())
connector = object()
reason = object()
wrapper.clientConnectionLost(connector, reason)
self.assertEqual(result, [(connector, reason)])
def test_clientConnectionFailed(self):
"""
L{policies.WrappingFactory.clientConnectionFailed} calls
C{clientConnectionFailed} on the underlying factory.
"""
result = []
class Factory(object):
def clientConnectionFailed(self, connector, reason):
result.append((connector, reason))
wrapper = policies.WrappingFactory(Factory())
connector = object()
reason = object()
wrapper.clientConnectionFailed(connector, reason)
self.assertEqual(result, [(connector, reason)])
class WrappingFactory(policies.WrappingFactory):
protocol = lambda s, f, p: p
def startFactory(self):
policies.WrappingFactory.startFactory(self)
self.deferred.callback(None)
class ThrottlingTestCase(unittest.TestCase):
"""
Tests for L{policies.ThrottlingFactory}.
"""
def test_limit(self):
"""
Full test using a custom server limiting number of connections.
"""
server = Server()
c1, c2, c3, c4 = [SimpleProtocol() for i in range(4)]
tServer = policies.ThrottlingFactory(server, 2)
wrapTServer = WrappingFactory(tServer)
wrapTServer.deferred = defer.Deferred()
# Start listening
p = reactor.listenTCP(0, wrapTServer, interface="127.0.0.1")
n = p.getHost().port
def _connect123(results):
reactor.connectTCP("127.0.0.1", n, SillyFactory(c1))
c1.dConnected.addCallback(
lambda r: reactor.connectTCP("127.0.0.1", n, SillyFactory(c2)))
c2.dConnected.addCallback(
lambda r: reactor.connectTCP("127.0.0.1", n, SillyFactory(c3)))
return c3.dDisconnected
def _check123(results):
self.assertEqual([c.connected for c in (c1, c2, c3)], [1, 1, 1])
self.assertEqual([c.disconnected for c in (c1, c2, c3)], [0, 0, 1])
self.assertEqual(len(tServer.protocols.keys()), 2)
return results
def _lose1(results):
# disconnect one protocol and now another should be able to connect
c1.transport.loseConnection()
return c1.dDisconnected
def _connect4(results):
reactor.connectTCP("127.0.0.1", n, SillyFactory(c4))
return c4.dConnected
def _check4(results):
self.assertEqual(c4.connected, 1)
self.assertEqual(c4.disconnected, 0)
return results
def _cleanup(results):
for c in c2, c4:
c.transport.loseConnection()
return defer.DeferredList([
defer.maybeDeferred(p.stopListening),
c2.dDisconnected,
c4.dDisconnected])
wrapTServer.deferred.addCallback(_connect123)
wrapTServer.deferred.addCallback(_check123)
wrapTServer.deferred.addCallback(_lose1)
wrapTServer.deferred.addCallback(_connect4)
wrapTServer.deferred.addCallback(_check4)
wrapTServer.deferred.addCallback(_cleanup)
return wrapTServer.deferred
def test_writeSequence(self):
"""
L{ThrottlingProtocol.writeSequence} is called on the underlying factory.
"""
server = Server()
tServer = TestableThrottlingFactory(task.Clock(), server)
protocol = tServer.buildProtocol(
address.IPv4Address('TCP', '127.0.0.1', 0))
transport = StringTransportWithDisconnection()
transport.protocol = protocol
protocol.makeConnection(transport)
protocol.writeSequence([b'bytes'] * 4)
self.assertEqual(transport.value(), b"bytesbytesbytesbytes")
self.assertEqual(tServer.writtenThisSecond, 20)
def test_writeLimit(self):
"""
Check the writeLimit parameter: write data, and check for the pause
status.
"""
server = Server()
tServer = TestableThrottlingFactory(task.Clock(), server, writeLimit=10)
port = tServer.buildProtocol(address.IPv4Address('TCP', '127.0.0.1', 0))
tr = StringTransportWithDisconnection()
tr.protocol = port
port.makeConnection(tr)
port.producer = port.wrappedProtocol
port.dataReceived(b"0123456789")
port.dataReceived(b"abcdefghij")
self.assertEqual(tr.value(), b"0123456789abcdefghij")
self.assertEqual(tServer.writtenThisSecond, 20)
self.assertFalse(port.wrappedProtocol.paused)
# at this point server should've written 20 bytes, 10 bytes
# above the limit so writing should be paused around 1 second
# from 'now', and resumed a second after that
tServer.clock.advance(1.05)
self.assertEqual(tServer.writtenThisSecond, 0)
self.assertTrue(port.wrappedProtocol.paused)
tServer.clock.advance(1.05)
self.assertEqual(tServer.writtenThisSecond, 0)
self.assertFalse(port.wrappedProtocol.paused)
def test_readLimit(self):
"""
Check the readLimit parameter: read data and check for the pause
status.
"""
server = Server()
tServer = TestableThrottlingFactory(task.Clock(), server, readLimit=10)
port = tServer.buildProtocol(address.IPv4Address('TCP', '127.0.0.1', 0))
tr = StringTransportWithDisconnection()
tr.protocol = port
port.makeConnection(tr)
port.dataReceived(b"0123456789")
port.dataReceived(b"abcdefghij")
self.assertEqual(tr.value(), b"0123456789abcdefghij")
self.assertEqual(tServer.readThisSecond, 20)
tServer.clock.advance(1.05)
self.assertEqual(tServer.readThisSecond, 0)
self.assertEqual(tr.producerState, 'paused')
tServer.clock.advance(1.05)
self.assertEqual(tServer.readThisSecond, 0)
self.assertEqual(tr.producerState, 'producing')
tr.clear()
port.dataReceived(b"0123456789")
port.dataReceived(b"abcdefghij")
self.assertEqual(tr.value(), b"0123456789abcdefghij")
self.assertEqual(tServer.readThisSecond, 20)
tServer.clock.advance(1.05)
self.assertEqual(tServer.readThisSecond, 0)
self.assertEqual(tr.producerState, 'paused')
tServer.clock.advance(1.05)
self.assertEqual(tServer.readThisSecond, 0)
self.assertEqual(tr.producerState, 'producing')
class TimeoutTestCase(unittest.TestCase):
"""
Tests for L{policies.TimeoutFactory}.
"""
def setUp(self):
"""
Create a testable, deterministic clock, and a set of
server factory/protocol/transport.
"""
self.clock = task.Clock()
wrappedFactory = protocol.ServerFactory()
wrappedFactory.protocol = SimpleProtocol
self.factory = TestableTimeoutFactory(self.clock, wrappedFactory, 3)
self.proto = self.factory.buildProtocol(
address.IPv4Address('TCP', '127.0.0.1', 12345))
self.transport = StringTransportWithDisconnection()
self.transport.protocol = self.proto
self.proto.makeConnection(self.transport)
def test_timeout(self):
"""
Make sure that when a TimeoutFactory accepts a connection, it will
time out that connection if no data is read or written within the
timeout period.
"""
# Let almost 3 time units pass
self.clock.pump([0.0, 0.5, 1.0, 1.0, 0.4])
self.failIf(self.proto.wrappedProtocol.disconnected)
# Now let the timer elapse
self.clock.pump([0.0, 0.2])
self.failUnless(self.proto.wrappedProtocol.disconnected)
def test_sendAvoidsTimeout(self):
"""
Make sure that writing data to a transport from a protocol
constructed by a TimeoutFactory resets the timeout countdown.
"""
# Let half the countdown period elapse
self.clock.pump([0.0, 0.5, 1.0])
self.failIf(self.proto.wrappedProtocol.disconnected)
# Send some data (self.proto is the /real/ proto's transport, so this
# is the write that gets called)
self.proto.write(b'bytes bytes bytes')
# More time passes, putting us past the original timeout
self.clock.pump([0.0, 1.0, 1.0])
self.failIf(self.proto.wrappedProtocol.disconnected)
# Make sure writeSequence delays timeout as well
self.proto.writeSequence([b'bytes'] * 3)
# Tick tock
self.clock.pump([0.0, 1.0, 1.0])
self.failIf(self.proto.wrappedProtocol.disconnected)
# Don't write anything more, just let the timeout expire
self.clock.pump([0.0, 2.0])
self.failUnless(self.proto.wrappedProtocol.disconnected)
def test_receiveAvoidsTimeout(self):
"""
Make sure that receiving data also resets the timeout countdown.
"""
# Let half the countdown period elapse
self.clock.pump([0.0, 1.0, 0.5])
self.failIf(self.proto.wrappedProtocol.disconnected)
# Some bytes arrive, they should reset the counter
self.proto.dataReceived(b'bytes bytes bytes')
# We pass the original timeout
self.clock.pump([0.0, 1.0, 1.0])
self.failIf(self.proto.wrappedProtocol.disconnected)
# Nothing more arrives though, the new timeout deadline is passed,
# the connection should be dropped.
self.clock.pump([0.0, 1.0, 1.0])
self.failUnless(self.proto.wrappedProtocol.disconnected)
class TimeoutTester(protocol.Protocol, policies.TimeoutMixin):
"""
A testable protocol with timeout facility.
@ivar timedOut: set to C{True} if a timeout has been detected.
@type timedOut: C{bool}
"""
timeOut = 3
timedOut = False
def __init__(self, clock):
"""
Initialize the protocol with a C{task.Clock} object.
"""
self.clock = clock
def connectionMade(self):
"""
Upon connection, set the timeout.
"""
self.setTimeout(self.timeOut)
def dataReceived(self, data):
"""
Reset the timeout on data.
"""
self.resetTimeout()
protocol.Protocol.dataReceived(self, data)
def connectionLost(self, reason=None):
"""
On connection lost, cancel all timeout operations.
"""
self.setTimeout(None)
def timeoutConnection(self):
"""
Flags the timedOut variable to indicate the timeout of the connection.
"""
self.timedOut = True
def callLater(self, timeout, func, *args, **kwargs):
"""
Override callLater to use the deterministic clock.
"""
return self.clock.callLater(timeout, func, *args, **kwargs)
class TestTimeout(unittest.TestCase):
"""
Tests for L{policies.TimeoutMixin}.
"""
def setUp(self):
"""
Create a testable, deterministic clock and a C{TimeoutTester} instance.
"""
self.clock = task.Clock()
self.proto = TimeoutTester(self.clock)
def test_overriddenCallLater(self):
"""
Test that the callLater of the clock is used instead of
C{reactor.callLater}.
"""
self.proto.setTimeout(10)
self.assertEqual(len(self.clock.calls), 1)
def test_timeout(self):
"""
Check that the protocol does timeout at the time specified by its
C{timeOut} attribute.
"""
self.proto.makeConnection(StringTransport())
# timeOut value is 3
self.clock.pump([0, 0.5, 1.0, 1.0])
self.failIf(self.proto.timedOut)
self.clock.pump([0, 1.0])
self.failUnless(self.proto.timedOut)
def test_noTimeout(self):
"""
Check that receiving data is delaying the timeout of the connection.
"""
self.proto.makeConnection(StringTransport())
self.clock.pump([0, 0.5, 1.0, 1.0])
self.failIf(self.proto.timedOut)
self.proto.dataReceived(b'hello there')
self.clock.pump([0, 1.0, 1.0, 0.5])
self.failIf(self.proto.timedOut)
self.clock.pump([0, 1.0])
self.failUnless(self.proto.timedOut)
def test_resetTimeout(self):
"""
Check that setting a new value for timeout cancel the previous value
and install a new timeout.
"""
self.proto.timeOut = None
self.proto.makeConnection(StringTransport())
self.proto.setTimeout(1)
self.assertEqual(self.proto.timeOut, 1)
self.clock.pump([0, 0.9])
self.failIf(self.proto.timedOut)
self.clock.pump([0, 0.2])
self.failUnless(self.proto.timedOut)
def test_cancelTimeout(self):
"""
Setting the timeout to C{None} cancel any timeout operations.
"""
self.proto.timeOut = 5
self.proto.makeConnection(StringTransport())
self.proto.setTimeout(None)
self.assertEqual(self.proto.timeOut, None)
self.clock.pump([0, 5, 5, 5])
self.failIf(self.proto.timedOut)
def test_return(self):
"""
setTimeout should return the value of the previous timeout.
"""
self.proto.timeOut = 5
self.assertEqual(self.proto.setTimeout(10), 5)
self.assertEqual(self.proto.setTimeout(None), 10)
self.assertEqual(self.proto.setTimeout(1), None)
self.assertEqual(self.proto.timeOut, 1)
# Clean up the DelayedCall
self.proto.setTimeout(None)
class LimitTotalConnectionsFactoryTestCase(unittest.TestCase):
"""Tests for policies.LimitTotalConnectionsFactory"""
def testConnectionCounting(self):
# Make a basic factory
factory = policies.LimitTotalConnectionsFactory()
factory.protocol = protocol.Protocol
# connectionCount starts at zero
self.assertEqual(0, factory.connectionCount)
# connectionCount increments as connections are made
p1 = factory.buildProtocol(None)
self.assertEqual(1, factory.connectionCount)
p2 = factory.buildProtocol(None)
self.assertEqual(2, factory.connectionCount)
# and decrements as they are lost
p1.connectionLost(None)
self.assertEqual(1, factory.connectionCount)
p2.connectionLost(None)
self.assertEqual(0, factory.connectionCount)
def testConnectionLimiting(self):
# Make a basic factory with a connection limit of 1
factory = policies.LimitTotalConnectionsFactory()
factory.protocol = protocol.Protocol
factory.connectionLimit = 1
# Make a connection
p = factory.buildProtocol(None)
self.assertNotEqual(None, p)
self.assertEqual(1, factory.connectionCount)
# Try to make a second connection, which will exceed the connection
# limit. This should return None, because overflowProtocol is None.
self.assertEqual(None, factory.buildProtocol(None))
self.assertEqual(1, factory.connectionCount)
# Define an overflow protocol
class OverflowProtocol(protocol.Protocol):
def connectionMade(self):
factory.overflowed = True
factory.overflowProtocol = OverflowProtocol
factory.overflowed = False
# Try to make a second connection again, now that we have an overflow
# protocol. Note that overflow connections count towards the connection
# count.
op = factory.buildProtocol(None)
op.makeConnection(None) # to trigger connectionMade
self.assertEqual(True, factory.overflowed)
self.assertEqual(2, factory.connectionCount)
# Close the connections.
p.connectionLost(None)
self.assertEqual(1, factory.connectionCount)
op.connectionLost(None)
self.assertEqual(0, factory.connectionCount)
class WriteSequenceEchoProtocol(EchoProtocol):
def dataReceived(self, bytes):
if bytes.find(b'vector!') != -1:
self.transport.writeSequence([bytes])
else:
EchoProtocol.dataReceived(self, bytes)
class TestLoggingFactory(policies.TrafficLoggingFactory):
openFile = None
def open(self, name):
assert self.openFile is None, "open() called too many times"
self.openFile = NativeStringIO()
return self.openFile
class LoggingFactoryTestCase(unittest.TestCase):
"""
Tests for L{policies.TrafficLoggingFactory}.
"""
def test_thingsGetLogged(self):
"""
Check the output produced by L{policies.TrafficLoggingFactory}.
"""
wrappedFactory = Server()
wrappedFactory.protocol = WriteSequenceEchoProtocol
t = StringTransportWithDisconnection()
f = TestLoggingFactory(wrappedFactory, 'test')
p = f.buildProtocol(('1.2.3.4', 5678))
t.protocol = p
p.makeConnection(t)
v = f.openFile.getvalue()
self.assertIn('*', v)
self.failIf(t.value())
p.dataReceived(b'here are some bytes')
v = f.openFile.getvalue()
self.assertIn("C 1: %r" % (b'here are some bytes',), v)
self.assertIn("S 1: %r" % (b'here are some bytes',), v)
self.assertEqual(t.value(), b'here are some bytes')
t.clear()
p.dataReceived(b'prepare for vector! to the extreme')
v = f.openFile.getvalue()
self.assertIn("SV 1: %r" % ([b'prepare for vector! to the extreme'],), v)
self.assertEqual(t.value(), b'prepare for vector! to the extreme')
p.loseConnection()
v = f.openFile.getvalue()
self.assertIn('ConnectionDone', v)
def test_counter(self):
"""
Test counter management with the resetCounter method.
"""
wrappedFactory = Server()
f = TestLoggingFactory(wrappedFactory, 'test')
self.assertEqual(f._counter, 0)
f.buildProtocol(('1.2.3.4', 5678))
self.assertEqual(f._counter, 1)
# Reset log file
f.openFile = None
f.buildProtocol(('1.2.3.4', 5679))
self.assertEqual(f._counter, 2)
f.resetCounter()
self.assertEqual(f._counter, 0)

View file

@ -0,0 +1,108 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for twisted.protocols.postfix module.
"""
from twisted.trial import unittest
from twisted.protocols import postfix
from twisted.test.proto_helpers import StringTransport
class PostfixTCPMapQuoteTestCase(unittest.TestCase):
data = [
# (raw, quoted, [aliasQuotedForms]),
('foo', 'foo'),
('foo bar', 'foo%20bar'),
('foo\tbar', 'foo%09bar'),
('foo\nbar', 'foo%0Abar', 'foo%0abar'),
('foo\r\nbar', 'foo%0D%0Abar', 'foo%0D%0abar', 'foo%0d%0Abar', 'foo%0d%0abar'),
('foo ', 'foo%20'),
(' foo', '%20foo'),
]
def testData(self):
for entry in self.data:
raw = entry[0]
quoted = entry[1:]
self.assertEqual(postfix.quote(raw), quoted[0])
for q in quoted:
self.assertEqual(postfix.unquote(q), raw)
class PostfixTCPMapServerTestCase:
data = {
# 'key': 'value',
}
chat = [
# (input, expected_output),
]
def test_chat(self):
"""
Test that I{get} and I{put} commands are responded to correctly by
L{postfix.PostfixTCPMapServer} when its factory is an instance of
L{postifx.PostfixTCPMapDictServerFactory}.
"""
factory = postfix.PostfixTCPMapDictServerFactory(self.data)
transport = StringTransport()
protocol = postfix.PostfixTCPMapServer()
protocol.service = factory
protocol.factory = factory
protocol.makeConnection(transport)
for input, expected_output in self.chat:
protocol.lineReceived(input)
self.assertEqual(
transport.value(), expected_output,
'For %r, expected %r but got %r' % (
input, expected_output, transport.value()))
transport.clear()
protocol.setTimeout(None)
def test_deferredChat(self):
"""
Test that I{get} and I{put} commands are responded to correctly by
L{postfix.PostfixTCPMapServer} when its factory is an instance of
L{postifx.PostfixTCPMapDeferringDictServerFactory}.
"""
factory = postfix.PostfixTCPMapDeferringDictServerFactory(self.data)
transport = StringTransport()
protocol = postfix.PostfixTCPMapServer()
protocol.service = factory
protocol.factory = factory
protocol.makeConnection(transport)
for input, expected_output in self.chat:
protocol.lineReceived(input)
self.assertEqual(
transport.value(), expected_output,
'For %r, expected %r but got %r' % (
input, expected_output, transport.value()))
transport.clear()
protocol.setTimeout(None)
class Valid(PostfixTCPMapServerTestCase, unittest.TestCase):
data = {
'foo': 'ThisIs Foo',
'bar': ' bar really is found\r\n',
}
chat = [
('get', "400 Command 'get' takes 1 parameters.\n"),
('get foo bar', "500 \n"),
('put', "400 Command 'put' takes 2 parameters.\n"),
('put foo', "400 Command 'put' takes 2 parameters.\n"),
('put foo bar baz', "500 put is not implemented yet.\n"),
('put foo bar', '500 put is not implemented yet.\n'),
('get foo', '200 ThisIs%20Foo\n'),
('get bar', '200 %20bar%20really%20is%20found%0D%0A\n'),
('get baz', '500 \n'),
('foo', '400 unknown command\n'),
]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,236 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for twisted.protocols package.
"""
from twisted.trial import unittest
from twisted.protocols import wire, portforward
from twisted.internet import reactor, defer, address, protocol
from twisted.test import proto_helpers
class WireTestCase(unittest.TestCase):
"""
Test wire protocols.
"""
def test_echo(self):
"""
Test wire.Echo protocol: send some data and check it send it back.
"""
t = proto_helpers.StringTransport()
a = wire.Echo()
a.makeConnection(t)
a.dataReceived("hello")
a.dataReceived("world")
a.dataReceived("how")
a.dataReceived("are")
a.dataReceived("you")
self.assertEqual(t.value(), "helloworldhowareyou")
def test_who(self):
"""
Test wire.Who protocol.
"""
t = proto_helpers.StringTransport()
a = wire.Who()
a.makeConnection(t)
self.assertEqual(t.value(), "root\r\n")
def test_QOTD(self):
"""
Test wire.QOTD protocol.
"""
t = proto_helpers.StringTransport()
a = wire.QOTD()
a.makeConnection(t)
self.assertEqual(t.value(),
"An apple a day keeps the doctor away.\r\n")
def test_discard(self):
"""
Test wire.Discard protocol.
"""
t = proto_helpers.StringTransport()
a = wire.Discard()
a.makeConnection(t)
a.dataReceived("hello")
a.dataReceived("world")
a.dataReceived("how")
a.dataReceived("are")
a.dataReceived("you")
self.assertEqual(t.value(), "")
class TestableProxyClientFactory(portforward.ProxyClientFactory):
"""
Test proxy client factory that keeps the last created protocol instance.
@ivar protoInstance: the last instance of the protocol.
@type protoInstance: L{portforward.ProxyClient}
"""
def buildProtocol(self, addr):
"""
Create the protocol instance and keeps track of it.
"""
proto = portforward.ProxyClientFactory.buildProtocol(self, addr)
self.protoInstance = proto
return proto
class TestableProxyFactory(portforward.ProxyFactory):
"""
Test proxy factory that keeps the last created protocol instance.
@ivar protoInstance: the last instance of the protocol.
@type protoInstance: L{portforward.ProxyServer}
@ivar clientFactoryInstance: client factory used by C{protoInstance} to
create forward connections.
@type clientFactoryInstance: L{TestableProxyClientFactory}
"""
def buildProtocol(self, addr):
"""
Create the protocol instance, keeps track of it, and makes it use
C{clientFactoryInstance} as client factory.
"""
proto = portforward.ProxyFactory.buildProtocol(self, addr)
self.clientFactoryInstance = TestableProxyClientFactory()
# Force the use of this specific instance
proto.clientProtocolFactory = lambda: self.clientFactoryInstance
self.protoInstance = proto
return proto
class Portforwarding(unittest.TestCase):
"""
Test port forwarding.
"""
def setUp(self):
self.serverProtocol = wire.Echo()
self.clientProtocol = protocol.Protocol()
self.openPorts = []
def tearDown(self):
try:
self.proxyServerFactory.protoInstance.transport.loseConnection()
except AttributeError:
pass
try:
pi = self.proxyServerFactory.clientFactoryInstance.protoInstance
pi.transport.loseConnection()
except AttributeError:
pass
try:
self.clientProtocol.transport.loseConnection()
except AttributeError:
pass
try:
self.serverProtocol.transport.loseConnection()
except AttributeError:
pass
return defer.gatherResults(
[defer.maybeDeferred(p.stopListening) for p in self.openPorts])
def test_portforward(self):
"""
Test port forwarding through Echo protocol.
"""
realServerFactory = protocol.ServerFactory()
realServerFactory.protocol = lambda: self.serverProtocol
realServerPort = reactor.listenTCP(0, realServerFactory,
interface='127.0.0.1')
self.openPorts.append(realServerPort)
self.proxyServerFactory = TestableProxyFactory('127.0.0.1',
realServerPort.getHost().port)
proxyServerPort = reactor.listenTCP(0, self.proxyServerFactory,
interface='127.0.0.1')
self.openPorts.append(proxyServerPort)
nBytes = 1000
received = []
d = defer.Deferred()
def testDataReceived(data):
received.extend(data)
if len(received) >= nBytes:
self.assertEqual(''.join(received), 'x' * nBytes)
d.callback(None)
self.clientProtocol.dataReceived = testDataReceived
def testConnectionMade():
self.clientProtocol.transport.write('x' * nBytes)
self.clientProtocol.connectionMade = testConnectionMade
clientFactory = protocol.ClientFactory()
clientFactory.protocol = lambda: self.clientProtocol
reactor.connectTCP(
'127.0.0.1', proxyServerPort.getHost().port, clientFactory)
return d
def test_registerProducers(self):
"""
The proxy client registers itself as a producer of the proxy server and
vice versa.
"""
# create a ProxyServer instance
addr = address.IPv4Address('TCP', '127.0.0.1', 0)
server = portforward.ProxyFactory('127.0.0.1', 0).buildProtocol(addr)
# set the reactor for this test
reactor = proto_helpers.MemoryReactor()
server.reactor = reactor
# make the connection
serverTransport = proto_helpers.StringTransport()
server.makeConnection(serverTransport)
# check that the ProxyClientFactory is connecting to the backend
self.assertEqual(len(reactor.tcpClients), 1)
# get the factory instance and check it's the one we expect
host, port, clientFactory, timeout, _ = reactor.tcpClients[0]
self.assertIsInstance(clientFactory, portforward.ProxyClientFactory)
# Connect it
client = clientFactory.buildProtocol(addr)
clientTransport = proto_helpers.StringTransport()
client.makeConnection(clientTransport)
# check that the producers are registered
self.assertIdentical(clientTransport.producer, serverTransport)
self.assertIdentical(serverTransport.producer, clientTransport)
# check the streaming attribute in both transports
self.assertTrue(clientTransport.streaming)
self.assertTrue(serverTransport.streaming)
class StringTransportTestCase(unittest.TestCase):
"""
Test L{proto_helpers.StringTransport} helper behaviour.
"""
def test_noUnicode(self):
"""
Test that L{proto_helpers.StringTransport} doesn't accept unicode data.
"""
s = proto_helpers.StringTransport()
self.assertRaises(TypeError, s.write, u'foo')

View file

@ -0,0 +1,121 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for L{twisted.python.randbytes}.
"""
from __future__ import division, absolute_import
import os
from twisted.trial import unittest
from twisted.python import randbytes
class SecureRandomTestCaseBase(object):
"""
Base class for secureRandom test cases.
"""
def _check(self, source):
"""
The given random bytes source should return the number of bytes
requested each time it is called and should probably not return the
same bytes on two consecutive calls (although this is a perfectly
legitimate occurrence and rejecting it may generate a spurious failure
-- maybe we'll get lucky and the heat death with come first).
"""
for nbytes in range(17, 25):
s = source(nbytes)
self.assertEqual(len(s), nbytes)
s2 = source(nbytes)
self.assertEqual(len(s2), nbytes)
# This is crude but hey
self.assertNotEquals(s2, s)
class SecureRandomTestCase(SecureRandomTestCaseBase, unittest.TestCase):
"""
Test secureRandom under normal conditions.
"""
def test_normal(self):
"""
L{randbytes.secureRandom} should return a string of the requested
length and make some effort to make its result otherwise unpredictable.
"""
self._check(randbytes.secureRandom)
class ConditionalSecureRandomTestCase(SecureRandomTestCaseBase,
unittest.SynchronousTestCase):
"""
Test random sources one by one, then remove it to.
"""
def setUp(self):
"""
Create a L{randbytes.RandomFactory} to use in the tests.
"""
self.factory = randbytes.RandomFactory()
def errorFactory(self, nbytes):
"""
A factory raising an error when a source is not available.
"""
raise randbytes.SourceNotAvailable()
def test_osUrandom(self):
"""
L{RandomFactory._osUrandom} should work as a random source whenever
L{os.urandom} is available.
"""
self._check(self.factory._osUrandom)
def test_withoutAnything(self):
"""
Remove all secure sources and assert it raises a failure. Then try the
fallback parameter.
"""
self.factory._osUrandom = self.errorFactory
self.assertRaises(randbytes.SecureRandomNotAvailable,
self.factory.secureRandom, 18)
def wrapper():
return self.factory.secureRandom(18, fallback=True)
s = self.assertWarns(
RuntimeWarning,
"urandom unavailable - "
"proceeding with non-cryptographically secure random source",
__file__,
wrapper)
self.assertEqual(len(s), 18)
class RandomTestCaseBase(SecureRandomTestCaseBase, unittest.SynchronousTestCase):
"""
'Normal' random test cases.
"""
def test_normal(self):
"""
Test basic case.
"""
self._check(randbytes.insecureRandom)
def test_withoutGetrandbits(self):
"""
Test C{insecureRandom} without C{random.getrandbits}.
"""
factory = randbytes.RandomFactory()
factory.getrandbits = None
self._check(factory.insecureRandom)

View file

@ -0,0 +1,252 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import sys, os
import types
from twisted.trial import unittest
from twisted.python import rebuild
import crash_test_dummy
f = crash_test_dummy.foo
class Foo: pass
class Bar(Foo): pass
class Baz(object): pass
class Buz(Bar, Baz): pass
class HashRaisesRuntimeError:
"""
Things that don't hash (raise an Exception) should be ignored by the
rebuilder.
@ivar hashCalled: C{bool} set to True when __hash__ is called.
"""
def __init__(self):
self.hashCalled = False
def __hash__(self):
self.hashCalled = True
raise RuntimeError('not a TypeError!')
unhashableObject = None # set in test_hashException
class RebuildTestCase(unittest.TestCase):
"""
Simple testcase for rebuilding, to at least exercise the code.
"""
def setUp(self):
self.libPath = self.mktemp()
os.mkdir(self.libPath)
self.fakelibPath = os.path.join(self.libPath, 'twisted_rebuild_fakelib')
os.mkdir(self.fakelibPath)
file(os.path.join(self.fakelibPath, '__init__.py'), 'w').close()
sys.path.insert(0, self.libPath)
def tearDown(self):
sys.path.remove(self.libPath)
def testFileRebuild(self):
from twisted.python.util import sibpath
import shutil, time
shutil.copyfile(sibpath(__file__, "myrebuilder1.py"),
os.path.join(self.fakelibPath, "myrebuilder.py"))
from twisted_rebuild_fakelib import myrebuilder
a = myrebuilder.A()
try:
object
except NameError:
pass
else:
from twisted.test import test_rebuild
b = myrebuilder.B()
class C(myrebuilder.B):
pass
test_rebuild.C = C
c = C()
i = myrebuilder.Inherit()
self.assertEqual(a.a(), 'a')
# necessary because the file has not "changed" if a second has not gone
# by in unix. This sucks, but it's not often that you'll be doing more
# than one reload per second.
time.sleep(1.1)
shutil.copyfile(sibpath(__file__, "myrebuilder2.py"),
os.path.join(self.fakelibPath, "myrebuilder.py"))
rebuild.rebuild(myrebuilder)
try:
object
except NameError:
pass
else:
b2 = myrebuilder.B()
self.assertEqual(b2.b(), 'c')
self.assertEqual(b.b(), 'c')
self.assertEqual(i.a(), 'd')
self.assertEqual(a.a(), 'b')
# more work to be done on new-style classes
# self.assertEqual(c.b(), 'c')
def testRebuild(self):
"""
Rebuilding an unchanged module.
"""
# This test would actually pass if rebuild was a no-op, but it
# ensures rebuild doesn't break stuff while being a less
# complex test than testFileRebuild.
x = crash_test_dummy.X('a')
rebuild.rebuild(crash_test_dummy, doLog=False)
# Instance rebuilding is triggered by attribute access.
x.do()
self.failUnlessIdentical(x.__class__, crash_test_dummy.X)
self.failUnlessIdentical(f, crash_test_dummy.foo)
def testComponentInteraction(self):
x = crash_test_dummy.XComponent()
x.setAdapter(crash_test_dummy.IX, crash_test_dummy.XA)
oldComponent = x.getComponent(crash_test_dummy.IX)
rebuild.rebuild(crash_test_dummy, 0)
newComponent = x.getComponent(crash_test_dummy.IX)
newComponent.method()
self.assertEqual(newComponent.__class__, crash_test_dummy.XA)
# Test that a duplicate registerAdapter is not allowed
from twisted.python import components
self.failUnlessRaises(ValueError, components.registerAdapter,
crash_test_dummy.XA, crash_test_dummy.X,
crash_test_dummy.IX)
def testUpdateInstance(self):
global Foo, Buz
b = Buz()
class Foo:
def foo(self):
pass
class Buz(Bar, Baz):
x = 10
rebuild.updateInstance(b)
assert hasattr(b, 'foo'), "Missing method on rebuilt instance"
assert hasattr(b, 'x'), "Missing class attribute on rebuilt instance"
def testBananaInteraction(self):
from twisted.python import rebuild
from twisted.spread import banana
rebuild.latestClass(banana.Banana)
def test_hashException(self):
"""
Rebuilding something that has a __hash__ that raises a non-TypeError
shouldn't cause rebuild to die.
"""
global unhashableObject
unhashableObject = HashRaisesRuntimeError()
def _cleanup():
global unhashableObject
unhashableObject = None
self.addCleanup(_cleanup)
rebuild.rebuild(rebuild)
self.assertEqual(unhashableObject.hashCalled, True)
class NewStyleTestCase(unittest.TestCase):
"""
Tests for rebuilding new-style classes of various sorts.
"""
def setUp(self):
self.m = types.ModuleType('whipping')
sys.modules['whipping'] = self.m
def tearDown(self):
del sys.modules['whipping']
del self.m
def test_slots(self):
"""
Try to rebuild a new style class with slots defined.
"""
classDefinition = (
"class SlottedClass(object):\n"
" __slots__ = ['a']\n")
exec classDefinition in self.m.__dict__
inst = self.m.SlottedClass()
inst.a = 7
exec classDefinition in self.m.__dict__
rebuild.updateInstance(inst)
self.assertEqual(inst.a, 7)
self.assertIdentical(type(inst), self.m.SlottedClass)
if sys.version_info < (2, 6):
test_slots.skip = "__class__ assignment for class with slots is only available starting Python 2.6"
def test_errorSlots(self):
"""
Try to rebuild a new style class with slots defined: this should fail.
"""
classDefinition = (
"class SlottedClass(object):\n"
" __slots__ = ['a']\n")
exec classDefinition in self.m.__dict__
inst = self.m.SlottedClass()
inst.a = 7
exec classDefinition in self.m.__dict__
self.assertRaises(rebuild.RebuildError, rebuild.updateInstance, inst)
if sys.version_info >= (2, 6):
test_errorSlots.skip = "__class__ assignment for class with slots should work starting Python 2.6"
def test_typeSubclass(self):
"""
Try to rebuild a base type subclass.
"""
classDefinition = (
"class ListSubclass(list):\n"
" pass\n")
exec classDefinition in self.m.__dict__
inst = self.m.ListSubclass()
inst.append(2)
exec classDefinition in self.m.__dict__
rebuild.updateInstance(inst)
self.assertEqual(inst[0], 2)
self.assertIdentical(type(inst), self.m.ListSubclass)
def test_instanceSlots(self):
"""
Test that when rebuilding an instance with a __slots__ attribute, it
fails accurately instead of giving a L{rebuild.RebuildError}.
"""
classDefinition = (
"class NotSlottedClass(object):\n"
" pass\n")
exec classDefinition in self.m.__dict__
inst = self.m.NotSlottedClass()
inst.__slots__ = ['a']
classDefinition = (
"class NotSlottedClass:\n"
" pass\n")
exec classDefinition in self.m.__dict__
# Moving from new-style class to old-style should fail.
self.assertRaises(TypeError, rebuild.updateInstance, inst)

View file

@ -0,0 +1,900 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for the L{twisted.python.reflect} module.
"""
from __future__ import division, absolute_import
import os
import weakref
from collections import deque
from twisted.python.compat import _PY3
from twisted.trial import unittest
from twisted.trial.unittest import SynchronousTestCase as TestCase
from twisted.python import reflect
from twisted.python.reflect import (
accumulateMethods, prefixedMethods, prefixedMethodNames,
addMethodNamesToDict, fullyQualifiedName)
from twisted.python.versions import Version
class Base(object):
"""
A no-op class which can be used to verify the behavior of
method-discovering APIs.
"""
def method(self):
"""
A no-op method which can be discovered.
"""
class Sub(Base):
"""
A subclass of a class with a method which can be discovered.
"""
class Separate(object):
"""
A no-op class with methods with differing prefixes.
"""
def good_method(self):
"""
A no-op method which a matching prefix to be discovered.
"""
def bad_method(self):
"""
A no-op method with a mismatched prefix to not be discovered.
"""
class AccumulateMethodsTests(TestCase):
"""
Tests for L{accumulateMethods} which finds methods on a class hierarchy and
adds them to a dictionary.
"""
def test_ownClass(self):
"""
If x is and instance of Base and Base defines a method named method,
L{accumulateMethods} adds an item to the given dictionary with
C{"method"} as the key and a bound method object for Base.method value.
"""
x = Base()
output = {}
accumulateMethods(x, output)
self.assertEqual({"method": x.method}, output)
def test_baseClass(self):
"""
If x is an instance of Sub and Sub is a subclass of Base and Base
defines a method named method, L{accumulateMethods} adds an item to the
given dictionary with C{"method"} as the key and a bound method object
for Base.method as the value.
"""
x = Sub()
output = {}
accumulateMethods(x, output)
self.assertEqual({"method": x.method}, output)
def test_prefix(self):
"""
If a prefix is given, L{accumulateMethods} limits its results to
methods beginning with that prefix. Keys in the resulting dictionary
also have the prefix removed from them.
"""
x = Separate()
output = {}
accumulateMethods(x, output, 'good_')
self.assertEqual({'method': x.good_method}, output)
class PrefixedMethodsTests(TestCase):
"""
Tests for L{prefixedMethods} which finds methods on a class hierarchy and
adds them to a dictionary.
"""
def test_onlyObject(self):
"""
L{prefixedMethods} returns a list of the methods discovered on an
object.
"""
x = Base()
output = prefixedMethods(x)
self.assertEqual([x.method], output)
def test_prefix(self):
"""
If a prefix is given, L{prefixedMethods} returns only methods named
with that prefix.
"""
x = Separate()
output = prefixedMethods(x, 'good_')
self.assertEqual([x.good_method], output)
class PrefixedMethodNamesTests(TestCase):
"""
Tests for L{prefixedMethodNames}.
"""
def test_method(self):
"""
L{prefixedMethodNames} returns a list including methods with the given
prefix defined on the class passed to it.
"""
self.assertEqual(["method"], prefixedMethodNames(Separate, "good_"))
def test_inheritedMethod(self):
"""
L{prefixedMethodNames} returns a list included methods with the given
prefix defined on base classes of the class passed to it.
"""
class Child(Separate):
pass
self.assertEqual(["method"], prefixedMethodNames(Child, "good_"))
class AddMethodNamesToDictTests(TestCase):
"""
Tests for L{addMethodNamesToDict}.
"""
def test_baseClass(self):
"""
If C{baseClass} is passed to L{addMethodNamesToDict}, only methods which
are a subclass of C{baseClass} are added to the result dictionary.
"""
class Alternate(object):
pass
class Child(Separate, Alternate):
def good_alternate(self):
pass
result = {}
addMethodNamesToDict(Child, result, 'good_', Alternate)
self.assertEqual({'alternate': 1}, result)
class Summer(object):
"""
A class we look up as part of the LookupsTestCase.
"""
def reallySet(self):
"""
Do something.
"""
class LookupsTestCase(TestCase):
"""
Tests for L{namedClass}, L{namedModule}, and L{namedAny}.
"""
def test_namedClassLookup(self):
"""
L{namedClass} should return the class object for the name it is passed.
"""
self.assertIdentical(
reflect.namedClass("twisted.test.test_reflect.Summer"),
Summer)
def test_namedModuleLookup(self):
"""
L{namedModule} should return the module object for the name it is
passed.
"""
from twisted.python import monkey
self.assertIdentical(
reflect.namedModule("twisted.python.monkey"), monkey)
def test_namedAnyPackageLookup(self):
"""
L{namedAny} should return the package object for the name it is passed.
"""
import twisted.python
self.assertIdentical(
reflect.namedAny("twisted.python"), twisted.python)
def test_namedAnyModuleLookup(self):
"""
L{namedAny} should return the module object for the name it is passed.
"""
from twisted.python import monkey
self.assertIdentical(
reflect.namedAny("twisted.python.monkey"), monkey)
def test_namedAnyClassLookup(self):
"""
L{namedAny} should return the class object for the name it is passed.
"""
self.assertIdentical(
reflect.namedAny("twisted.test.test_reflect.Summer"),
Summer)
def test_namedAnyAttributeLookup(self):
"""
L{namedAny} should return the object an attribute of a non-module,
non-package object is bound to for the name it is passed.
"""
# Note - not assertEqual because unbound method lookup creates a new
# object every time. This is a foolishness of Python's object
# implementation, not a bug in Twisted.
self.assertEqual(
reflect.namedAny(
"twisted.test.test_reflect.Summer.reallySet"),
Summer.reallySet)
def test_namedAnySecondAttributeLookup(self):
"""
L{namedAny} should return the object an attribute of an object which
itself was an attribute of a non-module, non-package object is bound to
for the name it is passed.
"""
self.assertIdentical(
reflect.namedAny(
"twisted.test.test_reflect."
"Summer.reallySet.__doc__"),
Summer.reallySet.__doc__)
def test_importExceptions(self):
"""
Exceptions raised by modules which L{namedAny} causes to be imported
should pass through L{namedAny} to the caller.
"""
self.assertRaises(
ZeroDivisionError,
reflect.namedAny, "twisted.test.reflect_helper_ZDE")
# Make sure that there is post-failed-import cleanup
self.assertRaises(
ZeroDivisionError,
reflect.namedAny, "twisted.test.reflect_helper_ZDE")
self.assertRaises(
ValueError,
reflect.namedAny, "twisted.test.reflect_helper_VE")
# Modules which themselves raise ImportError when imported should
# result in an ImportError
self.assertRaises(
ImportError,
reflect.namedAny, "twisted.test.reflect_helper_IE")
def test_attributeExceptions(self):
"""
If segments on the end of a fully-qualified Python name represents
attributes which aren't actually present on the object represented by
the earlier segments, L{namedAny} should raise an L{AttributeError}.
"""
self.assertRaises(
AttributeError,
reflect.namedAny, "twisted.nosuchmoduleintheworld")
# ImportError behaves somewhat differently between "import
# extant.nonextant" and "import extant.nonextant.nonextant", so test
# the latter as well.
self.assertRaises(
AttributeError,
reflect.namedAny, "twisted.nosuch.modulein.theworld")
self.assertRaises(
AttributeError,
reflect.namedAny,
"twisted.test.test_reflect.Summer.nosuchattribute")
def test_invalidNames(self):
"""
Passing a name which isn't a fully-qualified Python name to L{namedAny}
should result in one of the following exceptions:
- L{InvalidName}: the name is not a dot-separated list of Python
objects
- L{ObjectNotFound}: the object doesn't exist
- L{ModuleNotFound}: the object doesn't exist and there is only one
component in the name
"""
err = self.assertRaises(reflect.ModuleNotFound, reflect.namedAny,
'nosuchmoduleintheworld')
self.assertEqual(str(err), "No module named 'nosuchmoduleintheworld'")
# This is a dot-separated list, but it isn't valid!
err = self.assertRaises(reflect.ObjectNotFound, reflect.namedAny,
"@#$@(#.!@(#!@#")
self.assertEqual(str(err), "'@#$@(#.!@(#!@#' does not name an object")
err = self.assertRaises(reflect.ObjectNotFound, reflect.namedAny,
"tcelfer.nohtyp.detsiwt")
self.assertEqual(
str(err),
"'tcelfer.nohtyp.detsiwt' does not name an object")
err = self.assertRaises(reflect.InvalidName, reflect.namedAny, '')
self.assertEqual(str(err), 'Empty module name')
for invalidName in ['.twisted', 'twisted.', 'twisted..python']:
err = self.assertRaises(
reflect.InvalidName, reflect.namedAny, invalidName)
self.assertEqual(
str(err),
"name must be a string giving a '.'-separated list of Python "
"identifiers, not %r" % (invalidName,))
def test_requireModuleImportError(self):
"""
When module import fails with ImportError it returns the specified
default value.
"""
for name in ['nosuchmtopodule', 'no.such.module']:
default = object()
result = reflect.requireModule(name, default=default)
self.assertIs(result, default)
def test_requireModuleDefaultNone(self):
"""
When module import fails it returns C{None} by default.
"""
result = reflect.requireModule('no.such.module')
self.assertIs(None, result)
def test_requireModuleRequestedImport(self):
"""
When module import succeed it returns the module and not the default
value.
"""
from twisted.python import monkey
default = object()
self.assertIs(
reflect.requireModule('twisted.python.monkey', default=default),
monkey,
)
class Breakable(object):
breakRepr = False
breakStr = False
def __str__(self):
if self.breakStr:
raise RuntimeError("str!")
else:
return '<Breakable>'
def __repr__(self):
if self.breakRepr:
raise RuntimeError("repr!")
else:
return 'Breakable()'
class BrokenType(Breakable, type):
breakName = False
def get___name__(self):
if self.breakName:
raise RuntimeError("no name")
return 'BrokenType'
__name__ = property(get___name__)
BTBase = BrokenType('BTBase', (Breakable,),
{"breakRepr": True,
"breakStr": True})
class NoClassAttr(Breakable):
__class__ = property(lambda x: x.not_class)
class SafeRepr(TestCase):
"""
Tests for L{reflect.safe_repr} function.
"""
def test_workingRepr(self):
"""
L{reflect.safe_repr} produces the same output as C{repr} on a working
object.
"""
x = [1, 2, 3]
self.assertEqual(reflect.safe_repr(x), repr(x))
def test_brokenRepr(self):
"""
L{reflect.safe_repr} returns a string with class name, address, and
traceback when the repr call failed.
"""
b = Breakable()
b.breakRepr = True
bRepr = reflect.safe_repr(b)
self.assertIn("Breakable instance at 0x", bRepr)
# Check that the file is in the repr, but without the extension as it
# can be .py/.pyc
self.assertIn(os.path.splitext(__file__)[0], bRepr)
self.assertIn("RuntimeError: repr!", bRepr)
def test_brokenStr(self):
"""
L{reflect.safe_repr} isn't affected by a broken C{__str__} method.
"""
b = Breakable()
b.breakStr = True
self.assertEqual(reflect.safe_repr(b), repr(b))
def test_brokenClassRepr(self):
class X(BTBase):
breakRepr = True
reflect.safe_repr(X)
reflect.safe_repr(X())
def test_brokenReprIncludesID(self):
"""
C{id} is used to print the ID of the object in case of an error.
L{safe_repr} includes a traceback after a newline, so we only check
against the first line of the repr.
"""
class X(BTBase):
breakRepr = True
xRepr = reflect.safe_repr(X)
xReprExpected = ('<BrokenType instance at 0x%x with repr error:'
% (id(X),))
self.assertEqual(xReprExpected, xRepr.split('\n')[0])
def test_brokenClassStr(self):
class X(BTBase):
breakStr = True
reflect.safe_repr(X)
reflect.safe_repr(X())
def test_brokenClassAttribute(self):
"""
If an object raises an exception when accessing its C{__class__}
attribute, L{reflect.safe_repr} uses C{type} to retrieve the class
object.
"""
b = NoClassAttr()
b.breakRepr = True
bRepr = reflect.safe_repr(b)
self.assertIn("NoClassAttr instance at 0x", bRepr)
self.assertIn(os.path.splitext(__file__)[0], bRepr)
self.assertIn("RuntimeError: repr!", bRepr)
def test_brokenClassNameAttribute(self):
"""
If a class raises an exception when accessing its C{__name__} attribute
B{and} when calling its C{__str__} implementation, L{reflect.safe_repr}
returns 'BROKEN CLASS' instead of the class name.
"""
class X(BTBase):
breakName = True
xRepr = reflect.safe_repr(X())
self.assertIn("<BROKEN CLASS AT 0x", xRepr)
self.assertIn(os.path.splitext(__file__)[0], xRepr)
self.assertIn("RuntimeError: repr!", xRepr)
class SafeStr(TestCase):
"""
Tests for L{reflect.safe_str} function.
"""
def test_workingStr(self):
x = [1, 2, 3]
self.assertEqual(reflect.safe_str(x), str(x))
def test_brokenStr(self):
b = Breakable()
b.breakStr = True
reflect.safe_str(b)
def test_brokenRepr(self):
b = Breakable()
b.breakRepr = True
reflect.safe_str(b)
def test_brokenClassStr(self):
class X(BTBase):
breakStr = True
reflect.safe_str(X)
reflect.safe_str(X())
def test_brokenClassRepr(self):
class X(BTBase):
breakRepr = True
reflect.safe_str(X)
reflect.safe_str(X())
def test_brokenClassAttribute(self):
"""
If an object raises an exception when accessing its C{__class__}
attribute, L{reflect.safe_str} uses C{type} to retrieve the class
object.
"""
b = NoClassAttr()
b.breakStr = True
bStr = reflect.safe_str(b)
self.assertIn("NoClassAttr instance at 0x", bStr)
self.assertIn(os.path.splitext(__file__)[0], bStr)
self.assertIn("RuntimeError: str!", bStr)
def test_brokenClassNameAttribute(self):
"""
If a class raises an exception when accessing its C{__name__} attribute
B{and} when calling its C{__str__} implementation, L{reflect.safe_str}
returns 'BROKEN CLASS' instead of the class name.
"""
class X(BTBase):
breakName = True
xStr = reflect.safe_str(X())
self.assertIn("<BROKEN CLASS AT 0x", xStr)
self.assertIn(os.path.splitext(__file__)[0], xStr)
self.assertIn("RuntimeError: str!", xStr)
class FilenameToModule(TestCase):
"""
Test L{filenameToModuleName} detection.
"""
def setUp(self):
self.path = os.path.join(self.mktemp(), "fakepackage", "test")
os.makedirs(self.path)
with open(os.path.join(self.path, "__init__.py"), "w") as f:
f.write("")
with open(os.path.join(os.path.dirname(self.path), "__init__.py"),
"w") as f:
f.write("")
def test_directory(self):
"""
L{filenameToModuleName} returns the correct module (a package) given a
directory.
"""
module = reflect.filenameToModuleName(self.path)
self.assertEqual(module, 'fakepackage.test')
module = reflect.filenameToModuleName(self.path + os.path.sep)
self.assertEqual(module, 'fakepackage.test')
def test_file(self):
"""
L{filenameToModuleName} returns the correct module given the path to
its file.
"""
module = reflect.filenameToModuleName(
os.path.join(self.path, 'test_reflect.py'))
self.assertEqual(module, 'fakepackage.test.test_reflect')
def test_bytes(self):
"""
L{filenameToModuleName} returns the correct module given a C{bytes}
path to its file.
"""
module = reflect.filenameToModuleName(
os.path.join(self.path.encode("utf-8"), b'test_reflect.py'))
# Module names are always native string:
self.assertEqual(module, 'fakepackage.test.test_reflect')
class FullyQualifiedNameTests(TestCase):
"""
Test for L{fullyQualifiedName}.
"""
def _checkFullyQualifiedName(self, obj, expected):
"""
Helper to check that fully qualified name of C{obj} results to
C{expected}.
"""
self.assertEqual(fullyQualifiedName(obj), expected)
def test_package(self):
"""
L{fullyQualifiedName} returns the full name of a package and a
subpackage.
"""
import twisted
self._checkFullyQualifiedName(twisted, 'twisted')
import twisted.python
self._checkFullyQualifiedName(twisted.python, 'twisted.python')
def test_module(self):
"""
L{fullyQualifiedName} returns the name of a module inside a a package.
"""
import twisted.python.compat
self._checkFullyQualifiedName(
twisted.python.compat, 'twisted.python.compat')
def test_class(self):
"""
L{fullyQualifiedName} returns the name of a class and its module.
"""
self._checkFullyQualifiedName(
FullyQualifiedNameTests,
'%s.FullyQualifiedNameTests' % (__name__,))
def test_function(self):
"""
L{fullyQualifiedName} returns the name of a function inside its module.
"""
self._checkFullyQualifiedName(
fullyQualifiedName, "twisted.python.reflect.fullyQualifiedName")
def test_boundMethod(self):
"""
L{fullyQualifiedName} returns the name of a bound method inside its
class and its module.
"""
self._checkFullyQualifiedName(
self.test_boundMethod,
"%s.%s.test_boundMethod" % (__name__, self.__class__.__name__))
def test_unboundMethod(self):
"""
L{fullyQualifiedName} returns the name of an unbound method inside its
class and its module.
"""
self._checkFullyQualifiedName(
self.__class__.test_unboundMethod,
"%s.%s.test_unboundMethod" % (__name__, self.__class__.__name__))
class ObjectGrep(unittest.TestCase):
if _PY3:
# This is to be removed when fixing #6986
skip = "twisted.python.reflect.objgrep hasn't been ported to Python 3"
def test_dictionary(self):
"""
Test references search through a dictionnary, as a key or as a value.
"""
o = object()
d1 = {None: o}
d2 = {o: None}
self.assertIn("[None]", reflect.objgrep(d1, o, reflect.isSame))
self.assertIn("{None}", reflect.objgrep(d2, o, reflect.isSame))
def test_list(self):
"""
Test references search through a list.
"""
o = object()
L = [None, o]
self.assertIn("[1]", reflect.objgrep(L, o, reflect.isSame))
def test_tuple(self):
"""
Test references search through a tuple.
"""
o = object()
T = (o, None)
self.assertIn("[0]", reflect.objgrep(T, o, reflect.isSame))
def test_instance(self):
"""
Test references search through an object attribute.
"""
class Dummy:
pass
o = object()
d = Dummy()
d.o = o
self.assertIn(".o", reflect.objgrep(d, o, reflect.isSame))
def test_weakref(self):
"""
Test references search through a weakref object.
"""
class Dummy:
pass
o = Dummy()
w1 = weakref.ref(o)
self.assertIn("()", reflect.objgrep(w1, o, reflect.isSame))
def test_boundMethod(self):
"""
Test references search through method special attributes.
"""
class Dummy:
def dummy(self):
pass
o = Dummy()
m = o.dummy
self.assertIn(".__self__",
reflect.objgrep(m, m.__self__, reflect.isSame))
self.assertIn(".__self__.__class__",
reflect.objgrep(m, m.__self__.__class__, reflect.isSame))
self.assertIn(".__func__",
reflect.objgrep(m, m.__func__, reflect.isSame))
def test_everything(self):
"""
Test references search using complex set of objects.
"""
class Dummy:
def method(self):
pass
o = Dummy()
D1 = {(): "baz", None: "Quux", o: "Foosh"}
L = [None, (), D1, 3]
T = (L, {}, Dummy())
D2 = {0: "foo", 1: "bar", 2: T}
i = Dummy()
i.attr = D2
m = i.method
w = weakref.ref(m)
self.assertIn("().__self__.attr[2][0][2]{'Foosh'}",
reflect.objgrep(w, o, reflect.isSame))
def test_depthLimit(self):
"""
Test the depth of references search.
"""
a = []
b = [a]
c = [a, b]
d = [a, c]
self.assertEqual(['[0]'], reflect.objgrep(d, a, reflect.isSame, maxDepth=1))
self.assertEqual(['[0]', '[1][0]'], reflect.objgrep(d, a, reflect.isSame, maxDepth=2))
self.assertEqual(['[0]', '[1][0]', '[1][1][0]'], reflect.objgrep(d, a, reflect.isSame, maxDepth=3))
def test_deque(self):
"""
Test references search through a deque object.
"""
o = object()
D = deque()
D.append(None)
D.append(o)
self.assertIn("[1]", reflect.objgrep(D, o, reflect.isSame))
class GetClass(unittest.TestCase):
if _PY3:
oldClassNames = ['type']
else:
oldClassNames = ['class', 'classobj']
def testOld(self):
class OldClass:
pass
old = OldClass()
self.assertIn(reflect.getClass(OldClass).__name__, self.oldClassNames)
self.assertEqual(reflect.getClass(old).__name__, 'OldClass')
def testNew(self):
class NewClass(object):
pass
new = NewClass()
self.assertEqual(reflect.getClass(NewClass).__name__, 'type')
self.assertEqual(reflect.getClass(new).__name__, 'NewClass')
if not _PY3:
# The functions tested below are deprecated but still used by external
# projects like Nevow 0.10. They are not going to be ported to Python 3
# (hence the condition above) and will be removed as soon as no project used
# by Twisted will depend on these functions. Also, have a look at the
# comments related to those functions in twisted.python.reflect.
class DeprecationTestCase(unittest.TestCase):
"""
Test deprecations in twisted.python.reflect
"""
def test_allYourBase(self):
"""
Test deprecation of L{reflect.allYourBase}. See #5481 for removal.
"""
self.callDeprecated(
(Version("Twisted", 11, 0, 0), "inspect.getmro"),
reflect.allYourBase, DeprecationTestCase)
def test_accumulateBases(self):
"""
Test deprecation of L{reflect.accumulateBases}. See #5481 for removal.
"""
l = []
self.callDeprecated(
(Version("Twisted", 11, 0, 0), "inspect.getmro"),
reflect.accumulateBases, DeprecationTestCase, l, None)
def test_getcurrent(self):
"""
Test deprecation of L{reflect.getcurrent}.
"""
class C:
pass
self.callDeprecated(
Version("Twisted", 14, 0, 0),
reflect.getcurrent, C)
def test_isinst(self):
"""
Test deprecation of L{reflect.isinst}.
"""
self.callDeprecated(
(Version("Twisted", 14, 0, 0), "isinstance"),
reflect.isinst, object(), object)

View file

@ -0,0 +1,63 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.trial import unittest
from twisted.python import roots
import types
class RootsTest(unittest.TestCase):
def testExceptions(self):
request = roots.Request()
try:
request.write("blah")
except NotImplementedError:
pass
else:
self.fail()
try:
request.finish()
except NotImplementedError:
pass
else:
self.fail()
def testCollection(self):
collection = roots.Collection()
collection.putEntity("x", 'test')
self.assertEqual(collection.getStaticEntity("x"),
'test')
collection.delEntity("x")
self.assertEqual(collection.getStaticEntity('x'),
None)
try:
collection.storeEntity("x", None)
except NotImplementedError:
pass
else:
self.fail()
try:
collection.removeEntity("x", None)
except NotImplementedError:
pass
else:
self.fail()
def testConstrained(self):
class const(roots.Constrained):
def nameConstraint(self, name):
return (name == 'x')
c = const()
self.assertEqual(c.putEntity('x', 'test'), None)
self.failUnlessRaises(roots.ConstraintViolation,
c.putEntity, 'y', 'test')
def testHomogenous(self):
h = roots.Homogenous()
h.entityType = types.IntType
h.putEntity('a', 1)
self.assertEqual(h.getStaticEntity('a'),1 )
self.failUnlessRaises(roots.ConstraintViolation,
h.putEntity, 'x', 'y')

View file

@ -0,0 +1,61 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for C{setup.py}, Twisted's distutils integration file.
"""
from __future__ import division, absolute_import
import os, sys
import twisted
from twisted.trial.unittest import SynchronousTestCase
from twisted.python.filepath import FilePath
from twisted.python.dist import getExtensions
# Get rid of the UTF-8 encoding and bytes topfiles segment when FilePath
# supports unicode. #2366, #4736, #5203. Also #4743, which requires checking
# setup.py, not just the topfiles directory.
if not FilePath(twisted.__file__.encode('utf-8')).sibling(b'topfiles').child(b'setup.py').exists():
sourceSkip = "Only applies to source checkout of Twisted"
else:
sourceSkip = None
class TwistedExtensionsTests(SynchronousTestCase):
if sourceSkip is not None:
skip = sourceSkip
def setUp(self):
"""
Change the working directory to the parent of the C{twisted} package so
that L{twisted.python.dist.getExtensions} finds Twisted's own extension
definitions.
"""
self.addCleanup(os.chdir, os.getcwd())
os.chdir(FilePath(twisted.__file__).parent().parent().path)
def test_initgroups(self):
"""
If C{os.initgroups} is present (Python 2.7 and Python 3.3 and newer),
L{twisted.python._initgroups} is not returned as an extension to build
from L{getExtensions}.
"""
extensions = getExtensions()
found = None
for extension in extensions:
if extension.name == "twisted.python._initgroups":
found = extension
if sys.version_info[:2] >= (2, 7):
self.assertIdentical(
None, found,
"Should not have found twisted.python._initgroups extension "
"definition.")
else:
self.assertNotIdentical(
None, found,
"Should have found twisted.python._initgroups extension "
"definition.")

View file

@ -0,0 +1,26 @@
"""Test win32 shortcut script
"""
from twisted.trial import unittest
import os
if os.name == 'nt':
skipWindowsNopywin32 = None
try:
from twisted.python import shortcut
except ImportError:
skipWindowsNopywin32 = ("On windows, twisted.python.shortcut is not "
"available in the absence of win32com.")
import os.path
import sys
class ShortcutTest(unittest.TestCase):
def testCreate(self):
s1=shortcut.Shortcut("test_shortcut.py")
tempname=self.mktemp() + '.lnk'
s1.save(tempname)
self.assert_(os.path.exists(tempname))
sc=shortcut.open(tempname)
self.assert_(sc.GetPath(0)[0].endswith('test_shortcut.py'))
ShortcutTest.skip = skipWindowsNopywin32

View file

@ -0,0 +1,984 @@
# -*- test-case-name: twisted.test.test_sip -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""Session Initialization Protocol tests."""
from twisted.trial import unittest, util
from twisted.protocols import sip
from twisted.internet import defer, reactor, utils
from twisted.python.versions import Version
from twisted.test import proto_helpers
from twisted import cred
import twisted.cred.portal
import twisted.cred.checkers
from zope.interface import implements
# request, prefixed by random CRLFs
request1 = "\n\r\n\n\r" + """\
INVITE sip:foo SIP/2.0
From: mo
To: joe
Content-Length: 4
abcd""".replace("\n", "\r\n")
# request, no content-length
request2 = """INVITE sip:foo SIP/2.0
From: mo
To: joe
1234""".replace("\n", "\r\n")
# request, with garbage after
request3 = """INVITE sip:foo SIP/2.0
From: mo
To: joe
Content-Length: 4
1234
lalalal""".replace("\n", "\r\n")
# three requests
request4 = """INVITE sip:foo SIP/2.0
From: mo
To: joe
Content-Length: 0
INVITE sip:loop SIP/2.0
From: foo
To: bar
Content-Length: 4
abcdINVITE sip:loop SIP/2.0
From: foo
To: bar
Content-Length: 4
1234""".replace("\n", "\r\n")
# response, no content
response1 = """SIP/2.0 200 OK
From: foo
To:bar
Content-Length: 0
""".replace("\n", "\r\n")
# short header version
request_short = """\
INVITE sip:foo SIP/2.0
f: mo
t: joe
l: 4
abcd""".replace("\n", "\r\n")
request_natted = """\
INVITE sip:foo SIP/2.0
Via: SIP/2.0/UDP 10.0.0.1:5060;rport
""".replace("\n", "\r\n")
# multiline headers (example from RFC 3621).
response_multiline = """\
SIP/2.0 200 OK
Via: SIP/2.0/UDP server10.biloxi.com
;branch=z9hG4bKnashds8;received=192.0.2.3
Via: SIP/2.0/UDP bigbox3.site3.atlanta.com
;branch=z9hG4bK77ef4c2312983.1;received=192.0.2.2
Via: SIP/2.0/UDP pc33.atlanta.com
;branch=z9hG4bK776asdhds ;received=192.0.2.1
To: Bob <sip:bob@biloxi.com>;tag=a6c85cf
From: Alice <sip:alice@atlanta.com>;tag=1928301774
Call-ID: a84b4c76e66710@pc33.atlanta.com
CSeq: 314159 INVITE
Contact: <sip:bob@192.0.2.4>
Content-Type: application/sdp
Content-Length: 0
\n""".replace("\n", "\r\n")
class TestRealm:
def requestAvatar(self, avatarId, mind, *interfaces):
return sip.IContact, None, lambda: None
class MessageParsingTestCase(unittest.TestCase):
def setUp(self):
self.l = []
self.parser = sip.MessagesParser(self.l.append)
def feedMessage(self, message):
self.parser.dataReceived(message)
self.parser.dataDone()
def validateMessage(self, m, method, uri, headers, body):
"""Validate Requests."""
self.assertEqual(m.method, method)
self.assertEqual(m.uri.toString(), uri)
self.assertEqual(m.headers, headers)
self.assertEqual(m.body, body)
self.assertEqual(m.finished, 1)
def testSimple(self):
l = self.l
self.feedMessage(request1)
self.assertEqual(len(l), 1)
self.validateMessage(
l[0], "INVITE", "sip:foo",
{"from": ["mo"], "to": ["joe"], "content-length": ["4"]},
"abcd")
def testTwoMessages(self):
l = self.l
self.feedMessage(request1)
self.feedMessage(request2)
self.assertEqual(len(l), 2)
self.validateMessage(
l[0], "INVITE", "sip:foo",
{"from": ["mo"], "to": ["joe"], "content-length": ["4"]},
"abcd")
self.validateMessage(l[1], "INVITE", "sip:foo",
{"from": ["mo"], "to": ["joe"]},
"1234")
def testGarbage(self):
l = self.l
self.feedMessage(request3)
self.assertEqual(len(l), 1)
self.validateMessage(
l[0], "INVITE", "sip:foo",
{"from": ["mo"], "to": ["joe"], "content-length": ["4"]},
"1234")
def testThreeInOne(self):
l = self.l
self.feedMessage(request4)
self.assertEqual(len(l), 3)
self.validateMessage(
l[0], "INVITE", "sip:foo",
{"from": ["mo"], "to": ["joe"], "content-length": ["0"]},
"")
self.validateMessage(
l[1], "INVITE", "sip:loop",
{"from": ["foo"], "to": ["bar"], "content-length": ["4"]},
"abcd")
self.validateMessage(
l[2], "INVITE", "sip:loop",
{"from": ["foo"], "to": ["bar"], "content-length": ["4"]},
"1234")
def testShort(self):
l = self.l
self.feedMessage(request_short)
self.assertEqual(len(l), 1)
self.validateMessage(
l[0], "INVITE", "sip:foo",
{"from": ["mo"], "to": ["joe"], "content-length": ["4"]},
"abcd")
def testSimpleResponse(self):
l = self.l
self.feedMessage(response1)
self.assertEqual(len(l), 1)
m = l[0]
self.assertEqual(m.code, 200)
self.assertEqual(m.phrase, "OK")
self.assertEqual(
m.headers,
{"from": ["foo"], "to": ["bar"], "content-length": ["0"]})
self.assertEqual(m.body, "")
self.assertEqual(m.finished, 1)
def test_multiLine(self):
"""
A header may be split across multiple lines. Subsequent lines begin
with C{" "} or C{"\\t"}.
"""
l = self.l
self.feedMessage(response_multiline)
self.assertEquals(len(l), 1)
m = l[0]
self.assertEquals(
m.headers['via'][0],
"SIP/2.0/UDP server10.biloxi.com;"
"branch=z9hG4bKnashds8;received=192.0.2.3")
self.assertEquals(
m.headers['via'][1],
"SIP/2.0/UDP bigbox3.site3.atlanta.com;"
"branch=z9hG4bK77ef4c2312983.1;received=192.0.2.2")
self.assertEquals(
m.headers['via'][2],
"SIP/2.0/UDP pc33.atlanta.com;"
"branch=z9hG4bK776asdhds ;received=192.0.2.1")
class MessageParsingTestCase2(MessageParsingTestCase):
"""Same as base class, but feed data char by char."""
def feedMessage(self, message):
for c in message:
self.parser.dataReceived(c)
self.parser.dataDone()
class MakeMessageTestCase(unittest.TestCase):
def testRequest(self):
r = sip.Request("INVITE", "sip:foo")
r.addHeader("foo", "bar")
self.assertEqual(
r.toString(),
"INVITE sip:foo SIP/2.0\r\nFoo: bar\r\n\r\n")
def testResponse(self):
r = sip.Response(200, "OK")
r.addHeader("foo", "bar")
r.addHeader("Content-Length", "4")
r.bodyDataReceived("1234")
self.assertEqual(
r.toString(),
"SIP/2.0 200 OK\r\nFoo: bar\r\nContent-Length: 4\r\n\r\n1234")
def testStatusCode(self):
r = sip.Response(200)
self.assertEqual(r.toString(), "SIP/2.0 200 OK\r\n\r\n")
class ViaTestCase(unittest.TestCase):
def checkRoundtrip(self, v):
s = v.toString()
self.assertEqual(s, sip.parseViaHeader(s).toString())
def testExtraWhitespace(self):
v1 = sip.parseViaHeader('SIP/2.0/UDP 192.168.1.1:5060')
v2 = sip.parseViaHeader('SIP/2.0/UDP 192.168.1.1:5060')
self.assertEqual(v1.transport, v2.transport)
self.assertEqual(v1.host, v2.host)
self.assertEqual(v1.port, v2.port)
def test_complex(self):
"""
Test parsing a Via header with one of everything.
"""
s = ("SIP/2.0/UDP first.example.com:4000;ttl=16;maddr=224.2.0.1"
" ;branch=a7c6a8dlze (Example)")
v = sip.parseViaHeader(s)
self.assertEqual(v.transport, "UDP")
self.assertEqual(v.host, "first.example.com")
self.assertEqual(v.port, 4000)
self.assertEqual(v.rport, None)
self.assertEqual(v.rportValue, None)
self.assertEqual(v.rportRequested, False)
self.assertEqual(v.ttl, 16)
self.assertEqual(v.maddr, "224.2.0.1")
self.assertEqual(v.branch, "a7c6a8dlze")
self.assertEqual(v.hidden, 0)
self.assertEqual(v.toString(),
"SIP/2.0/UDP first.example.com:4000"
";ttl=16;branch=a7c6a8dlze;maddr=224.2.0.1")
self.checkRoundtrip(v)
def test_simple(self):
"""
Test parsing a simple Via header.
"""
s = "SIP/2.0/UDP example.com;hidden"
v = sip.parseViaHeader(s)
self.assertEqual(v.transport, "UDP")
self.assertEqual(v.host, "example.com")
self.assertEqual(v.port, 5060)
self.assertEqual(v.rport, None)
self.assertEqual(v.rportValue, None)
self.assertEqual(v.rportRequested, False)
self.assertEqual(v.ttl, None)
self.assertEqual(v.maddr, None)
self.assertEqual(v.branch, None)
self.assertEqual(v.hidden, True)
self.assertEqual(v.toString(),
"SIP/2.0/UDP example.com:5060;hidden")
self.checkRoundtrip(v)
def testSimpler(self):
v = sip.Via("example.com")
self.checkRoundtrip(v)
def test_deprecatedRPort(self):
"""
Setting rport to True is deprecated, but still produces a Via header
with the expected properties.
"""
v = sip.Via("foo.bar", rport=True)
warnings = self.flushWarnings(
offendingFunctions=[self.test_deprecatedRPort])
self.assertEqual(len(warnings), 1)
self.assertEqual(
warnings[0]['message'],
'rport=True is deprecated since Twisted 9.0.')
self.assertEqual(
warnings[0]['category'],
DeprecationWarning)
self.assertEqual(v.toString(), "SIP/2.0/UDP foo.bar:5060;rport")
self.assertEqual(v.rport, True)
self.assertEqual(v.rportRequested, True)
self.assertEqual(v.rportValue, None)
def test_rport(self):
"""
An rport setting of None should insert the parameter with no value.
"""
v = sip.Via("foo.bar", rport=None)
self.assertEqual(v.toString(), "SIP/2.0/UDP foo.bar:5060;rport")
self.assertEqual(v.rportRequested, True)
self.assertEqual(v.rportValue, None)
def test_rportValue(self):
"""
An rport numeric setting should insert the parameter with the number
value given.
"""
v = sip.Via("foo.bar", rport=1)
self.assertEqual(v.toString(), "SIP/2.0/UDP foo.bar:5060;rport=1")
self.assertEqual(v.rportRequested, False)
self.assertEqual(v.rportValue, 1)
self.assertEqual(v.rport, 1)
def testNAT(self):
s = "SIP/2.0/UDP 10.0.0.1:5060;received=22.13.1.5;rport=12345"
v = sip.parseViaHeader(s)
self.assertEqual(v.transport, "UDP")
self.assertEqual(v.host, "10.0.0.1")
self.assertEqual(v.port, 5060)
self.assertEqual(v.received, "22.13.1.5")
self.assertEqual(v.rport, 12345)
self.assertNotEquals(v.toString().find("rport=12345"), -1)
def test_unknownParams(self):
"""
Parsing and serializing Via headers with unknown parameters should work.
"""
s = "SIP/2.0/UDP example.com:5060;branch=a12345b;bogus;pie=delicious"
v = sip.parseViaHeader(s)
self.assertEqual(v.toString(), s)
class URLTestCase(unittest.TestCase):
def testRoundtrip(self):
for url in [
"sip:j.doe@big.com",
"sip:j.doe:secret@big.com;transport=tcp",
"sip:j.doe@big.com?subject=project",
"sip:example.com",
]:
self.assertEqual(sip.parseURL(url).toString(), url)
def testComplex(self):
s = ("sip:user:pass@hosta:123;transport=udp;user=phone;method=foo;"
"ttl=12;maddr=1.2.3.4;blah;goo=bar?a=b&c=d")
url = sip.parseURL(s)
for k, v in [("username", "user"), ("password", "pass"),
("host", "hosta"), ("port", 123),
("transport", "udp"), ("usertype", "phone"),
("method", "foo"), ("ttl", 12),
("maddr", "1.2.3.4"), ("other", ["blah", "goo=bar"]),
("headers", {"a": "b", "c": "d"})]:
self.assertEqual(getattr(url, k), v)
class ParseTestCase(unittest.TestCase):
def testParseAddress(self):
for address, name, urls, params in [
('"A. G. Bell" <sip:foo@example.com>',
"A. G. Bell", "sip:foo@example.com", {}),
("Anon <sip:foo@example.com>", "Anon", "sip:foo@example.com", {}),
("sip:foo@example.com", "", "sip:foo@example.com", {}),
("<sip:foo@example.com>", "", "sip:foo@example.com", {}),
("foo <sip:foo@example.com>;tag=bar;foo=baz", "foo",
"sip:foo@example.com", {"tag": "bar", "foo": "baz"}),
]:
gname, gurl, gparams = sip.parseAddress(address)
self.assertEqual(name, gname)
self.assertEqual(gurl.toString(), urls)
self.assertEqual(gparams, params)
class DummyLocator:
implements(sip.ILocator)
def getAddress(self, logicalURL):
return defer.succeed(sip.URL("server.com", port=5060))
class FailingLocator:
implements(sip.ILocator)
def getAddress(self, logicalURL):
return defer.fail(LookupError())
class ProxyTestCase(unittest.TestCase):
def setUp(self):
self.proxy = sip.Proxy("127.0.0.1")
self.proxy.locator = DummyLocator()
self.sent = []
self.proxy.sendMessage = lambda dest, msg: self.sent.append((dest, msg))
def testRequestForward(self):
r = sip.Request("INVITE", "sip:foo")
r.addHeader("via", sip.Via("1.2.3.4").toString())
r.addHeader("via", sip.Via("1.2.3.5").toString())
r.addHeader("foo", "bar")
r.addHeader("to", "<sip:joe@server.com>")
r.addHeader("contact", "<sip:joe@1.2.3.5>")
self.proxy.datagramReceived(r.toString(), ("1.2.3.4", 5060))
self.assertEqual(len(self.sent), 1)
dest, m = self.sent[0]
self.assertEqual(dest.port, 5060)
self.assertEqual(dest.host, "server.com")
self.assertEqual(m.uri.toString(), "sip:foo")
self.assertEqual(m.method, "INVITE")
self.assertEqual(m.headers["via"],
["SIP/2.0/UDP 127.0.0.1:5060",
"SIP/2.0/UDP 1.2.3.4:5060",
"SIP/2.0/UDP 1.2.3.5:5060"])
def testReceivedRequestForward(self):
r = sip.Request("INVITE", "sip:foo")
r.addHeader("via", sip.Via("1.2.3.4").toString())
r.addHeader("foo", "bar")
r.addHeader("to", "<sip:joe@server.com>")
r.addHeader("contact", "<sip:joe@1.2.3.4>")
self.proxy.datagramReceived(r.toString(), ("1.1.1.1", 5060))
dest, m = self.sent[0]
self.assertEqual(m.headers["via"],
["SIP/2.0/UDP 127.0.0.1:5060",
"SIP/2.0/UDP 1.2.3.4:5060;received=1.1.1.1"])
def testResponseWrongVia(self):
# first via must match proxy's address
r = sip.Response(200)
r.addHeader("via", sip.Via("foo.com").toString())
self.proxy.datagramReceived(r.toString(), ("1.1.1.1", 5060))
self.assertEqual(len(self.sent), 0)
def testResponseForward(self):
r = sip.Response(200)
r.addHeader("via", sip.Via("127.0.0.1").toString())
r.addHeader("via", sip.Via("client.com", port=1234).toString())
self.proxy.datagramReceived(r.toString(), ("1.1.1.1", 5060))
self.assertEqual(len(self.sent), 1)
dest, m = self.sent[0]
self.assertEqual((dest.host, dest.port), ("client.com", 1234))
self.assertEqual(m.code, 200)
self.assertEqual(m.headers["via"], ["SIP/2.0/UDP client.com:1234"])
def testReceivedResponseForward(self):
r = sip.Response(200)
r.addHeader("via", sip.Via("127.0.0.1").toString())
r.addHeader(
"via",
sip.Via("10.0.0.1", received="client.com").toString())
self.proxy.datagramReceived(r.toString(), ("1.1.1.1", 5060))
self.assertEqual(len(self.sent), 1)
dest, m = self.sent[0]
self.assertEqual((dest.host, dest.port), ("client.com", 5060))
def testResponseToUs(self):
r = sip.Response(200)
r.addHeader("via", sip.Via("127.0.0.1").toString())
l = []
self.proxy.gotResponse = lambda *a: l.append(a)
self.proxy.datagramReceived(r.toString(), ("1.1.1.1", 5060))
self.assertEqual(len(l), 1)
m, addr = l[0]
self.assertEqual(len(m.headers.get("via", [])), 0)
self.assertEqual(m.code, 200)
def testLoop(self):
r = sip.Request("INVITE", "sip:foo")
r.addHeader("via", sip.Via("1.2.3.4").toString())
r.addHeader("via", sip.Via("127.0.0.1").toString())
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
self.assertEqual(self.sent, [])
def testCantForwardRequest(self):
r = sip.Request("INVITE", "sip:foo")
r.addHeader("via", sip.Via("1.2.3.4").toString())
r.addHeader("to", "<sip:joe@server.com>")
self.proxy.locator = FailingLocator()
self.proxy.datagramReceived(r.toString(), ("1.2.3.4", 5060))
self.assertEqual(len(self.sent), 1)
dest, m = self.sent[0]
self.assertEqual((dest.host, dest.port), ("1.2.3.4", 5060))
self.assertEqual(m.code, 404)
self.assertEqual(m.headers["via"], ["SIP/2.0/UDP 1.2.3.4:5060"])
def testCantForwardResponse(self):
pass
#testCantForwardResponse.skip = "not implemented yet"
class RegistrationTestCase(unittest.TestCase):
def setUp(self):
self.proxy = sip.RegisterProxy(host="127.0.0.1")
self.registry = sip.InMemoryRegistry("bell.example.com")
self.proxy.registry = self.proxy.locator = self.registry
self.sent = []
self.proxy.sendMessage = lambda dest, msg: self.sent.append((dest, msg))
setUp = utils.suppressWarnings(setUp,
util.suppress(category=DeprecationWarning,
message=r'twisted.protocols.sip.DigestAuthorizer was deprecated'))
def tearDown(self):
for d, uri in self.registry.users.values():
d.cancel()
del self.proxy
def register(self):
r = sip.Request("REGISTER", "sip:bell.example.com")
r.addHeader("to", "sip:joe@bell.example.com")
r.addHeader("contact", "sip:joe@client.com:1234")
r.addHeader("via", sip.Via("client.com").toString())
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
def unregister(self):
r = sip.Request("REGISTER", "sip:bell.example.com")
r.addHeader("to", "sip:joe@bell.example.com")
r.addHeader("contact", "*")
r.addHeader("via", sip.Via("client.com").toString())
r.addHeader("expires", "0")
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
def testRegister(self):
self.register()
dest, m = self.sent[0]
self.assertEqual((dest.host, dest.port), ("client.com", 5060))
self.assertEqual(m.code, 200)
self.assertEqual(m.headers["via"], ["SIP/2.0/UDP client.com:5060"])
self.assertEqual(m.headers["to"], ["sip:joe@bell.example.com"])
self.assertEqual(m.headers["contact"], ["sip:joe@client.com:5060"])
self.failUnless(
int(m.headers["expires"][0]) in (3600, 3601, 3599, 3598))
self.assertEqual(len(self.registry.users), 1)
dc, uri = self.registry.users["joe"]
self.assertEqual(uri.toString(), "sip:joe@client.com:5060")
d = self.proxy.locator.getAddress(sip.URL(username="joe",
host="bell.example.com"))
d.addCallback(lambda desturl : (desturl.host, desturl.port))
d.addCallback(self.assertEqual, ('client.com', 5060))
return d
def testUnregister(self):
self.register()
self.unregister()
dest, m = self.sent[1]
self.assertEqual((dest.host, dest.port), ("client.com", 5060))
self.assertEqual(m.code, 200)
self.assertEqual(m.headers["via"], ["SIP/2.0/UDP client.com:5060"])
self.assertEqual(m.headers["to"], ["sip:joe@bell.example.com"])
self.assertEqual(m.headers["contact"], ["sip:joe@client.com:5060"])
self.assertEqual(m.headers["expires"], ["0"])
self.assertEqual(self.registry.users, {})
def addPortal(self):
r = TestRealm()
p = cred.portal.Portal(r)
c = cred.checkers.InMemoryUsernamePasswordDatabaseDontUse()
c.addUser('userXname@127.0.0.1', 'passXword')
p.registerChecker(c)
self.proxy.portal = p
def testFailedAuthentication(self):
self.addPortal()
self.register()
self.assertEqual(len(self.registry.users), 0)
self.assertEqual(len(self.sent), 1)
dest, m = self.sent[0]
self.assertEqual(m.code, 401)
def test_basicAuthentication(self):
"""
Test that registration with basic authentication suceeds.
"""
self.addPortal()
self.proxy.authorizers = self.proxy.authorizers.copy()
self.proxy.authorizers['basic'] = sip.BasicAuthorizer()
warnings = self.flushWarnings(
offendingFunctions=[self.test_basicAuthentication])
self.assertEqual(len(warnings), 1)
self.assertEqual(
warnings[0]['message'],
"twisted.protocols.sip.BasicAuthorizer was deprecated in "
"Twisted 9.0.0")
self.assertEqual(
warnings[0]['category'],
DeprecationWarning)
r = sip.Request("REGISTER", "sip:bell.example.com")
r.addHeader("to", "sip:joe@bell.example.com")
r.addHeader("contact", "sip:joe@client.com:1234")
r.addHeader("via", sip.Via("client.com").toString())
r.addHeader("authorization",
"Basic " + "userXname:passXword".encode('base64'))
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
self.assertEqual(len(self.registry.users), 1)
self.assertEqual(len(self.sent), 1)
dest, m = self.sent[0]
self.assertEqual(m.code, 200)
def test_failedBasicAuthentication(self):
"""
Failed registration with basic authentication results in an
unauthorized error response.
"""
self.addPortal()
self.proxy.authorizers = self.proxy.authorizers.copy()
self.proxy.authorizers['basic'] = sip.BasicAuthorizer()
warnings = self.flushWarnings(
offendingFunctions=[self.test_failedBasicAuthentication])
self.assertEqual(len(warnings), 1)
self.assertEqual(
warnings[0]['message'],
"twisted.protocols.sip.BasicAuthorizer was deprecated in "
"Twisted 9.0.0")
self.assertEqual(
warnings[0]['category'],
DeprecationWarning)
r = sip.Request("REGISTER", "sip:bell.example.com")
r.addHeader("to", "sip:joe@bell.example.com")
r.addHeader("contact", "sip:joe@client.com:1234")
r.addHeader("via", sip.Via("client.com").toString())
r.addHeader(
"authorization", "Basic " + "userXname:password".encode('base64'))
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
self.assertEqual(len(self.registry.users), 0)
self.assertEqual(len(self.sent), 1)
dest, m = self.sent[0]
self.assertEqual(m.code, 401)
def testWrongDomainRegister(self):
r = sip.Request("REGISTER", "sip:wrong.com")
r.addHeader("to", "sip:joe@bell.example.com")
r.addHeader("contact", "sip:joe@client.com:1234")
r.addHeader("via", sip.Via("client.com").toString())
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
self.assertEqual(len(self.sent), 0)
def testWrongToDomainRegister(self):
r = sip.Request("REGISTER", "sip:bell.example.com")
r.addHeader("to", "sip:joe@foo.com")
r.addHeader("contact", "sip:joe@client.com:1234")
r.addHeader("via", sip.Via("client.com").toString())
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
self.assertEqual(len(self.sent), 0)
def testWrongDomainLookup(self):
self.register()
url = sip.URL(username="joe", host="foo.com")
d = self.proxy.locator.getAddress(url)
self.assertFailure(d, LookupError)
return d
def testNoContactLookup(self):
self.register()
url = sip.URL(username="jane", host="bell.example.com")
d = self.proxy.locator.getAddress(url)
self.assertFailure(d, LookupError)
return d
class Client(sip.Base):
def __init__(self):
sip.Base.__init__(self)
self.received = []
self.deferred = defer.Deferred()
def handle_response(self, response, addr):
self.received.append(response)
self.deferred.callback(self.received)
class LiveTest(unittest.TestCase):
def setUp(self):
self.proxy = sip.RegisterProxy(host="127.0.0.1")
self.registry = sip.InMemoryRegistry("bell.example.com")
self.proxy.registry = self.proxy.locator = self.registry
self.serverPort = reactor.listenUDP(
0, self.proxy, interface="127.0.0.1")
self.client = Client()
self.clientPort = reactor.listenUDP(
0, self.client, interface="127.0.0.1")
self.serverAddress = (self.serverPort.getHost().host,
self.serverPort.getHost().port)
setUp = utils.suppressWarnings(setUp,
util.suppress(category=DeprecationWarning,
message=r'twisted.protocols.sip.DigestAuthorizer was deprecated'))
def tearDown(self):
for d, uri in self.registry.users.values():
d.cancel()
d1 = defer.maybeDeferred(self.clientPort.stopListening)
d2 = defer.maybeDeferred(self.serverPort.stopListening)
return defer.gatherResults([d1, d2])
def testRegister(self):
p = self.clientPort.getHost().port
r = sip.Request("REGISTER", "sip:bell.example.com")
r.addHeader("to", "sip:joe@bell.example.com")
r.addHeader("contact", "sip:joe@127.0.0.1:%d" % p)
r.addHeader("via", sip.Via("127.0.0.1", port=p).toString())
self.client.sendMessage(
sip.URL(host="127.0.0.1", port=self.serverAddress[1]), r)
d = self.client.deferred
def check(received):
self.assertEqual(len(received), 1)
r = received[0]
self.assertEqual(r.code, 200)
d.addCallback(check)
return d
def test_amoralRPort(self):
"""
rport is allowed without a value, apparently because server
implementors might be too stupid to check the received port
against 5060 and see if they're equal, and because client
implementors might be too stupid to bind to port 5060, or set a
value on the rport parameter they send if they bind to another
port.
"""
p = self.clientPort.getHost().port
r = sip.Request("REGISTER", "sip:bell.example.com")
r.addHeader("to", "sip:joe@bell.example.com")
r.addHeader("contact", "sip:joe@127.0.0.1:%d" % p)
r.addHeader("via", sip.Via("127.0.0.1", port=p, rport=True).toString())
warnings = self.flushWarnings(
offendingFunctions=[self.test_amoralRPort])
self.assertEqual(len(warnings), 1)
self.assertEqual(
warnings[0]['message'],
'rport=True is deprecated since Twisted 9.0.')
self.assertEqual(
warnings[0]['category'],
DeprecationWarning)
self.client.sendMessage(sip.URL(host="127.0.0.1",
port=self.serverAddress[1]),
r)
d = self.client.deferred
def check(received):
self.assertEqual(len(received), 1)
r = received[0]
self.assertEqual(r.code, 200)
d.addCallback(check)
return d
registerRequest = """
REGISTER sip:intarweb.us SIP/2.0\r
Via: SIP/2.0/UDP 192.168.1.100:50609\r
From: <sip:exarkun@intarweb.us:50609>\r
To: <sip:exarkun@intarweb.us:50609>\r
Contact: "exarkun" <sip:exarkun@192.168.1.100:50609>\r
Call-ID: 94E7E5DAF39111D791C6000393764646@intarweb.us\r
CSeq: 9898 REGISTER\r
Expires: 500\r
User-Agent: X-Lite build 1061\r
Content-Length: 0\r
\r
"""
challengeResponse = """\
SIP/2.0 401 Unauthorized\r
Via: SIP/2.0/UDP 192.168.1.100:50609;received=127.0.0.1;rport=5632\r
To: <sip:exarkun@intarweb.us:50609>\r
From: <sip:exarkun@intarweb.us:50609>\r
Call-ID: 94E7E5DAF39111D791C6000393764646@intarweb.us\r
CSeq: 9898 REGISTER\r
WWW-Authenticate: Digest nonce="92956076410767313901322208775",opaque="1674186428",qop-options="auth",algorithm="MD5",realm="intarweb.us"\r
\r
"""
authRequest = """\
REGISTER sip:intarweb.us SIP/2.0\r
Via: SIP/2.0/UDP 192.168.1.100:50609\r
From: <sip:exarkun@intarweb.us:50609>\r
To: <sip:exarkun@intarweb.us:50609>\r
Contact: "exarkun" <sip:exarkun@192.168.1.100:50609>\r
Call-ID: 94E7E5DAF39111D791C6000393764646@intarweb.us\r
CSeq: 9899 REGISTER\r
Expires: 500\r
Authorization: Digest username="exarkun",realm="intarweb.us",nonce="92956076410767313901322208775",response="4a47980eea31694f997369214292374b",uri="sip:intarweb.us",algorithm=MD5,opaque="1674186428"\r
User-Agent: X-Lite build 1061\r
Content-Length: 0\r
\r
"""
okResponse = """\
SIP/2.0 200 OK\r
Via: SIP/2.0/UDP 192.168.1.100:50609;received=127.0.0.1;rport=5632\r
To: <sip:exarkun@intarweb.us:50609>\r
From: <sip:exarkun@intarweb.us:50609>\r
Call-ID: 94E7E5DAF39111D791C6000393764646@intarweb.us\r
CSeq: 9899 REGISTER\r
Contact: sip:exarkun@127.0.0.1:5632\r
Expires: 3600\r
Content-Length: 0\r
\r
"""
class FakeDigestAuthorizer(sip.DigestAuthorizer):
def generateNonce(self):
return '92956076410767313901322208775'
def generateOpaque(self):
return '1674186428'
class FakeRegistry(sip.InMemoryRegistry):
"""Make sure expiration is always seen to be 3600.
Otherwise slow reactors fail tests incorrectly.
"""
def _cbReg(self, reg):
if 3600 < reg.secondsToExpiry or reg.secondsToExpiry < 3598:
raise RuntimeError(
"bad seconds to expire: %s" % reg.secondsToExpiry)
reg.secondsToExpiry = 3600
return reg
def getRegistrationInfo(self, uri):
d = sip.InMemoryRegistry.getRegistrationInfo(self, uri)
return d.addCallback(self._cbReg)
def registerAddress(self, domainURL, logicalURL, physicalURL):
d = sip.InMemoryRegistry.registerAddress(
self, domainURL, logicalURL, physicalURL)
return d.addCallback(self._cbReg)
class AuthorizationTestCase(unittest.TestCase):
def setUp(self):
self.proxy = sip.RegisterProxy(host="intarweb.us")
self.proxy.authorizers = self.proxy.authorizers.copy()
self.proxy.authorizers['digest'] = FakeDigestAuthorizer()
self.registry = FakeRegistry("intarweb.us")
self.proxy.registry = self.proxy.locator = self.registry
self.transport = proto_helpers.FakeDatagramTransport()
self.proxy.transport = self.transport
r = TestRealm()
p = cred.portal.Portal(r)
c = cred.checkers.InMemoryUsernamePasswordDatabaseDontUse()
c.addUser('exarkun@intarweb.us', 'password')
p.registerChecker(c)
self.proxy.portal = p
setUp = utils.suppressWarnings(setUp,
util.suppress(category=DeprecationWarning,
message=r'twisted.protocols.sip.DigestAuthorizer was deprecated'))
def tearDown(self):
for d, uri in self.registry.users.values():
d.cancel()
del self.proxy
def testChallenge(self):
self.proxy.datagramReceived(registerRequest, ("127.0.0.1", 5632))
self.assertEqual(
self.transport.written[-1],
((challengeResponse, ("127.0.0.1", 5632)))
)
self.transport.written = []
self.proxy.datagramReceived(authRequest, ("127.0.0.1", 5632))
self.assertEqual(
self.transport.written[-1],
((okResponse, ("127.0.0.1", 5632)))
)
testChallenge.suppress = [
util.suppress(
category=DeprecationWarning,
message=r'twisted.protocols.sip.DigestAuthorizer was deprecated'),
util.suppress(
category=DeprecationWarning,
message=r'twisted.protocols.sip.DigestedCredentials was deprecated'),
util.suppress(
category=DeprecationWarning,
message=r'twisted.protocols.sip.DigestCalcHA1 was deprecated'),
util.suppress(
category=DeprecationWarning,
message=r'twisted.protocols.sip.DigestCalcResponse was deprecated')]
class DeprecationTests(unittest.TestCase):
"""
Tests for deprecation of obsolete components of L{twisted.protocols.sip}.
"""
def test_deprecatedDigestCalcHA1(self):
"""
L{sip.DigestCalcHA1} is deprecated.
"""
self.callDeprecated(Version("Twisted", 9, 0, 0),
sip.DigestCalcHA1, '', '', '', '', '', '')
def test_deprecatedDigestCalcResponse(self):
"""
L{sip.DigestCalcResponse} is deprecated.
"""
self.callDeprecated(Version("Twisted", 9, 0, 0),
sip.DigestCalcResponse, '', '', '', '', '', '', '',
'')
def test_deprecatedBasicAuthorizer(self):
"""
L{sip.BasicAuthorizer} is deprecated.
"""
self.callDeprecated(Version("Twisted", 9, 0, 0), sip.BasicAuthorizer)
def test_deprecatedDigestAuthorizer(self):
"""
L{sip.DigestAuthorizer} is deprecated.
"""
self.callDeprecated(Version("Twisted", 9, 0, 0), sip.DigestAuthorizer)
def test_deprecatedDigestedCredentials(self):
"""
L{sip.DigestedCredentials} is deprecated.
"""
self.callDeprecated(Version("Twisted", 9, 0, 0),
sip.DigestedCredentials, '', {}, {})

View file

@ -0,0 +1,172 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import sys, os
try:
import Crypto.Cipher.AES
except ImportError:
Crypto = None
from twisted.trial import unittest
from twisted.persisted import sob
from twisted.python import components
class Dummy(components.Componentized):
pass
objects = [
1,
"hello",
(1, "hello"),
[1, "hello"],
{1:"hello"},
]
class FakeModule(object):
pass
class PersistTestCase(unittest.TestCase):
def testStyles(self):
for o in objects:
p = sob.Persistent(o, '')
for style in 'source pickle'.split():
p.setStyle(style)
p.save(filename='persisttest.'+style)
o1 = sob.load('persisttest.'+style, style)
self.assertEqual(o, o1)
def testStylesBeingSet(self):
o = Dummy()
o.foo = 5
o.setComponent(sob.IPersistable, sob.Persistent(o, 'lala'))
for style in 'source pickle'.split():
sob.IPersistable(o).setStyle(style)
sob.IPersistable(o).save(filename='lala.'+style)
o1 = sob.load('lala.'+style, style)
self.assertEqual(o.foo, o1.foo)
self.assertEqual(sob.IPersistable(o1).style, style)
def testNames(self):
o = [1,2,3]
p = sob.Persistent(o, 'object')
for style in 'source pickle'.split():
p.setStyle(style)
p.save()
o1 = sob.load('object.ta'+style[0], style)
self.assertEqual(o, o1)
for tag in 'lala lolo'.split():
p.save(tag)
o1 = sob.load('object-'+tag+'.ta'+style[0], style)
self.assertEqual(o, o1)
def testEncryptedStyles(self):
for o in objects:
phrase='once I was the king of spain'
p = sob.Persistent(o, '')
for style in 'source pickle'.split():
p.setStyle(style)
p.save(filename='epersisttest.'+style, passphrase=phrase)
o1 = sob.load('epersisttest.'+style, style, phrase)
self.assertEqual(o, o1)
if Crypto is None:
testEncryptedStyles.skip = "PyCrypto required for encrypted config"
def testPython(self):
f = open("persisttest.python", 'w')
f.write('foo=[1,2,3] ')
f.close()
o = sob.loadValueFromFile('persisttest.python', 'foo')
self.assertEqual(o, [1,2,3])
def testEncryptedPython(self):
phrase='once I was the king of spain'
f = open("epersisttest.python", 'w')
f.write(
sob._encrypt(phrase, 'foo=[1,2,3]'))
f.close()
o = sob.loadValueFromFile('epersisttest.python', 'foo', phrase)
self.assertEqual(o, [1,2,3])
if Crypto is None:
testEncryptedPython.skip = "PyCrypto required for encrypted config"
def testTypeGuesser(self):
self.assertRaises(KeyError, sob.guessType, "file.blah")
self.assertEqual('python', sob.guessType("file.py"))
self.assertEqual('python', sob.guessType("file.tac"))
self.assertEqual('python', sob.guessType("file.etac"))
self.assertEqual('pickle', sob.guessType("file.tap"))
self.assertEqual('pickle', sob.guessType("file.etap"))
self.assertEqual('source', sob.guessType("file.tas"))
self.assertEqual('source', sob.guessType("file.etas"))
def testEverythingEphemeralGetattr(self):
"""
Verify that _EverythingEphermal.__getattr__ works.
"""
self.fakeMain.testMainModGetattr = 1
dirname = self.mktemp()
os.mkdir(dirname)
filename = os.path.join(dirname, 'persisttest.ee_getattr')
f = file(filename, 'w')
f.write('import __main__\n')
f.write('if __main__.testMainModGetattr != 1: raise AssertionError\n')
f.write('app = None\n')
f.close()
sob.load(filename, 'source')
def testEverythingEphemeralSetattr(self):
"""
Verify that _EverythingEphemeral.__setattr__ won't affect __main__.
"""
self.fakeMain.testMainModSetattr = 1
dirname = self.mktemp()
os.mkdir(dirname)
filename = os.path.join(dirname, 'persisttest.ee_setattr')
f = file(filename, 'w')
f.write('import __main__\n')
f.write('__main__.testMainModSetattr = 2\n')
f.write('app = None\n')
f.close()
sob.load(filename, 'source')
self.assertEqual(self.fakeMain.testMainModSetattr, 1)
def testEverythingEphemeralException(self):
"""
Test that an exception during load() won't cause _EE to mask __main__
"""
dirname = self.mktemp()
os.mkdir(dirname)
filename = os.path.join(dirname, 'persisttest.ee_exception')
f = file(filename, 'w')
f.write('raise ValueError\n')
f.close()
self.assertRaises(ValueError, sob.load, filename, 'source')
self.assertEqual(type(sys.modules['__main__']), FakeModule)
def setUp(self):
"""
Replace the __main__ module with a fake one, so that it can be mutated
in tests
"""
self.realMain = sys.modules['__main__']
self.fakeMain = sys.modules['__main__'] = FakeModule()
def tearDown(self):
"""
Restore __main__ to its original value
"""
sys.modules['__main__'] = self.realMain

View file

@ -0,0 +1,498 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.protocol.socks}, an implementation of the SOCKSv4 and
SOCKSv4a protocols.
"""
import struct, socket
from twisted.trial import unittest
from twisted.test import proto_helpers
from twisted.internet import defer, address, reactor
from twisted.internet.error import DNSLookupError
from twisted.protocols import socks
class StringTCPTransport(proto_helpers.StringTransport):
stringTCPTransport_closing = False
peer = None
def getPeer(self):
return self.peer
def getHost(self):
return address.IPv4Address('TCP', '2.3.4.5', 42)
def loseConnection(self):
self.stringTCPTransport_closing = True
class FakeResolverReactor:
"""
Bare-bones reactor with deterministic behavior for the resolve method.
"""
def __init__(self, names):
"""
@type names: C{dict} containing C{str} keys and C{str} values.
@param names: A hostname to IP address mapping. The IP addresses are
stringified dotted quads.
"""
self.names = names
def resolve(self, hostname):
"""
Resolve a hostname by looking it up in the C{names} dictionary.
"""
try:
return defer.succeed(self.names[hostname])
except KeyError:
return defer.fail(
DNSLookupError("FakeResolverReactor couldn't find " + hostname))
class SOCKSv4Driver(socks.SOCKSv4):
# last SOCKSv4Outgoing instantiated
driver_outgoing = None
# last SOCKSv4IncomingFactory instantiated
driver_listen = None
def connectClass(self, host, port, klass, *args):
# fake it
proto = klass(*args)
proto.transport = StringTCPTransport()
proto.transport.peer = address.IPv4Address('TCP', host, port)
proto.connectionMade()
self.driver_outgoing = proto
return defer.succeed(proto)
def listenClass(self, port, klass, *args):
# fake it
factory = klass(*args)
self.driver_listen = factory
if port == 0:
port = 1234
return defer.succeed(('6.7.8.9', port))
class Connect(unittest.TestCase):
"""
Tests for SOCKS and SOCKSv4a connect requests using the L{SOCKSv4} protocol.
"""
def setUp(self):
self.sock = SOCKSv4Driver()
self.sock.transport = StringTCPTransport()
self.sock.connectionMade()
self.sock.reactor = FakeResolverReactor({"localhost":"127.0.0.1"})
def tearDown(self):
outgoing = self.sock.driver_outgoing
if outgoing is not None:
self.assert_(outgoing.transport.stringTCPTransport_closing,
"Outgoing SOCKS connections need to be closed.")
def test_simple(self):
self.sock.dataReceived(
struct.pack('!BBH', 4, 1, 34)
+ socket.inet_aton('1.2.3.4')
+ 'fooBAR'
+ '\0')
sent = self.sock.transport.value()
self.sock.transport.clear()
self.assertEqual(sent,
struct.pack('!BBH', 0, 90, 34)
+ socket.inet_aton('1.2.3.4'))
self.assert_(not self.sock.transport.stringTCPTransport_closing)
self.assert_(self.sock.driver_outgoing is not None)
# pass some data through
self.sock.dataReceived('hello, world')
self.assertEqual(self.sock.driver_outgoing.transport.value(),
'hello, world')
# the other way around
self.sock.driver_outgoing.dataReceived('hi there')
self.assertEqual(self.sock.transport.value(), 'hi there')
self.sock.connectionLost('fake reason')
def test_socks4aSuccessfulResolution(self):
"""
If the destination IP address has zeros for the first three octets and
non-zero for the fourth octet, the client is attempting a v4a
connection. A hostname is specified after the user ID string and the
server connects to the address that hostname resolves to.
@see: U{http://en.wikipedia.org/wiki/SOCKS#SOCKS_4a_protocol}
"""
# send the domain name "localhost" to be resolved
clientRequest = (
struct.pack('!BBH', 4, 1, 34)
+ socket.inet_aton('0.0.0.1')
+ 'fooBAZ\0'
+ 'localhost\0')
# Deliver the bytes one by one to exercise the protocol's buffering
# logic. FakeResolverReactor's resolve method is invoked to "resolve"
# the hostname.
for byte in clientRequest:
self.sock.dataReceived(byte)
sent = self.sock.transport.value()
self.sock.transport.clear()
# Verify that the server responded with the address which will be
# connected to.
self.assertEqual(
sent,
struct.pack('!BBH', 0, 90, 34) + socket.inet_aton('127.0.0.1'))
self.assertFalse(self.sock.transport.stringTCPTransport_closing)
self.assertNotIdentical(self.sock.driver_outgoing, None)
# Pass some data through and verify it is forwarded to the outgoing
# connection.
self.sock.dataReceived('hello, world')
self.assertEqual(
self.sock.driver_outgoing.transport.value(), 'hello, world')
# Deliver some data from the output connection and verify it is
# passed along to the incoming side.
self.sock.driver_outgoing.dataReceived('hi there')
self.assertEqual(self.sock.transport.value(), 'hi there')
self.sock.connectionLost('fake reason')
def test_socks4aFailedResolution(self):
"""
Failed hostname resolution on a SOCKSv4a packet results in a 91 error
response and the connection getting closed.
"""
# send the domain name "failinghost" to be resolved
clientRequest = (
struct.pack('!BBH', 4, 1, 34)
+ socket.inet_aton('0.0.0.1')
+ 'fooBAZ\0'
+ 'failinghost\0')
# Deliver the bytes one by one to exercise the protocol's buffering
# logic. FakeResolverReactor's resolve method is invoked to "resolve"
# the hostname.
for byte in clientRequest:
self.sock.dataReceived(byte)
# Verify that the server responds with a 91 error.
sent = self.sock.transport.value()
self.assertEqual(
sent,
struct.pack('!BBH', 0, 91, 0) + socket.inet_aton('0.0.0.0'))
# A failed resolution causes the transport to drop the connection.
self.assertTrue(self.sock.transport.stringTCPTransport_closing)
self.assertIdentical(self.sock.driver_outgoing, None)
def test_accessDenied(self):
self.sock.authorize = lambda code, server, port, user: 0
self.sock.dataReceived(
struct.pack('!BBH', 4, 1, 4242)
+ socket.inet_aton('10.2.3.4')
+ 'fooBAR'
+ '\0')
self.assertEqual(self.sock.transport.value(),
struct.pack('!BBH', 0, 91, 0)
+ socket.inet_aton('0.0.0.0'))
self.assert_(self.sock.transport.stringTCPTransport_closing)
self.assertIdentical(self.sock.driver_outgoing, None)
def test_eofRemote(self):
self.sock.dataReceived(
struct.pack('!BBH', 4, 1, 34)
+ socket.inet_aton('1.2.3.4')
+ 'fooBAR'
+ '\0')
sent = self.sock.transport.value()
self.sock.transport.clear()
# pass some data through
self.sock.dataReceived('hello, world')
self.assertEqual(self.sock.driver_outgoing.transport.value(),
'hello, world')
# now close it from the server side
self.sock.driver_outgoing.transport.loseConnection()
self.sock.driver_outgoing.connectionLost('fake reason')
def test_eofLocal(self):
self.sock.dataReceived(
struct.pack('!BBH', 4, 1, 34)
+ socket.inet_aton('1.2.3.4')
+ 'fooBAR'
+ '\0')
sent = self.sock.transport.value()
self.sock.transport.clear()
# pass some data through
self.sock.dataReceived('hello, world')
self.assertEqual(self.sock.driver_outgoing.transport.value(),
'hello, world')
# now close it from the client side
self.sock.connectionLost('fake reason')
class Bind(unittest.TestCase):
"""
Tests for SOCKS and SOCKSv4a bind requests using the L{SOCKSv4} protocol.
"""
def setUp(self):
self.sock = SOCKSv4Driver()
self.sock.transport = StringTCPTransport()
self.sock.connectionMade()
self.sock.reactor = FakeResolverReactor({"localhost":"127.0.0.1"})
## def tearDown(self):
## # TODO ensure the listen port is closed
## listen = self.sock.driver_listen
## if listen is not None:
## self.assert_(incoming.transport.stringTCPTransport_closing,
## "Incoming SOCKS connections need to be closed.")
def test_simple(self):
self.sock.dataReceived(
struct.pack('!BBH', 4, 2, 34)
+ socket.inet_aton('1.2.3.4')
+ 'fooBAR'
+ '\0')
sent = self.sock.transport.value()
self.sock.transport.clear()
self.assertEqual(sent,
struct.pack('!BBH', 0, 90, 1234)
+ socket.inet_aton('6.7.8.9'))
self.assert_(not self.sock.transport.stringTCPTransport_closing)
self.assert_(self.sock.driver_listen is not None)
# connect
incoming = self.sock.driver_listen.buildProtocol(('1.2.3.4', 5345))
self.assertNotIdentical(incoming, None)
incoming.transport = StringTCPTransport()
incoming.connectionMade()
# now we should have the second reply packet
sent = self.sock.transport.value()
self.sock.transport.clear()
self.assertEqual(sent,
struct.pack('!BBH', 0, 90, 0)
+ socket.inet_aton('0.0.0.0'))
self.assert_(not self.sock.transport.stringTCPTransport_closing)
# pass some data through
self.sock.dataReceived('hello, world')
self.assertEqual(incoming.transport.value(),
'hello, world')
# the other way around
incoming.dataReceived('hi there')
self.assertEqual(self.sock.transport.value(), 'hi there')
self.sock.connectionLost('fake reason')
def test_socks4a(self):
"""
If the destination IP address has zeros for the first three octets and
non-zero for the fourth octet, the client is attempting a v4a
connection. A hostname is specified after the user ID string and the
server connects to the address that hostname resolves to.
@see: U{http://en.wikipedia.org/wiki/SOCKS#SOCKS_4a_protocol}
"""
# send the domain name "localhost" to be resolved
clientRequest = (
struct.pack('!BBH', 4, 2, 34)
+ socket.inet_aton('0.0.0.1')
+ 'fooBAZ\0'
+ 'localhost\0')
# Deliver the bytes one by one to exercise the protocol's buffering
# logic. FakeResolverReactor's resolve method is invoked to "resolve"
# the hostname.
for byte in clientRequest:
self.sock.dataReceived(byte)
sent = self.sock.transport.value()
self.sock.transport.clear()
# Verify that the server responded with the address which will be
# connected to.
self.assertEqual(
sent,
struct.pack('!BBH', 0, 90, 1234) + socket.inet_aton('6.7.8.9'))
self.assertFalse(self.sock.transport.stringTCPTransport_closing)
self.assertNotIdentical(self.sock.driver_listen, None)
# connect
incoming = self.sock.driver_listen.buildProtocol(('127.0.0.1', 5345))
self.assertNotIdentical(incoming, None)
incoming.transport = StringTCPTransport()
incoming.connectionMade()
# now we should have the second reply packet
sent = self.sock.transport.value()
self.sock.transport.clear()
self.assertEqual(sent,
struct.pack('!BBH', 0, 90, 0)
+ socket.inet_aton('0.0.0.0'))
self.assertNotIdentical(
self.sock.transport.stringTCPTransport_closing, None)
# Deliver some data from the output connection and verify it is
# passed along to the incoming side.
self.sock.dataReceived('hi there')
self.assertEqual(incoming.transport.value(), 'hi there')
# the other way around
incoming.dataReceived('hi there')
self.assertEqual(self.sock.transport.value(), 'hi there')
self.sock.connectionLost('fake reason')
def test_socks4aFailedResolution(self):
"""
Failed hostname resolution on a SOCKSv4a packet results in a 91 error
response and the connection getting closed.
"""
# send the domain name "failinghost" to be resolved
clientRequest = (
struct.pack('!BBH', 4, 2, 34)
+ socket.inet_aton('0.0.0.1')
+ 'fooBAZ\0'
+ 'failinghost\0')
# Deliver the bytes one by one to exercise the protocol's buffering
# logic. FakeResolverReactor's resolve method is invoked to "resolve"
# the hostname.
for byte in clientRequest:
self.sock.dataReceived(byte)
# Verify that the server responds with a 91 error.
sent = self.sock.transport.value()
self.assertEqual(
sent,
struct.pack('!BBH', 0, 91, 0) + socket.inet_aton('0.0.0.0'))
# A failed resolution causes the transport to drop the connection.
self.assertTrue(self.sock.transport.stringTCPTransport_closing)
self.assertIdentical(self.sock.driver_outgoing, None)
def test_accessDenied(self):
self.sock.authorize = lambda code, server, port, user: 0
self.sock.dataReceived(
struct.pack('!BBH', 4, 2, 4242)
+ socket.inet_aton('10.2.3.4')
+ 'fooBAR'
+ '\0')
self.assertEqual(self.sock.transport.value(),
struct.pack('!BBH', 0, 91, 0)
+ socket.inet_aton('0.0.0.0'))
self.assert_(self.sock.transport.stringTCPTransport_closing)
self.assertIdentical(self.sock.driver_listen, None)
def test_eofRemote(self):
self.sock.dataReceived(
struct.pack('!BBH', 4, 2, 34)
+ socket.inet_aton('1.2.3.4')
+ 'fooBAR'
+ '\0')
sent = self.sock.transport.value()
self.sock.transport.clear()
# connect
incoming = self.sock.driver_listen.buildProtocol(('1.2.3.4', 5345))
self.assertNotIdentical(incoming, None)
incoming.transport = StringTCPTransport()
incoming.connectionMade()
# now we should have the second reply packet
sent = self.sock.transport.value()
self.sock.transport.clear()
self.assertEqual(sent,
struct.pack('!BBH', 0, 90, 0)
+ socket.inet_aton('0.0.0.0'))
self.assert_(not self.sock.transport.stringTCPTransport_closing)
# pass some data through
self.sock.dataReceived('hello, world')
self.assertEqual(incoming.transport.value(),
'hello, world')
# now close it from the server side
incoming.transport.loseConnection()
incoming.connectionLost('fake reason')
def test_eofLocal(self):
self.sock.dataReceived(
struct.pack('!BBH', 4, 2, 34)
+ socket.inet_aton('1.2.3.4')
+ 'fooBAR'
+ '\0')
sent = self.sock.transport.value()
self.sock.transport.clear()
# connect
incoming = self.sock.driver_listen.buildProtocol(('1.2.3.4', 5345))
self.assertNotIdentical(incoming, None)
incoming.transport = StringTCPTransport()
incoming.connectionMade()
# now we should have the second reply packet
sent = self.sock.transport.value()
self.sock.transport.clear()
self.assertEqual(sent,
struct.pack('!BBH', 0, 90, 0)
+ socket.inet_aton('0.0.0.0'))
self.assert_(not self.sock.transport.stringTCPTransport_closing)
# pass some data through
self.sock.dataReceived('hello, world')
self.assertEqual(incoming.transport.value(),
'hello, world')
# now close it from the client side
self.sock.connectionLost('fake reason')
def test_badSource(self):
self.sock.dataReceived(
struct.pack('!BBH', 4, 2, 34)
+ socket.inet_aton('1.2.3.4')
+ 'fooBAR'
+ '\0')
sent = self.sock.transport.value()
self.sock.transport.clear()
# connect from WRONG address
incoming = self.sock.driver_listen.buildProtocol(('1.6.6.6', 666))
self.assertIdentical(incoming, None)
# Now we should have the second reply packet and it should
# be a failure. The connection should be closing.
sent = self.sock.transport.value()
self.sock.transport.clear()
self.assertEqual(sent,
struct.pack('!BBH', 0, 91, 0)
+ socket.inet_aton('0.0.0.0'))
self.assert_(self.sock.transport.stringTCPTransport_closing)

View file

@ -0,0 +1,727 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for twisted SSL support.
"""
from __future__ import division, absolute_import
from twisted.python.filepath import FilePath
from twisted.trial import unittest
from twisted.internet import protocol, reactor, interfaces, defer
from twisted.internet.error import ConnectionDone
from twisted.protocols import basic
from twisted.python.runtime import platform
from twisted.test.test_tcp import ProperlyCloseFilesMixin
import os, errno
try:
from OpenSSL import SSL, crypto
from twisted.internet import ssl
from twisted.test.ssl_helpers import ClientTLSContext, certPath
except ImportError:
def _noSSL():
# ugh, make pyflakes happy.
global SSL
global ssl
SSL = ssl = None
_noSSL()
try:
from twisted.protocols import tls as newTLS
except ImportError:
# Assuming SSL exists, we're using old version in reactor (i.e. non-protocol)
newTLS = None
class UnintelligentProtocol(basic.LineReceiver):
"""
@ivar deferred: a deferred that will fire at connection lost.
@type deferred: L{defer.Deferred}
@cvar pretext: text sent before TLS is set up.
@type pretext: C{bytes}
@cvar posttext: text sent after TLS is set up.
@type posttext: C{bytes}
"""
pretext = [
b"first line",
b"last thing before tls starts",
b"STARTTLS"]
posttext = [
b"first thing after tls started",
b"last thing ever"]
def __init__(self):
self.deferred = defer.Deferred()
def connectionMade(self):
for l in self.pretext:
self.sendLine(l)
def lineReceived(self, line):
if line == b"READY":
self.transport.startTLS(ClientTLSContext(), self.factory.client)
for l in self.posttext:
self.sendLine(l)
self.transport.loseConnection()
def connectionLost(self, reason):
self.deferred.callback(None)
class LineCollector(basic.LineReceiver):
"""
@ivar deferred: a deferred that will fire at connection lost.
@type deferred: L{defer.Deferred}
@ivar doTLS: whether the protocol is initiate TLS or not.
@type doTLS: C{bool}
@ivar fillBuffer: if set to True, it will send lots of data once
C{STARTTLS} is received.
@type fillBuffer: C{bool}
"""
def __init__(self, doTLS, fillBuffer=False):
self.doTLS = doTLS
self.fillBuffer = fillBuffer
self.deferred = defer.Deferred()
def connectionMade(self):
self.factory.rawdata = b''
self.factory.lines = []
def lineReceived(self, line):
self.factory.lines.append(line)
if line == b'STARTTLS':
if self.fillBuffer:
for x in range(500):
self.sendLine(b'X' * 1000)
self.sendLine(b'READY')
if self.doTLS:
ctx = ServerTLSContext(
privateKeyFileName=certPath,
certificateFileName=certPath,
)
self.transport.startTLS(ctx, self.factory.server)
else:
self.setRawMode()
def rawDataReceived(self, data):
self.factory.rawdata += data
self.transport.loseConnection()
def connectionLost(self, reason):
self.deferred.callback(None)
class SingleLineServerProtocol(protocol.Protocol):
"""
A protocol that sends a single line of data at C{connectionMade}.
"""
def connectionMade(self):
self.transport.write(b"+OK <some crap>\r\n")
self.transport.getPeerCertificate()
class RecordingClientProtocol(protocol.Protocol):
"""
@ivar deferred: a deferred that will fire with first received content.
@type deferred: L{defer.Deferred}
"""
def __init__(self):
self.deferred = defer.Deferred()
def connectionMade(self):
self.transport.getPeerCertificate()
def dataReceived(self, data):
self.deferred.callback(data)
class ImmediatelyDisconnectingProtocol(protocol.Protocol):
"""
A protocol that disconnect immediately on connection. It fires the
C{connectionDisconnected} deferred of its factory on connetion lost.
"""
def connectionMade(self):
self.transport.loseConnection()
def connectionLost(self, reason):
self.factory.connectionDisconnected.callback(None)
def generateCertificateObjects(organization, organizationalUnit):
"""
Create a certificate for given C{organization} and C{organizationalUnit}.
@return: a tuple of (key, request, certificate) objects.
"""
pkey = crypto.PKey()
pkey.generate_key(crypto.TYPE_RSA, 512)
req = crypto.X509Req()
subject = req.get_subject()
subject.O = organization
subject.OU = organizationalUnit
req.set_pubkey(pkey)
req.sign(pkey, "md5")
# Here comes the actual certificate
cert = crypto.X509()
cert.set_serial_number(1)
cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(60) # Testing certificates need not be long lived
cert.set_issuer(req.get_subject())
cert.set_subject(req.get_subject())
cert.set_pubkey(req.get_pubkey())
cert.sign(pkey, "md5")
return pkey, req, cert
def generateCertificateFiles(basename, organization, organizationalUnit):
"""
Create certificate files key, req and cert prefixed by C{basename} for
given C{organization} and C{organizationalUnit}.
"""
pkey, req, cert = generateCertificateObjects(organization, organizationalUnit)
for ext, obj, dumpFunc in [
('key', pkey, crypto.dump_privatekey),
('req', req, crypto.dump_certificate_request),
('cert', cert, crypto.dump_certificate)]:
fName = os.extsep.join((basename, ext)).encode("utf-8")
FilePath(fName).setContent(dumpFunc(crypto.FILETYPE_PEM, obj))
class ContextGeneratingMixin:
"""
Offer methods to create L{ssl.DefaultOpenSSLContextFactory} for both client
and server.
@ivar clientBase: prefix of client certificate files.
@type clientBase: C{str}
@ivar serverBase: prefix of server certificate files.
@type serverBase: C{str}
@ivar clientCtxFactory: a generated context factory to be used in
C{reactor.connectSSL}.
@type clientCtxFactory: L{ssl.DefaultOpenSSLContextFactory}
@ivar serverCtxFactory: a generated context factory to be used in
C{reactor.listenSSL}.
@type serverCtxFactory: L{ssl.DefaultOpenSSLContextFactory}
"""
def makeContextFactory(self, org, orgUnit, *args, **kwArgs):
base = self.mktemp()
generateCertificateFiles(base, org, orgUnit)
serverCtxFactory = ssl.DefaultOpenSSLContextFactory(
os.extsep.join((base, 'key')),
os.extsep.join((base, 'cert')),
*args, **kwArgs)
return base, serverCtxFactory
def setupServerAndClient(self, clientArgs, clientKwArgs, serverArgs,
serverKwArgs):
self.clientBase, self.clientCtxFactory = self.makeContextFactory(
*clientArgs, **clientKwArgs)
self.serverBase, self.serverCtxFactory = self.makeContextFactory(
*serverArgs, **serverKwArgs)
if SSL is not None:
class ServerTLSContext(ssl.DefaultOpenSSLContextFactory):
"""
A context factory with a default method set to L{SSL.TLSv1_METHOD}.
"""
isClient = False
def __init__(self, *args, **kw):
kw['sslmethod'] = SSL.TLSv1_METHOD
ssl.DefaultOpenSSLContextFactory.__init__(self, *args, **kw)
class StolenTCPTestCase(ProperlyCloseFilesMixin, unittest.TestCase):
"""
For SSL transports, test many of the same things which are tested for
TCP transports.
"""
def createServer(self, address, portNumber, factory):
"""
Create an SSL server with a certificate using L{IReactorSSL.listenSSL}.
"""
cert = ssl.PrivateCertificate.loadPEM(FilePath(certPath).getContent())
contextFactory = cert.options()
return reactor.listenSSL(
portNumber, factory, contextFactory, interface=address)
def connectClient(self, address, portNumber, clientCreator):
"""
Create an SSL client using L{IReactorSSL.connectSSL}.
"""
contextFactory = ssl.CertificateOptions()
return clientCreator.connectSSL(address, portNumber, contextFactory)
def getHandleExceptionType(self):
"""
Return L{SSL.Error} as the expected error type which will be raised by
a write to the L{OpenSSL.SSL.Connection} object after it has been
closed.
"""
return SSL.Error
def getHandleErrorCode(self):
"""
Return the argument L{SSL.Error} will be constructed with for this
case. This is basically just a random OpenSSL implementation detail.
It would be better if this test worked in a way which did not require
this.
"""
# Windows 2000 SP 4 and Windows XP SP 2 give back WSAENOTSOCK for
# SSL.Connection.write for some reason. The twisted.protocols.tls
# implementation of IReactorSSL doesn't suffer from this imprecation,
# though, since it is isolated from the Windows I/O layer (I suppose?).
# If test_properlyCloseFiles waited for the SSL handshake to complete
# and performed an orderly shutdown, then this would probably be a
# little less weird: writing to a shutdown SSL connection has a more
# well-defined failure mode (or at least it should).
# So figure out if twisted.protocols.tls is in use. If it can be
# imported, it should be.
try:
import twisted.protocols.tls
except ImportError:
# It isn't available, so we expect WSAENOTSOCK if we're on Windows.
if platform.getType() == 'win32':
return errno.WSAENOTSOCK
# Otherwise, we expect an error about how we tried to write to a
# shutdown connection. This is terribly implementation-specific.
return [('SSL routines', 'SSL_write', 'protocol is shutdown')]
class TLSTestCase(unittest.TestCase):
"""
Tests for startTLS support.
@ivar fillBuffer: forwarded to L{LineCollector.fillBuffer}
@type fillBuffer: C{bool}
"""
fillBuffer = False
clientProto = None
serverProto = None
def tearDown(self):
if self.clientProto.transport is not None:
self.clientProto.transport.loseConnection()
if self.serverProto.transport is not None:
self.serverProto.transport.loseConnection()
def _runTest(self, clientProto, serverProto, clientIsServer=False):
"""
Helper method to run TLS tests.
@param clientProto: protocol instance attached to the client
connection.
@param serverProto: protocol instance attached to the server
connection.
@param clientIsServer: flag indicated if client should initiate
startTLS instead of server.
@return: a L{defer.Deferred} that will fire when both connections are
lost.
"""
self.clientProto = clientProto
cf = self.clientFactory = protocol.ClientFactory()
cf.protocol = lambda: clientProto
if clientIsServer:
cf.server = False
else:
cf.client = True
self.serverProto = serverProto
sf = self.serverFactory = protocol.ServerFactory()
sf.protocol = lambda: serverProto
if clientIsServer:
sf.client = False
else:
sf.server = True
port = reactor.listenTCP(0, sf, interface="127.0.0.1")
self.addCleanup(port.stopListening)
reactor.connectTCP('127.0.0.1', port.getHost().port, cf)
return defer.gatherResults([clientProto.deferred, serverProto.deferred])
def test_TLS(self):
"""
Test for server and client startTLS: client should received data both
before and after the startTLS.
"""
def check(ignore):
self.assertEqual(
self.serverFactory.lines,
UnintelligentProtocol.pretext + UnintelligentProtocol.posttext
)
d = self._runTest(UnintelligentProtocol(),
LineCollector(True, self.fillBuffer))
return d.addCallback(check)
def test_unTLS(self):
"""
Test for server startTLS not followed by a startTLS in client: the data
received after server startTLS should be received as raw.
"""
def check(ignored):
self.assertEqual(
self.serverFactory.lines,
UnintelligentProtocol.pretext
)
self.failUnless(self.serverFactory.rawdata,
"No encrypted bytes received")
d = self._runTest(UnintelligentProtocol(),
LineCollector(False, self.fillBuffer))
return d.addCallback(check)
def test_backwardsTLS(self):
"""
Test startTLS first initiated by client.
"""
def check(ignored):
self.assertEqual(
self.clientFactory.lines,
UnintelligentProtocol.pretext + UnintelligentProtocol.posttext
)
d = self._runTest(LineCollector(True, self.fillBuffer),
UnintelligentProtocol(), True)
return d.addCallback(check)
class SpammyTLSTestCase(TLSTestCase):
"""
Test TLS features with bytes sitting in the out buffer.
"""
fillBuffer = True
class BufferingTestCase(unittest.TestCase):
serverProto = None
clientProto = None
def tearDown(self):
if self.serverProto.transport is not None:
self.serverProto.transport.loseConnection()
if self.clientProto.transport is not None:
self.clientProto.transport.loseConnection()
def test_openSSLBuffering(self):
serverProto = self.serverProto = SingleLineServerProtocol()
clientProto = self.clientProto = RecordingClientProtocol()
server = protocol.ServerFactory()
client = self.client = protocol.ClientFactory()
server.protocol = lambda: serverProto
client.protocol = lambda: clientProto
sCTX = ssl.DefaultOpenSSLContextFactory(certPath, certPath)
cCTX = ssl.ClientContextFactory()
port = reactor.listenSSL(0, server, sCTX, interface='127.0.0.1')
self.addCleanup(port.stopListening)
reactor.connectSSL('127.0.0.1', port.getHost().port, client, cCTX)
return clientProto.deferred.addCallback(
self.assertEqual, b"+OK <some crap>\r\n")
class ConnectionLostTestCase(unittest.TestCase, ContextGeneratingMixin):
"""
SSL connection closing tests.
"""
def testImmediateDisconnect(self):
org = "twisted.test.test_ssl"
self.setupServerAndClient(
(org, org + ", client"), {},
(org, org + ", server"), {})
# Set up a server, connect to it with a client, which should work since our verifiers
# allow anything, then disconnect.
serverProtocolFactory = protocol.ServerFactory()
serverProtocolFactory.protocol = protocol.Protocol
self.serverPort = serverPort = reactor.listenSSL(0,
serverProtocolFactory, self.serverCtxFactory)
clientProtocolFactory = protocol.ClientFactory()
clientProtocolFactory.protocol = ImmediatelyDisconnectingProtocol
clientProtocolFactory.connectionDisconnected = defer.Deferred()
clientConnector = reactor.connectSSL('127.0.0.1',
serverPort.getHost().port, clientProtocolFactory, self.clientCtxFactory)
return clientProtocolFactory.connectionDisconnected.addCallback(
lambda ignoredResult: self.serverPort.stopListening())
def test_bothSidesLoseConnection(self):
"""
Both sides of SSL connection close connection; the connections should
close cleanly, and only after the underlying TCP connection has
disconnected.
"""
class CloseAfterHandshake(protocol.Protocol):
gotData = False
def __init__(self):
self.done = defer.Deferred()
def connectionMade(self):
self.transport.write(b"a")
def dataReceived(self, data):
# If we got data, handshake is over:
self.gotData = True
self.transport.loseConnection()
def connectionLost(self, reason):
if not self.gotData:
reason = RuntimeError("We never received the data!")
self.done.errback(reason)
del self.done
org = "twisted.test.test_ssl"
self.setupServerAndClient(
(org, org + ", client"), {},
(org, org + ", server"), {})
serverProtocol = CloseAfterHandshake()
serverProtocolFactory = protocol.ServerFactory()
serverProtocolFactory.protocol = lambda: serverProtocol
serverPort = reactor.listenSSL(0,
serverProtocolFactory, self.serverCtxFactory)
self.addCleanup(serverPort.stopListening)
clientProtocol = CloseAfterHandshake()
clientProtocolFactory = protocol.ClientFactory()
clientProtocolFactory.protocol = lambda: clientProtocol
clientConnector = reactor.connectSSL('127.0.0.1',
serverPort.getHost().port, clientProtocolFactory, self.clientCtxFactory)
def checkResult(failure):
failure.trap(ConnectionDone)
return defer.gatherResults(
[clientProtocol.done.addErrback(checkResult),
serverProtocol.done.addErrback(checkResult)])
if newTLS is None:
test_bothSidesLoseConnection.skip = "Old SSL code doesn't always close cleanly."
def testFailedVerify(self):
org = "twisted.test.test_ssl"
self.setupServerAndClient(
(org, org + ", client"), {},
(org, org + ", server"), {})
def verify(*a):
return False
self.clientCtxFactory.getContext().set_verify(SSL.VERIFY_PEER, verify)
serverConnLost = defer.Deferred()
serverProtocol = protocol.Protocol()
serverProtocol.connectionLost = serverConnLost.callback
serverProtocolFactory = protocol.ServerFactory()
serverProtocolFactory.protocol = lambda: serverProtocol
self.serverPort = serverPort = reactor.listenSSL(0,
serverProtocolFactory, self.serverCtxFactory)
clientConnLost = defer.Deferred()
clientProtocol = protocol.Protocol()
clientProtocol.connectionLost = clientConnLost.callback
clientProtocolFactory = protocol.ClientFactory()
clientProtocolFactory.protocol = lambda: clientProtocol
clientConnector = reactor.connectSSL('127.0.0.1',
serverPort.getHost().port, clientProtocolFactory, self.clientCtxFactory)
dl = defer.DeferredList([serverConnLost, clientConnLost], consumeErrors=True)
return dl.addCallback(self._cbLostConns)
def _cbLostConns(self, results):
(sSuccess, sResult), (cSuccess, cResult) = results
self.failIf(sSuccess)
self.failIf(cSuccess)
acceptableErrors = [SSL.Error]
# Rather than getting a verification failure on Windows, we are getting
# a connection failure. Without something like sslverify proxying
# in-between we can't fix up the platform's errors, so let's just
# specifically say it is only OK in this one case to keep the tests
# passing. Normally we'd like to be as strict as possible here, so
# we're not going to allow this to report errors incorrectly on any
# other platforms.
if platform.isWindows():
from twisted.internet.error import ConnectionLost
acceptableErrors.append(ConnectionLost)
sResult.trap(*acceptableErrors)
cResult.trap(*acceptableErrors)
return self.serverPort.stopListening()
class FakeContext:
"""
L{OpenSSL.SSL.Context} double which can more easily be inspected.
"""
def __init__(self, method):
self._method = method
self._options = 0
def set_options(self, options):
self._options |= options
def use_certificate_file(self, fileName):
pass
def use_privatekey_file(self, fileName):
pass
class DefaultOpenSSLContextFactoryTests(unittest.TestCase):
"""
Tests for L{ssl.DefaultOpenSSLContextFactory}.
"""
def setUp(self):
# pyOpenSSL Context objects aren't introspectable enough. Pass in
# an alternate context factory so we can inspect what is done to it.
self.contextFactory = ssl.DefaultOpenSSLContextFactory(
certPath, certPath, _contextFactory=FakeContext)
self.context = self.contextFactory.getContext()
def test_method(self):
"""
L{ssl.DefaultOpenSSLContextFactory.getContext} returns an SSL context
which can use SSLv3 or TLSv1 but not SSLv2.
"""
# SSLv23_METHOD allows SSLv2, SSLv3, or TLSv1
self.assertEqual(self.context._method, SSL.SSLv23_METHOD)
# And OP_NO_SSLv2 disables the SSLv2 support.
self.assertTrue(self.context._options & SSL.OP_NO_SSLv2)
# Make sure SSLv3 and TLSv1 aren't disabled though.
self.assertFalse(self.context._options & SSL.OP_NO_SSLv3)
self.assertFalse(self.context._options & SSL.OP_NO_TLSv1)
def test_missingCertificateFile(self):
"""
Instantiating L{ssl.DefaultOpenSSLContextFactory} with a certificate
filename which does not identify an existing file results in the
initializer raising L{OpenSSL.SSL.Error}.
"""
self.assertRaises(
SSL.Error,
ssl.DefaultOpenSSLContextFactory, certPath, self.mktemp())
def test_missingPrivateKeyFile(self):
"""
Instantiating L{ssl.DefaultOpenSSLContextFactory} with a private key
filename which does not identify an existing file results in the
initializer raising L{OpenSSL.SSL.Error}.
"""
self.assertRaises(
SSL.Error,
ssl.DefaultOpenSSLContextFactory, self.mktemp(), certPath)
class ClientContextFactoryTests(unittest.TestCase):
"""
Tests for L{ssl.ClientContextFactory}.
"""
def setUp(self):
self.contextFactory = ssl.ClientContextFactory()
self.contextFactory._contextFactory = FakeContext
self.context = self.contextFactory.getContext()
def test_method(self):
"""
L{ssl.ClientContextFactory.getContext} returns a context which can use
SSLv3 or TLSv1 but not SSLv2.
"""
self.assertEqual(self.context._method, SSL.SSLv23_METHOD)
self.assertTrue(self.context._options & SSL.OP_NO_SSLv2)
self.assertFalse(self.context._options & SSL.OP_NO_SSLv3)
self.assertFalse(self.context._options & SSL.OP_NO_TLSv1)
if interfaces.IReactorSSL(reactor, None) is None:
for tCase in [StolenTCPTestCase, TLSTestCase, SpammyTLSTestCase,
BufferingTestCase, ConnectionLostTestCase,
DefaultOpenSSLContextFactoryTests,
ClientContextFactoryTests]:
tCase.skip = "Reactor does not support SSL, cannot run SSL tests"

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,81 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for twisted.protocols.stateful
"""
from twisted.trial.unittest import TestCase
from twisted.protocols.test import test_basic
from twisted.protocols.stateful import StatefulProtocol
from struct import pack, unpack, calcsize
class MyInt32StringReceiver(StatefulProtocol):
"""
A stateful Int32StringReceiver.
"""
MAX_LENGTH = 99999
structFormat = "!I"
prefixLength = calcsize(structFormat)
def getInitialState(self):
return self._getHeader, 4
def lengthLimitExceeded(self, length):
self.transport.loseConnection()
def _getHeader(self, msg):
length, = unpack("!i", msg)
if length > self.MAX_LENGTH:
self.lengthLimitExceeded(length)
return
return self._getString, length
def _getString(self, msg):
self.stringReceived(msg)
return self._getHeader, 4
def stringReceived(self, msg):
"""
Override this.
"""
raise NotImplementedError
def sendString(self, data):
"""
Send an int32-prefixed string to the other end of the connection.
"""
self.transport.write(pack(self.structFormat, len(data)) + data)
class TestInt32(MyInt32StringReceiver):
def connectionMade(self):
self.received = []
def stringReceived(self, s):
self.received.append(s)
MAX_LENGTH = 50
closed = 0
def connectionLost(self, reason):
self.closed = 1
class Int32TestCase(TestCase, test_basic.IntNTestCaseMixin):
protocol = TestInt32
strings = ["a", "b" * 16]
illegalStrings = ["\x10\x00\x00\x00aaaaaa"]
partialStrings = ["\x00\x00\x00", "hello there", ""]
def test_bigReceive(self):
r = self.getProtocol()
big = ""
for s in self.strings * 4:
big += pack("!i", len(s)) + s
r.dataReceived(big)
self.assertEqual(r.received, self.strings * 4)

View file

@ -0,0 +1,371 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.internet.stdio}.
"""
import os, sys, itertools
from twisted.trial import unittest
from twisted.python import filepath, log
from twisted.python.runtime import platform
from twisted.internet import error, defer, protocol, stdio, reactor
from twisted.test.test_tcp import ConnectionLostNotifyingProtocol
# A short string which is intended to appear here and nowhere else,
# particularly not in any random garbage output CPython unavoidable
# generates (such as in warning text and so forth). This is searched
# for in the output from stdio_test_lastwrite.py and if it is found at
# the end, the functionality works.
UNIQUE_LAST_WRITE_STRING = 'xyz123abc Twisted is great!'
skipWindowsNopywin32 = None
if platform.isWindows():
try:
import win32process
except ImportError:
skipWindowsNopywin32 = ("On windows, spawnProcess is not available "
"in the absence of win32process.")
class StandardIOTestProcessProtocol(protocol.ProcessProtocol):
"""
Test helper for collecting output from a child process and notifying
something when it exits.
@ivar onConnection: A L{defer.Deferred} which will be called back with
C{None} when the connection to the child process is established.
@ivar onCompletion: A L{defer.Deferred} which will be errbacked with the
failure associated with the child process exiting when it exits.
@ivar onDataReceived: A L{defer.Deferred} which will be called back with
this instance whenever C{childDataReceived} is called, or C{None} to
suppress these callbacks.
@ivar data: A C{dict} mapping file descriptors to strings containing all
bytes received from the child process on each file descriptor.
"""
onDataReceived = None
def __init__(self):
self.onConnection = defer.Deferred()
self.onCompletion = defer.Deferred()
self.data = {}
def connectionMade(self):
self.onConnection.callback(None)
def childDataReceived(self, name, bytes):
"""
Record all bytes received from the child process in the C{data}
dictionary. Fire C{onDataReceived} if it is not C{None}.
"""
self.data[name] = self.data.get(name, '') + bytes
if self.onDataReceived is not None:
d, self.onDataReceived = self.onDataReceived, None
d.callback(self)
def processEnded(self, reason):
self.onCompletion.callback(reason)
class StandardInputOutputTestCase(unittest.TestCase):
skip = skipWindowsNopywin32
def _spawnProcess(self, proto, sibling, *args, **kw):
"""
Launch a child Python process and communicate with it using the
given ProcessProtocol.
@param proto: A L{ProcessProtocol} instance which will be connected
to the child process.
@param sibling: The basename of a file containing the Python program
to run in the child process.
@param *args: strings which will be passed to the child process on
the command line as C{argv[2:]}.
@param **kw: additional arguments to pass to L{reactor.spawnProcess}.
@return: The L{IProcessTransport} provider for the spawned process.
"""
import twisted
subenv = dict(os.environ)
subenv['PYTHONPATH'] = os.pathsep.join(
[os.path.abspath(
os.path.dirname(os.path.dirname(twisted.__file__))),
subenv.get('PYTHONPATH', '')
])
args = [sys.executable,
filepath.FilePath(__file__).sibling(sibling).path,
reactor.__class__.__module__] + list(args)
return reactor.spawnProcess(
proto,
sys.executable,
args,
env=subenv,
**kw)
def _requireFailure(self, d, callback):
def cb(result):
self.fail("Process terminated with non-Failure: %r" % (result,))
def eb(err):
return callback(err)
return d.addCallbacks(cb, eb)
def test_loseConnection(self):
"""
Verify that a protocol connected to L{StandardIO} can disconnect
itself using C{transport.loseConnection}.
"""
errorLogFile = self.mktemp()
log.msg("Child process logging to " + errorLogFile)
p = StandardIOTestProcessProtocol()
d = p.onCompletion
self._spawnProcess(p, 'stdio_test_loseconn.py', errorLogFile)
def processEnded(reason):
# Copy the child's log to ours so it's more visible.
for line in file(errorLogFile):
log.msg("Child logged: " + line.rstrip())
self.failIfIn(1, p.data)
reason.trap(error.ProcessDone)
return self._requireFailure(d, processEnded)
def test_readConnectionLost(self):
"""
When stdin is closed and the protocol connected to it implements
L{IHalfCloseableProtocol}, the protocol's C{readConnectionLost} method
is called.
"""
errorLogFile = self.mktemp()
log.msg("Child process logging to " + errorLogFile)
p = StandardIOTestProcessProtocol()
p.onDataReceived = defer.Deferred()
def cbBytes(ignored):
d = p.onCompletion
p.transport.closeStdin()
return d
p.onDataReceived.addCallback(cbBytes)
def processEnded(reason):
reason.trap(error.ProcessDone)
d = self._requireFailure(p.onDataReceived, processEnded)
self._spawnProcess(
p, 'stdio_test_halfclose.py', errorLogFile)
return d
def test_lastWriteReceived(self):
"""
Verify that a write made directly to stdout using L{os.write}
after StandardIO has finished is reliably received by the
process reading that stdout.
"""
p = StandardIOTestProcessProtocol()
# Note: the OS X bug which prompted the addition of this test
# is an apparent race condition involving non-blocking PTYs.
# Delaying the parent process significantly increases the
# likelihood of the race going the wrong way. If you need to
# fiddle with this code at all, uncommenting the next line
# will likely make your life much easier. It is commented out
# because it makes the test quite slow.
# p.onConnection.addCallback(lambda ign: __import__('time').sleep(5))
try:
self._spawnProcess(
p, 'stdio_test_lastwrite.py', UNIQUE_LAST_WRITE_STRING,
usePTY=True)
except ValueError, e:
# Some platforms don't work with usePTY=True
raise unittest.SkipTest(str(e))
def processEnded(reason):
"""
Asserts that the parent received the bytes written by the child
immediately after the child starts.
"""
self.assertTrue(
p.data[1].endswith(UNIQUE_LAST_WRITE_STRING),
"Received %r from child, did not find expected bytes." % (
p.data,))
reason.trap(error.ProcessDone)
return self._requireFailure(p.onCompletion, processEnded)
def test_hostAndPeer(self):
"""
Verify that the transport of a protocol connected to L{StandardIO}
has C{getHost} and C{getPeer} methods.
"""
p = StandardIOTestProcessProtocol()
d = p.onCompletion
self._spawnProcess(p, 'stdio_test_hostpeer.py')
def processEnded(reason):
host, peer = p.data[1].splitlines()
self.failUnless(host)
self.failUnless(peer)
reason.trap(error.ProcessDone)
return self._requireFailure(d, processEnded)
def test_write(self):
"""
Verify that the C{write} method of the transport of a protocol
connected to L{StandardIO} sends bytes to standard out.
"""
p = StandardIOTestProcessProtocol()
d = p.onCompletion
self._spawnProcess(p, 'stdio_test_write.py')
def processEnded(reason):
self.assertEqual(p.data[1], 'ok!')
reason.trap(error.ProcessDone)
return self._requireFailure(d, processEnded)
def test_writeSequence(self):
"""
Verify that the C{writeSequence} method of the transport of a
protocol connected to L{StandardIO} sends bytes to standard out.
"""
p = StandardIOTestProcessProtocol()
d = p.onCompletion
self._spawnProcess(p, 'stdio_test_writeseq.py')
def processEnded(reason):
self.assertEqual(p.data[1], 'ok!')
reason.trap(error.ProcessDone)
return self._requireFailure(d, processEnded)
def _junkPath(self):
junkPath = self.mktemp()
junkFile = file(junkPath, 'w')
for i in xrange(1024):
junkFile.write(str(i) + '\n')
junkFile.close()
return junkPath
def test_producer(self):
"""
Verify that the transport of a protocol connected to L{StandardIO}
is a working L{IProducer} provider.
"""
p = StandardIOTestProcessProtocol()
d = p.onCompletion
written = []
toWrite = range(100)
def connectionMade(ign):
if toWrite:
written.append(str(toWrite.pop()) + "\n")
proc.write(written[-1])
reactor.callLater(0.01, connectionMade, None)
proc = self._spawnProcess(p, 'stdio_test_producer.py')
p.onConnection.addCallback(connectionMade)
def processEnded(reason):
self.assertEqual(p.data[1], ''.join(written))
self.failIf(toWrite, "Connection lost with %d writes left to go." % (len(toWrite),))
reason.trap(error.ProcessDone)
return self._requireFailure(d, processEnded)
def test_consumer(self):
"""
Verify that the transport of a protocol connected to L{StandardIO}
is a working L{IConsumer} provider.
"""
p = StandardIOTestProcessProtocol()
d = p.onCompletion
junkPath = self._junkPath()
self._spawnProcess(p, 'stdio_test_consumer.py', junkPath)
def processEnded(reason):
self.assertEqual(p.data[1], file(junkPath).read())
reason.trap(error.ProcessDone)
return self._requireFailure(d, processEnded)
def test_normalFileStandardOut(self):
"""
If L{StandardIO} is created with a file descriptor which refers to a
normal file (ie, a file from the filesystem), L{StandardIO.write}
writes bytes to that file. In particular, it does not immediately
consider the file closed or call its protocol's C{connectionLost}
method.
"""
onConnLost = defer.Deferred()
proto = ConnectionLostNotifyingProtocol(onConnLost)
path = filepath.FilePath(self.mktemp())
self.normal = normal = path.open('w')
self.addCleanup(normal.close)
kwargs = dict(stdout=normal.fileno())
if not platform.isWindows():
# Make a fake stdin so that StandardIO doesn't mess with the *real*
# stdin.
r, w = os.pipe()
self.addCleanup(os.close, r)
self.addCleanup(os.close, w)
kwargs['stdin'] = r
connection = stdio.StandardIO(proto, **kwargs)
# The reactor needs to spin a bit before it might have incorrectly
# decided stdout is closed. Use this counter to keep track of how
# much we've let it spin. If it closes before we expected, this
# counter will have a value that's too small and we'll know.
howMany = 5
count = itertools.count()
def spin():
for value in count:
if value == howMany:
connection.loseConnection()
return
connection.write(str(value))
break
reactor.callLater(0, spin)
reactor.callLater(0, spin)
# Once the connection is lost, make sure the counter is at the
# appropriate value.
def cbLost(reason):
self.assertEqual(count.next(), howMany + 1)
self.assertEqual(
path.getContent(),
''.join(map(str, range(howMany))))
onConnLost.addCallback(cbLost)
return onConnLost
if platform.isWindows():
test_normalFileStandardOut.skip = (
"StandardIO does not accept stdout as an argument to Windows. "
"Testing redirection to a file is therefore harder.")

Some files were not shown because too many files have changed in this diff Show more