add Linux_i686
This commit is contained in:
parent
75f9a2fcbc
commit
95cd9b11f2
1644 changed files with 564260 additions and 0 deletions
|
|
@ -0,0 +1,10 @@
|
|||
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
Twisted Test: Unit Tests for Twisted.
|
||||
|
||||
"""
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
# This makes sure Twisted-using child processes used in the test suite import
|
||||
# the correct version of Twisted (ie, the version of Twisted under test).
|
||||
|
||||
# This is a copy of the bin/_preamble.py script because it's not clear how to
|
||||
# use the functionality for both things without having a copy.
|
||||
|
||||
import sys, os
|
||||
|
||||
path = os.path.abspath(sys.argv[0])
|
||||
while os.path.dirname(path) != path:
|
||||
if os.path.exists(os.path.join(path, 'twisted', '__init__.py')):
|
||||
sys.path.insert(0, path)
|
||||
break
|
||||
path = os.path.dirname(path)
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
from twisted.python import components
|
||||
from zope.interface import implements, Interface
|
||||
|
||||
def foo():
|
||||
return 2
|
||||
|
||||
class X:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
def do(self):
|
||||
#print 'X',self.x,'doing!'
|
||||
pass
|
||||
|
||||
|
||||
class XComponent(components.Componentized):
|
||||
pass
|
||||
|
||||
class IX(Interface):
|
||||
pass
|
||||
|
||||
class XA(components.Adapter):
|
||||
implements(IX)
|
||||
|
||||
def method(self):
|
||||
# Kick start :(
|
||||
pass
|
||||
|
||||
components.registerAdapter(XA, X, IX)
|
||||
407
Linux_i686/lib/python2.7/site-packages/twisted/test/iosim.py
Normal file
407
Linux_i686/lib/python2.7/site-packages/twisted/test/iosim.py
Normal file
|
|
@ -0,0 +1,407 @@
|
|||
# -*- test-case-name: twisted.test.test_amp.TLSTest,twisted.test.test_iosim -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Utilities and helpers for simulating a network
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
|
||||
try:
|
||||
from OpenSSL.SSL import Error as NativeOpenSSLError
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from zope.interface import implementer, directlyProvides
|
||||
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.internet import error
|
||||
from twisted.internet import interfaces
|
||||
|
||||
class TLSNegotiation:
|
||||
def __init__(self, obj, connectState):
|
||||
self.obj = obj
|
||||
self.connectState = connectState
|
||||
self.sent = False
|
||||
self.readyToSend = connectState
|
||||
|
||||
def __repr__(self):
|
||||
return 'TLSNegotiation(%r)' % (self.obj,)
|
||||
|
||||
def pretendToVerify(self, other, tpt):
|
||||
# Set the transport problems list here? disconnections?
|
||||
# hmmmmm... need some negative path tests.
|
||||
|
||||
if not self.obj.iosimVerify(other.obj):
|
||||
tpt.disconnectReason = NativeOpenSSLError()
|
||||
tpt.loseConnection()
|
||||
|
||||
|
||||
|
||||
implementer(interfaces.IAddress)
|
||||
class FakeAddress(object):
|
||||
"""
|
||||
The default address type for the host and peer of L{FakeTransport}
|
||||
connections.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@implementer(interfaces.ITransport,
|
||||
interfaces.ITLSTransport)
|
||||
class FakeTransport:
|
||||
"""
|
||||
A wrapper around a file-like object to make it behave as a Transport.
|
||||
|
||||
This doesn't actually stream the file to the attached protocol,
|
||||
and is thus useful mainly as a utility for debugging protocols.
|
||||
"""
|
||||
|
||||
_nextserial = staticmethod(lambda counter=itertools.count(): next(counter))
|
||||
closed = 0
|
||||
disconnecting = 0
|
||||
disconnected = 0
|
||||
disconnectReason = error.ConnectionDone("Connection done")
|
||||
producer = None
|
||||
streamingProducer = 0
|
||||
tls = None
|
||||
|
||||
def __init__(self, protocol, isServer, hostAddress=None, peerAddress=None):
|
||||
"""
|
||||
@param protocol: This transport will deliver bytes to this protocol.
|
||||
@type protocol: L{IProtocol} provider
|
||||
|
||||
@param isServer: C{True} if this is the accepting side of the
|
||||
connection, C{False} if it is the connecting side.
|
||||
@type isServer: L{bool}
|
||||
|
||||
@param hostAddress: The value to return from C{getHost}. C{None}
|
||||
results in a new L{FakeAddress} being created to use as the value.
|
||||
@type hostAddress: L{IAddress} provider or L{NoneType}
|
||||
|
||||
@param peerAddress: The value to return from C{getPeer}. C{None}
|
||||
results in a new L{FakeAddress} being created to use as the value.
|
||||
@type peerAddress: L{IAddress} provider or L{NoneType}
|
||||
"""
|
||||
self.protocol = protocol
|
||||
self.isServer = isServer
|
||||
self.stream = []
|
||||
self.serial = self._nextserial()
|
||||
if hostAddress is None:
|
||||
hostAddress = FakeAddress()
|
||||
self.hostAddress = hostAddress
|
||||
if peerAddress is None:
|
||||
peerAddress = FakeAddress()
|
||||
self.peerAddress = peerAddress
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return 'FakeTransport<%s,%s,%s>' % (
|
||||
self.isServer and 'S' or 'C', self.serial,
|
||||
self.protocol.__class__.__name__)
|
||||
|
||||
|
||||
def write(self, data):
|
||||
if self.tls is not None:
|
||||
self.tlsbuf.append(data)
|
||||
else:
|
||||
self.stream.append(data)
|
||||
|
||||
def _checkProducer(self):
|
||||
# Cheating; this is called at "idle" times to allow producers to be
|
||||
# found and dealt with
|
||||
if self.producer:
|
||||
self.producer.resumeProducing()
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
"""From abstract.FileDescriptor
|
||||
"""
|
||||
self.producer = producer
|
||||
self.streamingProducer = streaming
|
||||
if not streaming:
|
||||
producer.resumeProducing()
|
||||
|
||||
def unregisterProducer(self):
|
||||
self.producer = None
|
||||
|
||||
def stopConsuming(self):
|
||||
self.unregisterProducer()
|
||||
self.loseConnection()
|
||||
|
||||
def writeSequence(self, iovec):
|
||||
self.write("".join(iovec))
|
||||
|
||||
def loseConnection(self):
|
||||
self.disconnecting = True
|
||||
|
||||
|
||||
def abortConnection(self):
|
||||
"""
|
||||
For the time being, this is the same as loseConnection; no buffered
|
||||
data will be lost.
|
||||
"""
|
||||
self.disconnecting = True
|
||||
|
||||
|
||||
def reportDisconnect(self):
|
||||
if self.tls is not None:
|
||||
# We were in the middle of negotiating! Must have been a TLS problem.
|
||||
err = NativeOpenSSLError()
|
||||
else:
|
||||
err = self.disconnectReason
|
||||
self.protocol.connectionLost(Failure(err))
|
||||
|
||||
def logPrefix(self):
|
||||
"""
|
||||
Identify this transport/event source to the logging system.
|
||||
"""
|
||||
return "iosim"
|
||||
|
||||
def getPeer(self):
|
||||
return self.peerAddress
|
||||
|
||||
def getHost(self):
|
||||
return self.hostAddress
|
||||
|
||||
def resumeProducing(self):
|
||||
# Never sends data anyways
|
||||
pass
|
||||
|
||||
def pauseProducing(self):
|
||||
# Never sends data anyways
|
||||
pass
|
||||
|
||||
def stopProducing(self):
|
||||
self.loseConnection()
|
||||
|
||||
def startTLS(self, contextFactory, beNormal=True):
|
||||
# Nothing's using this feature yet, but startTLS has an undocumented
|
||||
# second argument which defaults to true; if set to False, servers will
|
||||
# behave like clients and clients will behave like servers.
|
||||
connectState = self.isServer ^ beNormal
|
||||
self.tls = TLSNegotiation(contextFactory, connectState)
|
||||
self.tlsbuf = []
|
||||
|
||||
|
||||
def getOutBuffer(self):
|
||||
"""
|
||||
Get the pending writes from this transport, clearing them from the
|
||||
pending buffer.
|
||||
|
||||
@return: the bytes written with C{transport.write}
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
S = self.stream
|
||||
if S:
|
||||
self.stream = []
|
||||
return b''.join(S)
|
||||
elif self.tls is not None:
|
||||
if self.tls.readyToSend:
|
||||
# Only _send_ the TLS negotiation "packet" if I'm ready to.
|
||||
self.tls.sent = True
|
||||
return self.tls
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def bufferReceived(self, buf):
|
||||
if isinstance(buf, TLSNegotiation):
|
||||
assert self.tls is not None # By the time you're receiving a
|
||||
# negotiation, you have to have called
|
||||
# startTLS already.
|
||||
if self.tls.sent:
|
||||
self.tls.pretendToVerify(buf, self)
|
||||
self.tls = None # we're done with the handshake if we've gotten
|
||||
# this far... although maybe it failed...?
|
||||
# TLS started! Unbuffer...
|
||||
b, self.tlsbuf = self.tlsbuf, None
|
||||
self.writeSequence(b)
|
||||
directlyProvides(self, interfaces.ISSLTransport)
|
||||
else:
|
||||
# We haven't sent our own TLS negotiation: time to do that!
|
||||
self.tls.readyToSend = True
|
||||
else:
|
||||
self.protocol.dataReceived(buf)
|
||||
|
||||
|
||||
|
||||
def makeFakeClient(clientProtocol):
|
||||
"""
|
||||
Create and return a new in-memory transport hooked up to the given protocol.
|
||||
|
||||
@param clientProtocol: The client protocol to use.
|
||||
@type clientProtocol: L{IProtocol} provider
|
||||
|
||||
@return: The transport.
|
||||
@rtype: L{FakeTransport}
|
||||
"""
|
||||
return FakeTransport(clientProtocol, isServer=False)
|
||||
|
||||
|
||||
|
||||
def makeFakeServer(serverProtocol):
|
||||
"""
|
||||
Create and return a new in-memory transport hooked up to the given protocol.
|
||||
|
||||
@param serverProtocol: The server protocol to use.
|
||||
@type serverProtocol: L{IProtocol} provider
|
||||
|
||||
@return: The transport.
|
||||
@rtype: L{FakeTransport}
|
||||
"""
|
||||
return FakeTransport(serverProtocol, isServer=True)
|
||||
|
||||
|
||||
|
||||
class IOPump:
|
||||
"""Utility to pump data between clients and servers for protocol testing.
|
||||
|
||||
Perhaps this is a utility worthy of being in protocol.py?
|
||||
"""
|
||||
def __init__(self, client, server, clientIO, serverIO, debug):
|
||||
self.client = client
|
||||
self.server = server
|
||||
self.clientIO = clientIO
|
||||
self.serverIO = serverIO
|
||||
self.debug = debug
|
||||
|
||||
def flush(self, debug=False):
|
||||
"""Pump until there is no more input or output.
|
||||
|
||||
Returns whether any data was moved.
|
||||
"""
|
||||
result = False
|
||||
for x in range(1000):
|
||||
if self.pump(debug):
|
||||
result = True
|
||||
else:
|
||||
break
|
||||
else:
|
||||
assert 0, "Too long"
|
||||
return result
|
||||
|
||||
|
||||
def pump(self, debug=False):
|
||||
"""Move data back and forth.
|
||||
|
||||
Returns whether any data was moved.
|
||||
"""
|
||||
if self.debug or debug:
|
||||
print('-- GLUG --')
|
||||
sData = self.serverIO.getOutBuffer()
|
||||
cData = self.clientIO.getOutBuffer()
|
||||
self.clientIO._checkProducer()
|
||||
self.serverIO._checkProducer()
|
||||
if self.debug or debug:
|
||||
print('.')
|
||||
# XXX slightly buggy in the face of incremental output
|
||||
if cData:
|
||||
print('C: ' + repr(cData))
|
||||
if sData:
|
||||
print('S: ' + repr(sData))
|
||||
if cData:
|
||||
self.serverIO.bufferReceived(cData)
|
||||
if sData:
|
||||
self.clientIO.bufferReceived(sData)
|
||||
if cData or sData:
|
||||
return True
|
||||
if (self.serverIO.disconnecting and
|
||||
not self.serverIO.disconnected):
|
||||
if self.debug or debug:
|
||||
print('* C')
|
||||
self.serverIO.disconnected = True
|
||||
self.clientIO.disconnecting = True
|
||||
self.clientIO.reportDisconnect()
|
||||
return True
|
||||
if self.clientIO.disconnecting and not self.clientIO.disconnected:
|
||||
if self.debug or debug:
|
||||
print('* S')
|
||||
self.clientIO.disconnected = True
|
||||
self.serverIO.disconnecting = True
|
||||
self.serverIO.reportDisconnect()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def connect(serverProtocol, serverTransport,
|
||||
clientProtocol, clientTransport, debug=False):
|
||||
"""
|
||||
Create a new L{IOPump} connecting two protocols.
|
||||
|
||||
@param serverProtocol: The protocol to use on the accepting side of the
|
||||
connection.
|
||||
@type serverProtocol: L{IProtocol} provider
|
||||
|
||||
@param serverTransport: The transport to associate with C{serverProtocol}.
|
||||
@type serverTransport: L{FakeTransport}
|
||||
|
||||
@param clientProtocol: The protocol to use on the initiating side of the
|
||||
connection.
|
||||
@type clientProtocol: L{IProtocol} provider
|
||||
|
||||
@param clientTransport: The transport to associate with C{clientProtocol}.
|
||||
@type clientTransport: L{FakeTransport}
|
||||
|
||||
@param debug: A flag indicating whether to log information about what the
|
||||
L{IOPump} is doing.
|
||||
@type debug: L{bool}
|
||||
|
||||
@return: An L{IOPump} which connects C{serverProtocol} and
|
||||
C{clientProtocol} and delivers bytes between them when it is pumped.
|
||||
@rtype: L{IOPump}
|
||||
"""
|
||||
serverProtocol.makeConnection(serverTransport)
|
||||
clientProtocol.makeConnection(clientTransport)
|
||||
pump = IOPump(
|
||||
clientProtocol, serverProtocol, clientTransport, serverTransport, debug
|
||||
)
|
||||
# kick off server greeting, etc
|
||||
pump.flush()
|
||||
return pump
|
||||
|
||||
|
||||
|
||||
def connectedServerAndClient(ServerClass, ClientClass,
|
||||
clientTransportFactory=makeFakeClient,
|
||||
serverTransportFactory=makeFakeServer,
|
||||
debug=False):
|
||||
"""
|
||||
Connect a given server and client class to each other.
|
||||
|
||||
@param ServerClass: a callable that produces the server-side protocol.
|
||||
@type ServerClass: 0-argument callable returning L{IProtocol} provider.
|
||||
|
||||
@param ClientClass: like C{ServerClass} but for the other side of the
|
||||
connection.
|
||||
@type ClientClass: 0-argument callable returning L{IProtocol} provider.
|
||||
|
||||
@param clientTransportFactory: a callable that produces the transport which
|
||||
will be attached to the protocol returned from C{ClientClass}.
|
||||
@type clientTransportFactory: callable taking (L{IProtocol}) and returning
|
||||
L{FakeTransport}
|
||||
|
||||
@param serverTransportFactory: a callable that produces the transport which
|
||||
will be attached to the protocol returned from C{ServerClass}.
|
||||
@type serverTransportFactory: callable taking (L{IProtocol}) and returning
|
||||
L{FakeTransport}
|
||||
|
||||
@param debug: Should this dump an escaped version of all traffic on this
|
||||
connection to stdout for inspection?
|
||||
@type debug: L{bool}
|
||||
|
||||
@return: the client protocol, the server protocol, and an L{IOPump} which,
|
||||
when its C{pump} and C{flush} methods are called, will move data
|
||||
between the created client and server protocol instances.
|
||||
@rtype: 3-L{tuple} of L{IProtocol}, L{IProtocol}, L{IOPump}
|
||||
"""
|
||||
c = ClientClass()
|
||||
s = ServerClass()
|
||||
cio = clientTransportFactory(c)
|
||||
sio = serverTransportFactory(s)
|
||||
return c, s, connect(s, sio, c, cio, debug)
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
This is a mock win32process module.
|
||||
|
||||
The purpose of this module is mock process creation for the PID test.
|
||||
|
||||
CreateProcess(...) will spawn a process, and always return a PID of 42.
|
||||
"""
|
||||
|
||||
import win32process
|
||||
GetExitCodeProcess = win32process.GetExitCodeProcess
|
||||
STARTUPINFO = win32process.STARTUPINFO
|
||||
|
||||
STARTF_USESTDHANDLES = win32process.STARTF_USESTDHANDLES
|
||||
|
||||
|
||||
def CreateProcess(appName,
|
||||
cmdline,
|
||||
procSecurity,
|
||||
threadSecurity,
|
||||
inheritHandles,
|
||||
newEnvironment,
|
||||
env,
|
||||
workingDir,
|
||||
startupInfo):
|
||||
"""
|
||||
This function mocks the generated pid aspect of the win32.CreateProcess
|
||||
function.
|
||||
- the true win32process.CreateProcess is called
|
||||
- return values are harvested in a tuple.
|
||||
- all return values from createProcess are passed back to the calling
|
||||
function except for the pid, the returned pid is hardcoded to 42
|
||||
"""
|
||||
|
||||
hProcess, hThread, dwPid, dwTid = win32process.CreateProcess(
|
||||
appName,
|
||||
cmdline,
|
||||
procSecurity,
|
||||
threadSecurity,
|
||||
inheritHandles,
|
||||
newEnvironment,
|
||||
env,
|
||||
workingDir,
|
||||
startupInfo)
|
||||
dwPid = 42
|
||||
return (hProcess, hThread, dwPid, dwTid)
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
|
||||
class A:
|
||||
def a(self):
|
||||
return 'a'
|
||||
try:
|
||||
object
|
||||
except NameError:
|
||||
pass
|
||||
else:
|
||||
class B(object, A):
|
||||
def b(self):
|
||||
return 'b'
|
||||
class Inherit(A):
|
||||
def a(self):
|
||||
return 'c'
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
|
||||
class A:
|
||||
def a(self):
|
||||
return 'b'
|
||||
try:
|
||||
object
|
||||
except NameError:
|
||||
pass
|
||||
else:
|
||||
class B(A, object):
|
||||
def b(self):
|
||||
return 'c'
|
||||
|
||||
class Inherit(A):
|
||||
def a(self):
|
||||
return 'd'
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) 2005 Divmod, Inc.
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
# Don't change the docstring, it's part of the tests
|
||||
"""
|
||||
I'm a test drop-in. The plugin system's unit tests use me. No one
|
||||
else should.
|
||||
"""
|
||||
|
||||
from zope.interface import classProvides
|
||||
|
||||
from twisted.plugin import IPlugin
|
||||
from twisted.test.test_plugin import ITestPlugin, ITestPlugin2
|
||||
|
||||
|
||||
|
||||
class TestPlugin:
|
||||
"""
|
||||
A plugin used solely for testing purposes.
|
||||
"""
|
||||
|
||||
classProvides(ITestPlugin,
|
||||
IPlugin)
|
||||
|
||||
def test1():
|
||||
pass
|
||||
test1 = staticmethod(test1)
|
||||
|
||||
|
||||
|
||||
class AnotherTestPlugin:
|
||||
"""
|
||||
Another plugin used solely for testing purposes.
|
||||
"""
|
||||
|
||||
classProvides(ITestPlugin2,
|
||||
IPlugin)
|
||||
|
||||
def test():
|
||||
pass
|
||||
test = staticmethod(test)
|
||||
|
||||
|
||||
|
||||
class ThirdTestPlugin:
|
||||
"""
|
||||
Another plugin used solely for testing purposes.
|
||||
"""
|
||||
|
||||
classProvides(ITestPlugin2,
|
||||
IPlugin)
|
||||
|
||||
def test():
|
||||
pass
|
||||
test = staticmethod(test)
|
||||
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) 2005 Divmod, Inc.
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test plugin used in L{twisted.test.test_plugin}.
|
||||
"""
|
||||
|
||||
from zope.interface import classProvides
|
||||
|
||||
from twisted.plugin import IPlugin
|
||||
from twisted.test.test_plugin import ITestPlugin
|
||||
|
||||
|
||||
|
||||
class FourthTestPlugin:
|
||||
classProvides(ITestPlugin,
|
||||
IPlugin)
|
||||
|
||||
def test1():
|
||||
pass
|
||||
test1 = staticmethod(test1)
|
||||
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) 2005 Divmod, Inc.
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test plugin used in L{twisted.test.test_plugin}.
|
||||
"""
|
||||
|
||||
from zope.interface import classProvides
|
||||
|
||||
from twisted.plugin import IPlugin
|
||||
from twisted.test.test_plugin import ITestPlugin
|
||||
|
||||
|
||||
|
||||
class FourthTestPlugin:
|
||||
classProvides(ITestPlugin,
|
||||
IPlugin)
|
||||
|
||||
def test1():
|
||||
pass
|
||||
test1 = staticmethod(test1)
|
||||
|
||||
|
||||
|
||||
class FifthTestPlugin:
|
||||
"""
|
||||
More documentation: I hate you.
|
||||
"""
|
||||
classProvides(ITestPlugin,
|
||||
IPlugin)
|
||||
|
||||
def test1():
|
||||
pass
|
||||
test1 = staticmethod(test1)
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
"""Write to stdout the command line args it received, one per line."""
|
||||
|
||||
import sys
|
||||
for x in sys.argv[1:]:
|
||||
print x
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
"""Write back all data it receives."""
|
||||
|
||||
import sys
|
||||
|
||||
data = sys.stdin.read(1)
|
||||
while data:
|
||||
sys.stdout.write(data)
|
||||
sys.stdout.flush()
|
||||
data = sys.stdin.read(1)
|
||||
sys.stderr.write("byebye")
|
||||
sys.stderr.flush()
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
|
||||
"""Write to a handful of file descriptors, to test the childFDs= argument of
|
||||
reactor.spawnProcess()
|
||||
"""
|
||||
|
||||
import os, sys
|
||||
|
||||
debug = 0
|
||||
|
||||
if debug: stderr = os.fdopen(2, "w")
|
||||
|
||||
if debug: print >>stderr, "this is stderr"
|
||||
|
||||
abcd = os.read(0, 4)
|
||||
if debug: print >>stderr, "read(0):", abcd
|
||||
if abcd != "abcd":
|
||||
sys.exit(1)
|
||||
|
||||
if debug: print >>stderr, "os.write(1, righto)"
|
||||
|
||||
os.write(1, "righto")
|
||||
|
||||
efgh = os.read(3, 4)
|
||||
if debug: print >>stderr, "read(3):", efgh
|
||||
if efgh != "efgh":
|
||||
sys.exit(2)
|
||||
|
||||
if debug: print >>stderr, "os.close(4)"
|
||||
os.close(4)
|
||||
|
||||
eof = os.read(5, 4)
|
||||
if debug: print >>stderr, "read(5):", eof
|
||||
if eof != "":
|
||||
sys.exit(3)
|
||||
|
||||
if debug: print >>stderr, "os.write(1, closed)"
|
||||
os.write(1, "closed")
|
||||
|
||||
if debug: print >>stderr, "sys.exit(0)"
|
||||
sys.exit(0)
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
|
||||
"""Write to a file descriptor and then close it, waiting a few seconds before
|
||||
quitting. This serves to make sure SIGCHLD is actually being noticed.
|
||||
"""
|
||||
|
||||
import os, sys, time
|
||||
|
||||
print "here is some text"
|
||||
time.sleep(1)
|
||||
print "goodbye"
|
||||
os.close(1)
|
||||
os.close(2)
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""Script used by test_process.TestTwoProcesses"""
|
||||
|
||||
# run until stdin is closed, then quit
|
||||
|
||||
import sys
|
||||
|
||||
while 1:
|
||||
d = sys.stdin.read()
|
||||
if len(d) == 0:
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
import sys, signal
|
||||
|
||||
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
||||
if getattr(signal, "SIGHUP", None) is not None:
|
||||
signal.signal(signal.SIGHUP, signal.SIG_DFL)
|
||||
print 'ok, signal us'
|
||||
sys.stdin.read()
|
||||
sys.exit(1)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""Script used by twisted.test.test_process on win32."""
|
||||
|
||||
import sys, time, os, msvcrt
|
||||
msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
|
||||
msvcrt.setmode(sys.stderr.fileno(), os.O_BINARY)
|
||||
|
||||
|
||||
sys.stdout.write("out\n")
|
||||
sys.stdout.flush()
|
||||
sys.stderr.write("err\n")
|
||||
sys.stderr.flush()
|
||||
|
||||
data = sys.stdin.read()
|
||||
|
||||
sys.stdout.write(data)
|
||||
sys.stdout.write("\nout\n")
|
||||
sys.stderr.write("err\n")
|
||||
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
"""Test program for processes."""
|
||||
|
||||
import sys, os
|
||||
|
||||
test_file_match = "process_test.log.*"
|
||||
test_file = "process_test.log.%d" % os.getpid()
|
||||
|
||||
def main():
|
||||
f = open(test_file, 'wb')
|
||||
|
||||
# stage 1
|
||||
bytes = sys.stdin.read(4)
|
||||
f.write("one: %r\n" % bytes)
|
||||
# stage 2
|
||||
sys.stdout.write(bytes)
|
||||
sys.stdout.flush()
|
||||
os.close(sys.stdout.fileno())
|
||||
|
||||
# and a one, and a two, and a...
|
||||
bytes = sys.stdin.read(4)
|
||||
f.write("two: %r\n" % bytes)
|
||||
|
||||
# stage 3
|
||||
sys.stderr.write(bytes)
|
||||
sys.stderr.flush()
|
||||
os.close(sys.stderr.fileno())
|
||||
|
||||
# stage 4
|
||||
bytes = sys.stdin.read(4)
|
||||
f.write("three: %r\n" % bytes)
|
||||
|
||||
# exit with status code 23
|
||||
sys.exit(23)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
"""Test to make sure we can open /dev/tty"""
|
||||
|
||||
f = open("/dev/tty", "r+")
|
||||
a = f.readline()
|
||||
f.write(a)
|
||||
f.close()
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
"""A process that reads from stdin and out using Twisted."""
|
||||
|
||||
### Twisted Preamble
|
||||
# This makes sure that users don't have to set up their environment
|
||||
# specially in order to run these programs from bin/.
|
||||
import sys, os
|
||||
pos = os.path.abspath(sys.argv[0]).find(os.sep+'Twisted')
|
||||
if pos != -1:
|
||||
sys.path.insert(0, os.path.abspath(sys.argv[0])[:pos+8])
|
||||
sys.path.insert(0, os.curdir)
|
||||
### end of preamble
|
||||
|
||||
|
||||
from twisted.python import log
|
||||
from zope.interface import implements
|
||||
from twisted.internet import interfaces
|
||||
|
||||
log.startLogging(sys.stderr)
|
||||
|
||||
from twisted.internet import protocol, reactor, stdio
|
||||
|
||||
|
||||
class Echo(protocol.Protocol):
|
||||
implements(interfaces.IHalfCloseableProtocol)
|
||||
|
||||
def connectionMade(self):
|
||||
print "connection made"
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.transport.write(data)
|
||||
|
||||
def readConnectionLost(self):
|
||||
print "readConnectionLost"
|
||||
self.transport.loseConnection()
|
||||
def writeConnectionLost(self):
|
||||
print "writeConnectionLost"
|
||||
|
||||
def connectionLost(self, reason):
|
||||
print "connectionLost", reason
|
||||
reactor.stop()
|
||||
|
||||
stdio.StandardIO(Echo())
|
||||
reactor.run()
|
||||
|
|
@ -0,0 +1,690 @@
|
|||
# -*- test-case-name: twisted.test.test_stringtransport -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Assorted functionality which is commonly useful when writing unit tests.
|
||||
"""
|
||||
|
||||
from __future__ import division, absolute_import
|
||||
|
||||
from socket import AF_INET, AF_INET6
|
||||
from io import BytesIO
|
||||
|
||||
from zope.interface import implementer, implementedBy
|
||||
from zope.interface.verify import verifyClass
|
||||
|
||||
from twisted.python import failure
|
||||
from twisted.python.compat import unicode
|
||||
from twisted.internet.interfaces import (
|
||||
ITransport, IConsumer, IPushProducer, IConnector, IReactorTCP, IReactorSSL,
|
||||
IReactorUNIX, IReactorSocket, IListeningPort, IReactorFDSet
|
||||
)
|
||||
from twisted.internet.abstract import isIPv6Address
|
||||
from twisted.internet.error import UnsupportedAddressFamily
|
||||
from twisted.protocols import basic
|
||||
from twisted.internet import protocol, error, address
|
||||
|
||||
from twisted.internet.task import Clock
|
||||
from twisted.internet.address import IPv4Address, UNIXAddress, IPv6Address
|
||||
|
||||
|
||||
class AccumulatingProtocol(protocol.Protocol):
|
||||
"""
|
||||
L{AccumulatingProtocol} is an L{IProtocol} implementation which collects
|
||||
the data delivered to it and can fire a Deferred when it is connected or
|
||||
disconnected.
|
||||
|
||||
@ivar made: A flag indicating whether C{connectionMade} has been called.
|
||||
@ivar data: Bytes giving all the data passed to C{dataReceived}.
|
||||
@ivar closed: A flag indicated whether C{connectionLost} has been called.
|
||||
@ivar closedReason: The value of the I{reason} parameter passed to
|
||||
C{connectionLost}.
|
||||
@ivar closedDeferred: If set to a L{Deferred}, this will be fired when
|
||||
C{connectionLost} is called.
|
||||
"""
|
||||
made = closed = 0
|
||||
closedReason = None
|
||||
|
||||
closedDeferred = None
|
||||
|
||||
data = b""
|
||||
|
||||
factory = None
|
||||
|
||||
def connectionMade(self):
|
||||
self.made = 1
|
||||
if (self.factory is not None and
|
||||
self.factory.protocolConnectionMade is not None):
|
||||
d = self.factory.protocolConnectionMade
|
||||
self.factory.protocolConnectionMade = None
|
||||
d.callback(self)
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.data += data
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.closed = 1
|
||||
self.closedReason = reason
|
||||
if self.closedDeferred is not None:
|
||||
d, self.closedDeferred = self.closedDeferred, None
|
||||
d.callback(None)
|
||||
|
||||
|
||||
class LineSendingProtocol(basic.LineReceiver):
|
||||
lostConn = False
|
||||
|
||||
def __init__(self, lines, start = True):
|
||||
self.lines = lines[:]
|
||||
self.response = []
|
||||
self.start = start
|
||||
|
||||
def connectionMade(self):
|
||||
if self.start:
|
||||
for line in self.lines:
|
||||
self.sendLine(line)
|
||||
|
||||
def lineReceived(self, line):
|
||||
if not self.start:
|
||||
for line in self.lines:
|
||||
self.sendLine(line)
|
||||
self.lines = []
|
||||
self.response.append(line)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.lostConn = True
|
||||
|
||||
|
||||
class FakeDatagramTransport:
|
||||
noAddr = object()
|
||||
|
||||
def __init__(self):
|
||||
self.written = []
|
||||
|
||||
def write(self, packet, addr=noAddr):
|
||||
self.written.append((packet, addr))
|
||||
|
||||
|
||||
|
||||
@implementer(ITransport, IConsumer, IPushProducer)
|
||||
class StringTransport:
|
||||
"""
|
||||
A transport implementation which buffers data in memory and keeps track of
|
||||
its other state without providing any behavior.
|
||||
|
||||
L{StringTransport} has a number of attributes which are not part of any of
|
||||
the interfaces it claims to implement. These attributes are provided for
|
||||
testing purposes. Implementation code should not use any of these
|
||||
attributes; they are not provided by other transports.
|
||||
|
||||
@ivar disconnecting: A C{bool} which is C{False} until L{loseConnection} is
|
||||
called, then C{True}.
|
||||
|
||||
@ivar producer: If a producer is currently registered, C{producer} is a
|
||||
reference to it. Otherwise, C{None}.
|
||||
|
||||
@ivar streaming: If a producer is currently registered, C{streaming} refers
|
||||
to the value of the second parameter passed to C{registerProducer}.
|
||||
|
||||
@ivar hostAddr: C{None} or an object which will be returned as the host
|
||||
address of this transport. If C{None}, a nasty tuple will be returned
|
||||
instead.
|
||||
|
||||
@ivar peerAddr: C{None} or an object which will be returned as the peer
|
||||
address of this transport. If C{None}, a nasty tuple will be returned
|
||||
instead.
|
||||
|
||||
@ivar producerState: The state of this L{StringTransport} in its capacity
|
||||
as an L{IPushProducer}. One of C{'producing'}, C{'paused'}, or
|
||||
C{'stopped'}.
|
||||
|
||||
@ivar io: A L{BytesIO} which holds the data which has been written to this
|
||||
transport since the last call to L{clear}. Use L{value} instead of
|
||||
accessing this directly.
|
||||
"""
|
||||
|
||||
disconnecting = False
|
||||
|
||||
producer = None
|
||||
streaming = None
|
||||
|
||||
hostAddr = None
|
||||
peerAddr = None
|
||||
|
||||
producerState = 'producing'
|
||||
|
||||
def __init__(self, hostAddress=None, peerAddress=None):
|
||||
self.clear()
|
||||
if hostAddress is not None:
|
||||
self.hostAddr = hostAddress
|
||||
if peerAddress is not None:
|
||||
self.peerAddr = peerAddress
|
||||
self.connected = True
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
Discard all data written to this transport so far.
|
||||
|
||||
This is not a transport method. It is intended for tests. Do not use
|
||||
it in implementation code.
|
||||
"""
|
||||
self.io = BytesIO()
|
||||
|
||||
|
||||
def value(self):
|
||||
"""
|
||||
Retrieve all data which has been buffered by this transport.
|
||||
|
||||
This is not a transport method. It is intended for tests. Do not use
|
||||
it in implementation code.
|
||||
|
||||
@return: A C{bytes} giving all data written to this transport since the
|
||||
last call to L{clear}.
|
||||
@rtype: C{bytes}
|
||||
"""
|
||||
return self.io.getvalue()
|
||||
|
||||
|
||||
# ITransport
|
||||
def write(self, data):
|
||||
if isinstance(data, unicode): # no, really, I mean it
|
||||
raise TypeError("Data must not be unicode")
|
||||
self.io.write(data)
|
||||
|
||||
|
||||
def writeSequence(self, data):
|
||||
self.io.write(b''.join(data))
|
||||
|
||||
|
||||
def loseConnection(self):
|
||||
"""
|
||||
Close the connection. Does nothing besides toggle the C{disconnecting}
|
||||
instance variable to C{True}.
|
||||
"""
|
||||
self.disconnecting = True
|
||||
|
||||
|
||||
def getPeer(self):
|
||||
if self.peerAddr is None:
|
||||
return address.IPv4Address('TCP', '192.168.1.1', 54321)
|
||||
return self.peerAddr
|
||||
|
||||
|
||||
def getHost(self):
|
||||
if self.hostAddr is None:
|
||||
return address.IPv4Address('TCP', '10.0.0.1', 12345)
|
||||
return self.hostAddr
|
||||
|
||||
|
||||
# IConsumer
|
||||
def registerProducer(self, producer, streaming):
|
||||
if self.producer is not None:
|
||||
raise RuntimeError("Cannot register two producers")
|
||||
self.producer = producer
|
||||
self.streaming = streaming
|
||||
|
||||
|
||||
def unregisterProducer(self):
|
||||
if self.producer is None:
|
||||
raise RuntimeError(
|
||||
"Cannot unregister a producer unless one is registered")
|
||||
self.producer = None
|
||||
self.streaming = None
|
||||
|
||||
|
||||
# IPushProducer
|
||||
def _checkState(self):
|
||||
if self.disconnecting:
|
||||
raise RuntimeError(
|
||||
"Cannot resume producing after loseConnection")
|
||||
if self.producerState == 'stopped':
|
||||
raise RuntimeError("Cannot resume a stopped producer")
|
||||
|
||||
|
||||
def pauseProducing(self):
|
||||
self._checkState()
|
||||
self.producerState = 'paused'
|
||||
|
||||
|
||||
def stopProducing(self):
|
||||
self.producerState = 'stopped'
|
||||
|
||||
|
||||
def resumeProducing(self):
|
||||
self._checkState()
|
||||
self.producerState = 'producing'
|
||||
|
||||
|
||||
|
||||
class StringTransportWithDisconnection(StringTransport):
|
||||
"""
|
||||
A L{StringTransport} which can be disconnected.
|
||||
"""
|
||||
|
||||
def loseConnection(self):
|
||||
if self.connected:
|
||||
self.connected = False
|
||||
self.protocol.connectionLost(
|
||||
failure.Failure(error.ConnectionDone("Bye.")))
|
||||
|
||||
|
||||
|
||||
class StringIOWithoutClosing(BytesIO):
|
||||
"""
|
||||
A BytesIO that can't be closed.
|
||||
"""
|
||||
def close(self):
|
||||
"""
|
||||
Do nothing.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@implementer(IListeningPort)
|
||||
class _FakePort(object):
|
||||
"""
|
||||
A fake L{IListeningPort} to be used in tests.
|
||||
|
||||
@ivar _hostAddress: The L{IAddress} this L{IListeningPort} is pretending
|
||||
to be listening on.
|
||||
"""
|
||||
|
||||
def __init__(self, hostAddress):
|
||||
"""
|
||||
@param hostAddress: An L{IAddress} this L{IListeningPort} should
|
||||
pretend to be listening on.
|
||||
"""
|
||||
self._hostAddress = hostAddress
|
||||
|
||||
|
||||
def startListening(self):
|
||||
"""
|
||||
Fake L{IListeningPort.startListening} that doesn't do anything.
|
||||
"""
|
||||
|
||||
|
||||
def stopListening(self):
|
||||
"""
|
||||
Fake L{IListeningPort.stopListening} that doesn't do anything.
|
||||
"""
|
||||
|
||||
|
||||
def getHost(self):
|
||||
"""
|
||||
Fake L{IListeningPort.getHost} that returns our L{IAddress}.
|
||||
"""
|
||||
return self._hostAddress
|
||||
|
||||
|
||||
|
||||
@implementer(IConnector)
|
||||
class _FakeConnector(object):
|
||||
"""
|
||||
A fake L{IConnector} that allows us to inspect if it has been told to stop
|
||||
connecting.
|
||||
|
||||
@ivar stoppedConnecting: has this connector's
|
||||
L{FakeConnector.stopConnecting} method been invoked yet?
|
||||
|
||||
@ivar _address: An L{IAddress} provider that represents our destination.
|
||||
"""
|
||||
_disconnected = False
|
||||
stoppedConnecting = False
|
||||
|
||||
def __init__(self, address):
|
||||
"""
|
||||
@param address: An L{IAddress} provider that represents this
|
||||
connector's destination.
|
||||
"""
|
||||
self._address = address
|
||||
|
||||
|
||||
def stopConnecting(self):
|
||||
"""
|
||||
Implement L{IConnector.stopConnecting} and set
|
||||
L{FakeConnector.stoppedConnecting} to C{True}
|
||||
"""
|
||||
self.stoppedConnecting = True
|
||||
|
||||
|
||||
def disconnect(self):
|
||||
"""
|
||||
Implement L{IConnector.disconnect} as a no-op.
|
||||
"""
|
||||
self._disconnected = True
|
||||
|
||||
|
||||
def connect(self):
|
||||
"""
|
||||
Implement L{IConnector.connect} as a no-op.
|
||||
"""
|
||||
|
||||
|
||||
def getDestination(self):
|
||||
"""
|
||||
Implement L{IConnector.getDestination} to return the C{address} passed
|
||||
to C{__init__}.
|
||||
"""
|
||||
return self._address
|
||||
|
||||
|
||||
|
||||
@implementer(
|
||||
IReactorTCP, IReactorSSL, IReactorUNIX, IReactorSocket, IReactorFDSet
|
||||
)
|
||||
class MemoryReactor(object):
|
||||
"""
|
||||
A fake reactor to be used in tests. This reactor doesn't actually do
|
||||
much that's useful yet. It accepts TCP connection setup attempts, but
|
||||
they will never succeed.
|
||||
|
||||
@ivar tcpClients: a list that keeps track of connection attempts (ie, calls
|
||||
to C{connectTCP}).
|
||||
@type tcpClients: C{list}
|
||||
|
||||
@ivar tcpServers: a list that keeps track of server listen attempts (ie, calls
|
||||
to C{listenTCP}).
|
||||
@type tcpServers: C{list}
|
||||
|
||||
@ivar sslClients: a list that keeps track of connection attempts (ie,
|
||||
calls to C{connectSSL}).
|
||||
@type sslClients: C{list}
|
||||
|
||||
@ivar sslServers: a list that keeps track of server listen attempts (ie,
|
||||
calls to C{listenSSL}).
|
||||
@type sslServers: C{list}
|
||||
|
||||
@ivar unixClients: a list that keeps track of connection attempts (ie,
|
||||
calls to C{connectUNIX}).
|
||||
@type unixClients: C{list}
|
||||
|
||||
@ivar unixServers: a list that keeps track of server listen attempts (ie,
|
||||
calls to C{listenUNIX}).
|
||||
@type unixServers: C{list}
|
||||
|
||||
@ivar adoptedPorts: a list that keeps track of server listen attempts (ie,
|
||||
calls to C{adoptStreamPort}).
|
||||
|
||||
@ivar adoptedStreamConnections: a list that keeps track of stream-oriented
|
||||
connections added using C{adoptStreamConnection}.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the tracking lists.
|
||||
"""
|
||||
self.tcpClients = []
|
||||
self.tcpServers = []
|
||||
self.sslClients = []
|
||||
self.sslServers = []
|
||||
self.unixClients = []
|
||||
self.unixServers = []
|
||||
self.adoptedPorts = []
|
||||
self.adoptedStreamConnections = []
|
||||
self.connectors = []
|
||||
|
||||
self.readers = set()
|
||||
self.writers = set()
|
||||
|
||||
|
||||
def adoptStreamPort(self, fileno, addressFamily, factory):
|
||||
"""
|
||||
Fake L{IReactorSocket.adoptStreamPort}, that logs the call and returns
|
||||
an L{IListeningPort}.
|
||||
"""
|
||||
if addressFamily == AF_INET:
|
||||
addr = IPv4Address('TCP', '0.0.0.0', 1234)
|
||||
elif addressFamily == AF_INET6:
|
||||
addr = IPv6Address('TCP', '::', 1234)
|
||||
else:
|
||||
raise UnsupportedAddressFamily()
|
||||
|
||||
self.adoptedPorts.append((fileno, addressFamily, factory))
|
||||
return _FakePort(addr)
|
||||
|
||||
|
||||
def adoptStreamConnection(self, fileDescriptor, addressFamily, factory):
|
||||
"""
|
||||
Record the given stream connection in C{adoptedStreamConnections}.
|
||||
|
||||
@see: L{twisted.internet.interfaces.IReactorSocket.adoptStreamConnection}
|
||||
"""
|
||||
self.adoptedStreamConnections.append((
|
||||
fileDescriptor, addressFamily, factory))
|
||||
|
||||
|
||||
def adoptDatagramPort(self, fileno, addressFamily, protocol,
|
||||
maxPacketSize=8192):
|
||||
"""
|
||||
Fake L{IReactorSocket.adoptDatagramPort}, that logs the call and returns
|
||||
a fake L{IListeningPort}.
|
||||
|
||||
@see: L{twisted.internet.interfaces.IReactorSocket.adoptDatagramPort}
|
||||
"""
|
||||
if addressFamily == AF_INET:
|
||||
addr = IPv4Address('UDP', '0.0.0.0', 1234)
|
||||
elif addressFamily == AF_INET6:
|
||||
addr = IPv6Address('UDP', '::', 1234)
|
||||
else:
|
||||
raise UnsupportedAddressFamily()
|
||||
|
||||
self.adoptedPorts.append(
|
||||
(fileno, addressFamily, protocol, maxPacketSize))
|
||||
return _FakePort(addr)
|
||||
|
||||
|
||||
def listenTCP(self, port, factory, backlog=50, interface=''):
|
||||
"""
|
||||
Fake L{reactor.listenTCP}, that logs the call and returns an
|
||||
L{IListeningPort}.
|
||||
"""
|
||||
self.tcpServers.append((port, factory, backlog, interface))
|
||||
if isIPv6Address(interface):
|
||||
address = IPv6Address('TCP', interface, port)
|
||||
else:
|
||||
address = IPv4Address('TCP', '0.0.0.0', port)
|
||||
return _FakePort(address)
|
||||
|
||||
|
||||
def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
|
||||
"""
|
||||
Fake L{reactor.connectTCP}, that logs the call and returns an
|
||||
L{IConnector}.
|
||||
"""
|
||||
self.tcpClients.append((host, port, factory, timeout, bindAddress))
|
||||
if isIPv6Address(host):
|
||||
conn = _FakeConnector(IPv6Address('TCP', host, port))
|
||||
else:
|
||||
conn = _FakeConnector(IPv4Address('TCP', host, port))
|
||||
factory.startedConnecting(conn)
|
||||
self.connectors.append(conn)
|
||||
return conn
|
||||
|
||||
|
||||
def listenSSL(self, port, factory, contextFactory,
|
||||
backlog=50, interface=''):
|
||||
"""
|
||||
Fake L{reactor.listenSSL}, that logs the call and returns an
|
||||
L{IListeningPort}.
|
||||
"""
|
||||
self.sslServers.append((port, factory, contextFactory,
|
||||
backlog, interface))
|
||||
return _FakePort(IPv4Address('TCP', '0.0.0.0', port))
|
||||
|
||||
|
||||
def connectSSL(self, host, port, factory, contextFactory,
|
||||
timeout=30, bindAddress=None):
|
||||
"""
|
||||
Fake L{reactor.connectSSL}, that logs the call and returns an
|
||||
L{IConnector}.
|
||||
"""
|
||||
self.sslClients.append((host, port, factory, contextFactory,
|
||||
timeout, bindAddress))
|
||||
conn = _FakeConnector(IPv4Address('TCP', host, port))
|
||||
factory.startedConnecting(conn)
|
||||
self.connectors.append(conn)
|
||||
return conn
|
||||
|
||||
|
||||
def listenUNIX(self, address, factory,
|
||||
backlog=50, mode=0o666, wantPID=0):
|
||||
"""
|
||||
Fake L{reactor.listenUNIX}, that logs the call and returns an
|
||||
L{IListeningPort}.
|
||||
"""
|
||||
self.unixServers.append((address, factory, backlog, mode, wantPID))
|
||||
return _FakePort(UNIXAddress(address))
|
||||
|
||||
|
||||
def connectUNIX(self, address, factory, timeout=30, checkPID=0):
|
||||
"""
|
||||
Fake L{reactor.connectUNIX}, that logs the call and returns an
|
||||
L{IConnector}.
|
||||
"""
|
||||
self.unixClients.append((address, factory, timeout, checkPID))
|
||||
conn = _FakeConnector(UNIXAddress(address))
|
||||
factory.startedConnecting(conn)
|
||||
self.connectors.append(conn)
|
||||
return conn
|
||||
|
||||
|
||||
def addReader(self, reader):
|
||||
"""
|
||||
Fake L{IReactorFDSet.addReader} which adds the reader to a local set.
|
||||
"""
|
||||
self.readers.add(reader)
|
||||
|
||||
|
||||
def removeReader(self, reader):
|
||||
"""
|
||||
Fake L{IReactorFDSet.removeReader} which removes the reader from a
|
||||
local set.
|
||||
"""
|
||||
self.readers.discard(reader)
|
||||
|
||||
|
||||
def addWriter(self, writer):
|
||||
"""
|
||||
Fake L{IReactorFDSet.addWriter} which adds the writer to a local set.
|
||||
"""
|
||||
self.writers.add(writer)
|
||||
|
||||
|
||||
def removeWriter(self, writer):
|
||||
"""
|
||||
Fake L{IReactorFDSet.removeWriter} which removes the writer from a
|
||||
local set.
|
||||
"""
|
||||
self.writers.discard(writer)
|
||||
|
||||
|
||||
def getReaders(self):
|
||||
"""
|
||||
Fake L{IReactorFDSet.getReaders} which returns a list of readers from
|
||||
the local set.
|
||||
"""
|
||||
return list(self.readers)
|
||||
|
||||
|
||||
def getWriters(self):
|
||||
"""
|
||||
Fake L{IReactorFDSet.getWriters} which returns a list of writers from
|
||||
the local set.
|
||||
"""
|
||||
return list(self.writers)
|
||||
|
||||
|
||||
def removeAll(self):
|
||||
"""
|
||||
Fake L{IReactorFDSet.removeAll} which removed all readers and writers
|
||||
from the local sets.
|
||||
"""
|
||||
self.readers.clear()
|
||||
self.writers.clear()
|
||||
|
||||
|
||||
for iface in implementedBy(MemoryReactor):
|
||||
verifyClass(iface, MemoryReactor)
|
||||
|
||||
|
||||
|
||||
class MemoryReactorClock(MemoryReactor, Clock):
|
||||
def __init__(self):
|
||||
MemoryReactor.__init__(self)
|
||||
Clock.__init__(self)
|
||||
|
||||
|
||||
|
||||
@implementer(IReactorTCP, IReactorSSL, IReactorUNIX, IReactorSocket)
|
||||
class RaisingMemoryReactor(object):
|
||||
"""
|
||||
A fake reactor to be used in tests. It accepts TCP connection setup
|
||||
attempts, but they will fail.
|
||||
|
||||
@ivar _listenException: An instance of an L{Exception}
|
||||
@ivar _connectException: An instance of an L{Exception}
|
||||
"""
|
||||
|
||||
def __init__(self, listenException=None, connectException=None):
|
||||
"""
|
||||
@param listenException: An instance of an L{Exception} to raise when any
|
||||
C{listen} method is called.
|
||||
|
||||
@param connectException: An instance of an L{Exception} to raise when
|
||||
any C{connect} method is called.
|
||||
"""
|
||||
self._listenException = listenException
|
||||
self._connectException = connectException
|
||||
|
||||
|
||||
def adoptStreamPort(self, fileno, addressFamily, factory):
|
||||
"""
|
||||
Fake L{IReactorSocket.adoptStreamPort}, that raises
|
||||
L{self._listenException}.
|
||||
"""
|
||||
raise self._listenException
|
||||
|
||||
|
||||
def listenTCP(self, port, factory, backlog=50, interface=''):
|
||||
"""
|
||||
Fake L{reactor.listenTCP}, that raises L{self._listenException}.
|
||||
"""
|
||||
raise self._listenException
|
||||
|
||||
|
||||
def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
|
||||
"""
|
||||
Fake L{reactor.connectTCP}, that raises L{self._connectException}.
|
||||
"""
|
||||
raise self._connectException
|
||||
|
||||
|
||||
def listenSSL(self, port, factory, contextFactory,
|
||||
backlog=50, interface=''):
|
||||
"""
|
||||
Fake L{reactor.listenSSL}, that raises L{self._listenException}.
|
||||
"""
|
||||
raise self._listenException
|
||||
|
||||
|
||||
def connectSSL(self, host, port, factory, contextFactory,
|
||||
timeout=30, bindAddress=None):
|
||||
"""
|
||||
Fake L{reactor.connectSSL}, that raises L{self._connectException}.
|
||||
"""
|
||||
raise self._connectException
|
||||
|
||||
|
||||
def listenUNIX(self, address, factory,
|
||||
backlog=50, mode=0o666, wantPID=0):
|
||||
"""
|
||||
Fake L{reactor.listenUNIX}, that raises L{self._listenException}.
|
||||
"""
|
||||
raise self._listenException
|
||||
|
||||
|
||||
def connectUNIX(self, address, factory, timeout=30, checkPID=0):
|
||||
"""
|
||||
Fake L{reactor.connectUNIX}, that raises L{self._connectException}.
|
||||
"""
|
||||
raise self._connectException
|
||||
1443
Linux_i686/lib/python2.7/site-packages/twisted/test/raiser.c
Normal file
1443
Linux_i686/lib/python2.7/site-packages/twisted/test/raiser.c
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
A trivial extension that just raises an exception.
|
||||
See L{twisted.test.test_failure.test_failureConstructionWithMungedStackSucceeds}.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class RaiserException(Exception):
|
||||
"""
|
||||
A speficic exception only used to be identified in tests.
|
||||
"""
|
||||
|
||||
|
||||
def raiseException():
|
||||
"""
|
||||
Raise L{RaiserException}.
|
||||
"""
|
||||
raise RaiserException("This function is intentionally broken")
|
||||
BIN
Linux_i686/lib/python2.7/site-packages/twisted/test/raiser.so
Executable file
BIN
Linux_i686/lib/python2.7/site-packages/twisted/test/raiser.so
Executable file
Binary file not shown.
|
|
@ -0,0 +1,4 @@
|
|||
|
||||
# Helper for a test_reflect test
|
||||
|
||||
import idonotexist
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
|
||||
# Helper for a test_reflect test
|
||||
|
||||
raise ValueError("Stuff is broken and things")
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
|
||||
# Helper module for a test_reflect test
|
||||
|
||||
1//0
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
-----BEGIN CERTIFICATE-----
|
||||
MIIDBjCCAm+gAwIBAgIBATANBgkqhkiG9w0BAQQFADB7MQswCQYDVQQGEwJTRzER
|
||||
MA8GA1UEChMITTJDcnlwdG8xFDASBgNVBAsTC00yQ3J5cHRvIENBMSQwIgYDVQQD
|
||||
ExtNMkNyeXB0byBDZXJ0aWZpY2F0ZSBNYXN0ZXIxHTAbBgkqhkiG9w0BCQEWDm5n
|
||||
cHNAcG9zdDEuY29tMB4XDTAwMDkxMDA5NTEzMFoXDTAyMDkxMDA5NTEzMFowUzEL
|
||||
MAkGA1UEBhMCU0cxETAPBgNVBAoTCE0yQ3J5cHRvMRIwEAYDVQQDEwlsb2NhbGhv
|
||||
c3QxHTAbBgkqhkiG9w0BCQEWDm5ncHNAcG9zdDEuY29tMFwwDQYJKoZIhvcNAQEB
|
||||
BQADSwAwSAJBAKy+e3dulvXzV7zoTZWc5TzgApr8DmeQHTYC8ydfzH7EECe4R1Xh
|
||||
5kwIzOuuFfn178FBiS84gngaNcrFi0Z5fAkCAwEAAaOCAQQwggEAMAkGA1UdEwQC
|
||||
MAAwLAYJYIZIAYb4QgENBB8WHU9wZW5TU0wgR2VuZXJhdGVkIENlcnRpZmljYXRl
|
||||
MB0GA1UdDgQWBBTPhIKSvnsmYsBVNWjj0m3M2z0qVTCBpQYDVR0jBIGdMIGagBT7
|
||||
hyNp65w6kxXlxb8pUU/+7Sg4AaF/pH0wezELMAkGA1UEBhMCU0cxETAPBgNVBAoT
|
||||
CE0yQ3J5cHRvMRQwEgYDVQQLEwtNMkNyeXB0byBDQTEkMCIGA1UEAxMbTTJDcnlw
|
||||
dG8gQ2VydGlmaWNhdGUgTWFzdGVyMR0wGwYJKoZIhvcNAQkBFg5uZ3BzQHBvc3Qx
|
||||
LmNvbYIBADANBgkqhkiG9w0BAQQFAAOBgQA7/CqT6PoHycTdhEStWNZde7M/2Yc6
|
||||
BoJuVwnW8YxGO8Sn6UJ4FeffZNcYZddSDKosw8LtPOeWoK3JINjAk5jiPQ2cww++
|
||||
7QGG/g5NDjxFZNDJP1dGiLAxPW6JXwov4v0FmdzfLOZ01jDcgQQZqEpYlgpuI5JE
|
||||
WUQ9Ho4EzbYCOQ==
|
||||
-----END CERTIFICATE-----
|
||||
-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIBPAIBAAJBAKy+e3dulvXzV7zoTZWc5TzgApr8DmeQHTYC8ydfzH7EECe4R1Xh
|
||||
5kwIzOuuFfn178FBiS84gngaNcrFi0Z5fAkCAwEAAQJBAIqm/bz4NA1H++Vx5Ewx
|
||||
OcKp3w19QSaZAwlGRtsUxrP7436QjnREM3Bm8ygU11BjkPVmtrKm6AayQfCHqJoT
|
||||
ZIECIQDW0BoMoL0HOYM/mrTLhaykYAVqgIeJsPjvkEhTFXWBuQIhAM3deFAvWNu4
|
||||
nklUQ37XsCT2c9tmNt1LAT+slG2JOTTRAiAuXDtC/m3NYVwyHfFm+zKHRzHkClk2
|
||||
HjubeEgjpj32AQIhAJqMGTaZVOwevTXvvHwNEH+vRWsAYU/gbx+OQB+7VOcBAiEA
|
||||
oolb6NMg/R3enNPvS1O4UU1H8wpaF77L4yiSWlE0p4w=
|
||||
-----END RSA PRIVATE KEY-----
|
||||
-----BEGIN CERTIFICATE REQUEST-----
|
||||
MIIBDTCBuAIBADBTMQswCQYDVQQGEwJTRzERMA8GA1UEChMITTJDcnlwdG8xEjAQ
|
||||
BgNVBAMTCWxvY2FsaG9zdDEdMBsGCSqGSIb3DQEJARYObmdwc0Bwb3N0MS5jb20w
|
||||
XDANBgkqhkiG9w0BAQEFAANLADBIAkEArL57d26W9fNXvOhNlZzlPOACmvwOZ5Ad
|
||||
NgLzJ1/MfsQQJ7hHVeHmTAjM664V+fXvwUGJLziCeBo1ysWLRnl8CQIDAQABoAAw
|
||||
DQYJKoZIhvcNAQEEBQADQQA7uqbrNTjVWpF6By5ZNPvhZ4YdFgkeXFVWi5ao/TaP
|
||||
Vq4BG021fJ9nlHRtr4rotpgHDX1rr+iWeHKsx4+5DRSy
|
||||
-----END CERTIFICATE REQUEST-----
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Helper classes for twisted.test.test_ssl.
|
||||
|
||||
They are in a separate module so they will not prevent test_ssl importing if
|
||||
pyOpenSSL is unavailable.
|
||||
"""
|
||||
from __future__ import division, absolute_import
|
||||
|
||||
from twisted.python.compat import nativeString
|
||||
from twisted.internet import ssl
|
||||
from twisted.python.filepath import FilePath
|
||||
|
||||
from OpenSSL import SSL
|
||||
|
||||
certPath = nativeString(FilePath(__file__.encode("utf-8")
|
||||
).sibling(b"server.pem").path)
|
||||
|
||||
|
||||
class ClientTLSContext(ssl.ClientContextFactory):
|
||||
isClient = 1
|
||||
def getContext(self):
|
||||
return SSL.Context(SSL.TLSv1_METHOD)
|
||||
|
||||
class ServerTLSContext:
|
||||
isClient = 0
|
||||
|
||||
def __init__(self, filename=certPath):
|
||||
self.filename = filename
|
||||
|
||||
def getContext(self):
|
||||
ctx = SSL.Context(SSL.TLSv1_METHOD)
|
||||
ctx.use_certificate_file(self.filename)
|
||||
ctx.use_privatekey_file(self.filename)
|
||||
return ctx
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTestCase.test_consumer -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Main program for the child process run by
|
||||
L{twisted.test.test_stdio.StandardInputOutputTestCase.test_consumer} to test
|
||||
that process transports implement IConsumer properly.
|
||||
"""
|
||||
|
||||
import sys, _preamble
|
||||
|
||||
from twisted.python import log, reflect
|
||||
from twisted.internet import stdio, protocol
|
||||
from twisted.protocols import basic
|
||||
|
||||
def failed(err):
|
||||
log.startLogging(sys.stderr)
|
||||
log.err(err)
|
||||
|
||||
class ConsumerChild(protocol.Protocol):
|
||||
def __init__(self, junkPath):
|
||||
self.junkPath = junkPath
|
||||
|
||||
def connectionMade(self):
|
||||
d = basic.FileSender().beginFileTransfer(file(self.junkPath), self.transport)
|
||||
d.addErrback(failed)
|
||||
d.addCallback(lambda ign: self.transport.loseConnection())
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
reactor.stop()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
reflect.namedAny(sys.argv[1]).install()
|
||||
from twisted.internet import reactor
|
||||
stdio.StandardIO(ConsumerChild(sys.argv[2]))
|
||||
reactor.run()
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTestCase.test_readConnectionLost -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Main program for the child process run by
|
||||
L{twisted.test.test_stdio.StandardInputOutputTestCase.test_readConnectionLost}
|
||||
to test that IHalfCloseableProtocol.readConnectionLost works for process
|
||||
transports.
|
||||
"""
|
||||
|
||||
import sys, _preamble
|
||||
|
||||
from zope.interface import implements
|
||||
|
||||
from twisted.internet.interfaces import IHalfCloseableProtocol
|
||||
from twisted.internet import stdio, protocol
|
||||
from twisted.python import reflect, log
|
||||
|
||||
|
||||
class HalfCloseProtocol(protocol.Protocol):
|
||||
"""
|
||||
A protocol to hook up to stdio and observe its transport being
|
||||
half-closed. If all goes as expected, C{exitCode} will be set to C{0};
|
||||
otherwise it will be set to C{1} to indicate failure.
|
||||
"""
|
||||
implements(IHalfCloseableProtocol)
|
||||
|
||||
exitCode = None
|
||||
|
||||
def connectionMade(self):
|
||||
"""
|
||||
Signal the parent process that we're ready.
|
||||
"""
|
||||
self.transport.write("x")
|
||||
|
||||
|
||||
def readConnectionLost(self):
|
||||
"""
|
||||
This is the desired event. Once it has happened, stop the reactor so
|
||||
the process will exit.
|
||||
"""
|
||||
self.exitCode = 0
|
||||
reactor.stop()
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
"""
|
||||
This may only be invoked after C{readConnectionLost}. If it happens
|
||||
otherwise, mark it as an error and shut down.
|
||||
"""
|
||||
if self.exitCode is None:
|
||||
self.exitCode = 1
|
||||
log.err(reason, "Unexpected call to connectionLost")
|
||||
reactor.stop()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
reflect.namedAny(sys.argv[1]).install()
|
||||
log.startLogging(file(sys.argv[2], 'w'))
|
||||
from twisted.internet import reactor
|
||||
protocol = HalfCloseProtocol()
|
||||
stdio.StandardIO(protocol)
|
||||
reactor.run()
|
||||
sys.exit(protocol.exitCode)
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTestCase.test_hostAndPeer -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Main program for the child process run by
|
||||
L{twisted.test.test_stdio.StandardInputOutputTestCase.test_hostAndPeer} to test
|
||||
that ITransport.getHost() and ITransport.getPeer() work for process transports.
|
||||
"""
|
||||
|
||||
import sys, _preamble
|
||||
|
||||
from twisted.internet import stdio, protocol
|
||||
from twisted.python import reflect
|
||||
|
||||
class HostPeerChild(protocol.Protocol):
|
||||
def connectionMade(self):
|
||||
self.transport.write('\n'.join([
|
||||
str(self.transport.getHost()),
|
||||
str(self.transport.getPeer())]))
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
reactor.stop()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
reflect.namedAny(sys.argv[1]).install()
|
||||
from twisted.internet import reactor
|
||||
stdio.StandardIO(HostPeerChild())
|
||||
reactor.run()
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTestCase.test_lastWriteReceived -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Main program for the child process run by
|
||||
L{twisted.test.test_stdio.StandardInputOutputTestCase.test_lastWriteReceived}
|
||||
to test that L{os.write} can be reliably used after
|
||||
L{twisted.internet.stdio.StandardIO} has finished.
|
||||
"""
|
||||
|
||||
import sys, _preamble
|
||||
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.internet.stdio import StandardIO
|
||||
from twisted.python.reflect import namedAny
|
||||
|
||||
|
||||
class LastWriteChild(Protocol):
|
||||
def __init__(self, reactor, magicString):
|
||||
self.reactor = reactor
|
||||
self.magicString = magicString
|
||||
|
||||
|
||||
def connectionMade(self):
|
||||
self.transport.write(self.magicString)
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.reactor.stop()
|
||||
|
||||
|
||||
|
||||
def main(reactor, magicString):
|
||||
p = LastWriteChild(reactor, magicString)
|
||||
StandardIO(p)
|
||||
reactor.run()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
namedAny(sys.argv[1]).install()
|
||||
from twisted.internet import reactor
|
||||
main(reactor, sys.argv[2])
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTestCase.test_loseConnection -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Main program for the child process run by
|
||||
L{twisted.test.test_stdio.StandardInputOutputTestCase.test_loseConnection} to
|
||||
test that ITransport.loseConnection() works for process transports.
|
||||
"""
|
||||
|
||||
import sys, _preamble
|
||||
|
||||
from twisted.internet.error import ConnectionDone
|
||||
from twisted.internet import stdio, protocol
|
||||
from twisted.python import reflect, log
|
||||
|
||||
class LoseConnChild(protocol.Protocol):
|
||||
exitCode = 0
|
||||
|
||||
def connectionMade(self):
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
"""
|
||||
Check that C{reason} is a L{Failure} wrapping a L{ConnectionDone}
|
||||
instance and stop the reactor. If C{reason} is wrong for some reason,
|
||||
log something about that in C{self.errorLogFile} and make sure the
|
||||
process exits with a non-zero status.
|
||||
"""
|
||||
try:
|
||||
try:
|
||||
reason.trap(ConnectionDone)
|
||||
except:
|
||||
log.err(None, "Problem with reason passed to connectionLost")
|
||||
self.exitCode = 1
|
||||
finally:
|
||||
reactor.stop()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
reflect.namedAny(sys.argv[1]).install()
|
||||
log.startLogging(file(sys.argv[2], 'w'))
|
||||
from twisted.internet import reactor
|
||||
protocol = LoseConnChild()
|
||||
stdio.StandardIO(protocol)
|
||||
reactor.run()
|
||||
sys.exit(protocol.exitCode)
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTestCase.test_producer -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Main program for the child process run by
|
||||
L{twisted.test.test_stdio.StandardInputOutputTestCase.test_producer} to test
|
||||
that process transports implement IProducer properly.
|
||||
"""
|
||||
|
||||
import sys, _preamble
|
||||
|
||||
from twisted.internet import stdio, protocol
|
||||
from twisted.python import log, reflect
|
||||
|
||||
class ProducerChild(protocol.Protocol):
|
||||
_paused = False
|
||||
buf = ''
|
||||
|
||||
def connectionLost(self, reason):
|
||||
log.msg("*****OVER*****")
|
||||
reactor.callLater(1, reactor.stop)
|
||||
# reactor.stop()
|
||||
|
||||
|
||||
def dataReceived(self, bytes):
|
||||
self.buf += bytes
|
||||
if self._paused:
|
||||
log.startLogging(sys.stderr)
|
||||
log.msg("dataReceived while transport paused!")
|
||||
self.transport.loseConnection()
|
||||
else:
|
||||
self.transport.write(bytes)
|
||||
if self.buf.endswith('\n0\n'):
|
||||
self.transport.loseConnection()
|
||||
else:
|
||||
self.pause()
|
||||
|
||||
|
||||
def pause(self):
|
||||
self._paused = True
|
||||
self.transport.pauseProducing()
|
||||
reactor.callLater(0.01, self.unpause)
|
||||
|
||||
|
||||
def unpause(self):
|
||||
self._paused = False
|
||||
self.transport.resumeProducing()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
reflect.namedAny(sys.argv[1]).install()
|
||||
from twisted.internet import reactor
|
||||
stdio.StandardIO(ProducerChild())
|
||||
reactor.run()
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTestCase.test_write -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Main program for the child process run by
|
||||
L{twisted.test.test_stdio.StandardInputOutputTestCase.test_write} to test that
|
||||
ITransport.write() works for process transports.
|
||||
"""
|
||||
|
||||
import sys, _preamble
|
||||
|
||||
from twisted.internet import stdio, protocol
|
||||
from twisted.python import reflect
|
||||
|
||||
class WriteChild(protocol.Protocol):
|
||||
def connectionMade(self):
|
||||
for ch in 'ok!':
|
||||
self.transport.write(ch)
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
reactor.stop()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
reflect.namedAny(sys.argv[1]).install()
|
||||
from twisted.internet import reactor
|
||||
stdio.StandardIO(WriteChild())
|
||||
reactor.run()
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTestCase.test_writeSequence -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Main program for the child process run by
|
||||
L{twisted.test.test_stdio.StandardInputOutputTestCase.test_writeSequence} to test that
|
||||
ITransport.writeSequence() works for process transports.
|
||||
"""
|
||||
|
||||
import sys, _preamble
|
||||
|
||||
from twisted.internet import stdio, protocol
|
||||
from twisted.python import reflect
|
||||
|
||||
class WriteSequenceChild(protocol.Protocol):
|
||||
def connectionMade(self):
|
||||
self.transport.writeSequence(list('ok!'))
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
reactor.stop()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
reflect.namedAny(sys.argv[1]).install()
|
||||
from twisted.internet import reactor
|
||||
stdio.StandardIO(WriteSequenceChild())
|
||||
reactor.run()
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for generic file descriptor based reactor support code.
|
||||
"""
|
||||
|
||||
from __future__ import division, absolute_import
|
||||
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
from twisted.internet.abstract import isIPAddress
|
||||
|
||||
|
||||
class AddressTests(TestCase):
|
||||
"""
|
||||
Tests for address-related functionality.
|
||||
"""
|
||||
def test_decimalDotted(self):
|
||||
"""
|
||||
L{isIPAddress} should return C{True} for any decimal dotted
|
||||
representation of an IPv4 address.
|
||||
"""
|
||||
self.assertTrue(isIPAddress('0.1.2.3'))
|
||||
self.assertTrue(isIPAddress('252.253.254.255'))
|
||||
|
||||
|
||||
def test_shortDecimalDotted(self):
|
||||
"""
|
||||
L{isIPAddress} should return C{False} for a dotted decimal
|
||||
representation with fewer or more than four octets.
|
||||
"""
|
||||
self.assertFalse(isIPAddress('0'))
|
||||
self.assertFalse(isIPAddress('0.1'))
|
||||
self.assertFalse(isIPAddress('0.1.2'))
|
||||
self.assertFalse(isIPAddress('0.1.2.3.4'))
|
||||
|
||||
|
||||
def test_invalidLetters(self):
|
||||
"""
|
||||
L{isIPAddress} should return C{False} for any non-decimal dotted
|
||||
representation including letters.
|
||||
"""
|
||||
self.assertFalse(isIPAddress('a.2.3.4'))
|
||||
self.assertFalse(isIPAddress('1.b.3.4'))
|
||||
|
||||
|
||||
def test_invalidPunctuation(self):
|
||||
"""
|
||||
L{isIPAddress} should return C{False} for a string containing
|
||||
strange punctuation.
|
||||
"""
|
||||
self.assertFalse(isIPAddress(','))
|
||||
self.assertFalse(isIPAddress('1,2'))
|
||||
self.assertFalse(isIPAddress('1,2,3'))
|
||||
self.assertFalse(isIPAddress('1.,.3,4'))
|
||||
|
||||
|
||||
def test_emptyString(self):
|
||||
"""
|
||||
L{isIPAddress} should return C{False} for the empty string.
|
||||
"""
|
||||
self.assertFalse(isIPAddress(''))
|
||||
|
||||
|
||||
def test_invalidNegative(self):
|
||||
"""
|
||||
L{isIPAddress} should return C{False} for negative decimal values.
|
||||
"""
|
||||
self.assertFalse(isIPAddress('-1'))
|
||||
self.assertFalse(isIPAddress('1.-2'))
|
||||
self.assertFalse(isIPAddress('1.2.-3'))
|
||||
self.assertFalse(isIPAddress('1.2.-3.4'))
|
||||
|
||||
|
||||
def test_invalidPositive(self):
|
||||
"""
|
||||
L{isIPAddress} should return C{False} for a string containing
|
||||
positive decimal values greater than 255.
|
||||
"""
|
||||
self.assertFalse(isIPAddress('256.0.0.0'))
|
||||
self.assertFalse(isIPAddress('0.256.0.0'))
|
||||
self.assertFalse(isIPAddress('0.0.256.0'))
|
||||
self.assertFalse(isIPAddress('0.0.0.256'))
|
||||
self.assertFalse(isIPAddress('256.256.256.256'))
|
||||
|
|
@ -0,0 +1,819 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""
|
||||
Tests for twisted.enterprise.adbapi.
|
||||
"""
|
||||
|
||||
from twisted.trial import unittest
|
||||
|
||||
import os, stat
|
||||
import types
|
||||
|
||||
from twisted.enterprise.adbapi import ConnectionPool, ConnectionLost
|
||||
from twisted.enterprise.adbapi import Connection, Transaction
|
||||
from twisted.internet import reactor, defer, interfaces
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
|
||||
simple_table_schema = """
|
||||
CREATE TABLE simple (
|
||||
x integer
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
class ADBAPITestBase:
|
||||
"""Test the asynchronous DB-API code."""
|
||||
|
||||
openfun_called = {}
|
||||
|
||||
if interfaces.IReactorThreads(reactor, None) is None:
|
||||
skip = "ADB-API requires threads, no way to test without them"
|
||||
|
||||
def extraSetUp(self):
|
||||
"""
|
||||
Set up the database and create a connection pool pointing at it.
|
||||
"""
|
||||
self.startDB()
|
||||
self.dbpool = self.makePool(cp_openfun=self.openfun)
|
||||
self.dbpool.start()
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
d = self.dbpool.runOperation('DROP TABLE simple')
|
||||
d.addCallback(lambda res: self.dbpool.close())
|
||||
d.addCallback(lambda res: self.stopDB())
|
||||
return d
|
||||
|
||||
def openfun(self, conn):
|
||||
self.openfun_called[conn] = True
|
||||
|
||||
def checkOpenfunCalled(self, conn=None):
|
||||
if not conn:
|
||||
self.failUnless(self.openfun_called)
|
||||
else:
|
||||
self.failUnless(self.openfun_called.has_key(conn))
|
||||
|
||||
def testPool(self):
|
||||
d = self.dbpool.runOperation(simple_table_schema)
|
||||
if self.test_failures:
|
||||
d.addCallback(self._testPool_1_1)
|
||||
d.addCallback(self._testPool_1_2)
|
||||
d.addCallback(self._testPool_1_3)
|
||||
d.addCallback(self._testPool_1_4)
|
||||
d.addCallback(lambda res: self.flushLoggedErrors())
|
||||
d.addCallback(self._testPool_2)
|
||||
d.addCallback(self._testPool_3)
|
||||
d.addCallback(self._testPool_4)
|
||||
d.addCallback(self._testPool_5)
|
||||
d.addCallback(self._testPool_6)
|
||||
d.addCallback(self._testPool_7)
|
||||
d.addCallback(self._testPool_8)
|
||||
d.addCallback(self._testPool_9)
|
||||
return d
|
||||
|
||||
def _testPool_1_1(self, res):
|
||||
d = defer.maybeDeferred(self.dbpool.runQuery, "select * from NOTABLE")
|
||||
d.addCallbacks(lambda res: self.fail('no exception'),
|
||||
lambda f: None)
|
||||
return d
|
||||
|
||||
def _testPool_1_2(self, res):
|
||||
d = defer.maybeDeferred(self.dbpool.runOperation,
|
||||
"deletexxx from NOTABLE")
|
||||
d.addCallbacks(lambda res: self.fail('no exception'),
|
||||
lambda f: None)
|
||||
return d
|
||||
|
||||
def _testPool_1_3(self, res):
|
||||
d = defer.maybeDeferred(self.dbpool.runInteraction,
|
||||
self.bad_interaction)
|
||||
d.addCallbacks(lambda res: self.fail('no exception'),
|
||||
lambda f: None)
|
||||
return d
|
||||
|
||||
def _testPool_1_4(self, res):
|
||||
d = defer.maybeDeferred(self.dbpool.runWithConnection,
|
||||
self.bad_withConnection)
|
||||
d.addCallbacks(lambda res: self.fail('no exception'),
|
||||
lambda f: None)
|
||||
return d
|
||||
|
||||
def _testPool_2(self, res):
|
||||
# verify simple table is empty
|
||||
sql = "select count(1) from simple"
|
||||
d = self.dbpool.runQuery(sql)
|
||||
def _check(row):
|
||||
self.failUnless(int(row[0][0]) == 0, "Interaction not rolled back")
|
||||
self.checkOpenfunCalled()
|
||||
d.addCallback(_check)
|
||||
return d
|
||||
|
||||
def _testPool_3(self, res):
|
||||
sql = "select count(1) from simple"
|
||||
inserts = []
|
||||
# add some rows to simple table (runOperation)
|
||||
for i in range(self.num_iterations):
|
||||
sql = "insert into simple(x) values(%d)" % i
|
||||
inserts.append(self.dbpool.runOperation(sql))
|
||||
d = defer.gatherResults(inserts)
|
||||
|
||||
def _select(res):
|
||||
# make sure they were added (runQuery)
|
||||
sql = "select x from simple order by x";
|
||||
d = self.dbpool.runQuery(sql)
|
||||
return d
|
||||
d.addCallback(_select)
|
||||
|
||||
def _check(rows):
|
||||
self.failUnless(len(rows) == self.num_iterations,
|
||||
"Wrong number of rows")
|
||||
for i in range(self.num_iterations):
|
||||
self.failUnless(len(rows[i]) == 1, "Wrong size row")
|
||||
self.failUnless(rows[i][0] == i, "Values not returned.")
|
||||
d.addCallback(_check)
|
||||
|
||||
return d
|
||||
|
||||
def _testPool_4(self, res):
|
||||
# runInteraction
|
||||
d = self.dbpool.runInteraction(self.interaction)
|
||||
d.addCallback(lambda res: self.assertEqual(res, "done"))
|
||||
return d
|
||||
|
||||
def _testPool_5(self, res):
|
||||
# withConnection
|
||||
d = self.dbpool.runWithConnection(self.withConnection)
|
||||
d.addCallback(lambda res: self.assertEqual(res, "done"))
|
||||
return d
|
||||
|
||||
def _testPool_6(self, res):
|
||||
# Test a withConnection cannot be closed
|
||||
d = self.dbpool.runWithConnection(self.close_withConnection)
|
||||
return d
|
||||
|
||||
def _testPool_7(self, res):
|
||||
# give the pool a workout
|
||||
ds = []
|
||||
for i in range(self.num_iterations):
|
||||
sql = "select x from simple where x = %d" % i
|
||||
ds.append(self.dbpool.runQuery(sql))
|
||||
dlist = defer.DeferredList(ds, fireOnOneErrback=True)
|
||||
def _check(result):
|
||||
for i in range(self.num_iterations):
|
||||
self.failUnless(result[i][1][0][0] == i, "Value not returned")
|
||||
dlist.addCallback(_check)
|
||||
return dlist
|
||||
|
||||
def _testPool_8(self, res):
|
||||
# now delete everything
|
||||
ds = []
|
||||
for i in range(self.num_iterations):
|
||||
sql = "delete from simple where x = %d" % i
|
||||
ds.append(self.dbpool.runOperation(sql))
|
||||
dlist = defer.DeferredList(ds, fireOnOneErrback=True)
|
||||
return dlist
|
||||
|
||||
def _testPool_9(self, res):
|
||||
# verify simple table is empty
|
||||
sql = "select count(1) from simple"
|
||||
d = self.dbpool.runQuery(sql)
|
||||
def _check(row):
|
||||
self.failUnless(int(row[0][0]) == 0,
|
||||
"Didn't successfully delete table contents")
|
||||
self.checkConnect()
|
||||
d.addCallback(_check)
|
||||
return d
|
||||
|
||||
def checkConnect(self):
|
||||
"""Check the connect/disconnect synchronous calls."""
|
||||
conn = self.dbpool.connect()
|
||||
self.checkOpenfunCalled(conn)
|
||||
curs = conn.cursor()
|
||||
curs.execute("insert into simple(x) values(1)")
|
||||
curs.execute("select x from simple")
|
||||
res = curs.fetchall()
|
||||
self.assertEqual(len(res), 1)
|
||||
self.assertEqual(len(res[0]), 1)
|
||||
self.assertEqual(res[0][0], 1)
|
||||
curs.execute("delete from simple")
|
||||
curs.execute("select x from simple")
|
||||
self.assertEqual(len(curs.fetchall()), 0)
|
||||
curs.close()
|
||||
self.dbpool.disconnect(conn)
|
||||
|
||||
def interaction(self, transaction):
|
||||
transaction.execute("select x from simple order by x")
|
||||
for i in range(self.num_iterations):
|
||||
row = transaction.fetchone()
|
||||
self.failUnless(len(row) == 1, "Wrong size row")
|
||||
self.failUnless(row[0] == i, "Value not returned.")
|
||||
# should test this, but gadfly throws an exception instead
|
||||
#self.failUnless(transaction.fetchone() is None, "Too many rows")
|
||||
return "done"
|
||||
|
||||
def bad_interaction(self, transaction):
|
||||
if self.can_rollback:
|
||||
transaction.execute("insert into simple(x) values(0)")
|
||||
|
||||
transaction.execute("select * from NOTABLE")
|
||||
|
||||
def withConnection(self, conn):
|
||||
curs = conn.cursor()
|
||||
try:
|
||||
curs.execute("select x from simple order by x")
|
||||
for i in range(self.num_iterations):
|
||||
row = curs.fetchone()
|
||||
self.failUnless(len(row) == 1, "Wrong size row")
|
||||
self.failUnless(row[0] == i, "Value not returned.")
|
||||
# should test this, but gadfly throws an exception instead
|
||||
#self.failUnless(transaction.fetchone() is None, "Too many rows")
|
||||
finally:
|
||||
curs.close()
|
||||
return "done"
|
||||
|
||||
def close_withConnection(self, conn):
|
||||
conn.close()
|
||||
|
||||
def bad_withConnection(self, conn):
|
||||
curs = conn.cursor()
|
||||
try:
|
||||
curs.execute("select * from NOTABLE")
|
||||
finally:
|
||||
curs.close()
|
||||
|
||||
|
||||
class ReconnectTestBase:
|
||||
"""Test the asynchronous DB-API code with reconnect."""
|
||||
|
||||
if interfaces.IReactorThreads(reactor, None) is None:
|
||||
skip = "ADB-API requires threads, no way to test without them"
|
||||
|
||||
def extraSetUp(self):
|
||||
"""
|
||||
Skip the test if C{good_sql} is unavailable. Otherwise, set up the
|
||||
database, create a connection pool pointed at it, and set up a simple
|
||||
schema in it.
|
||||
"""
|
||||
if self.good_sql is None:
|
||||
raise unittest.SkipTest('no good sql for reconnect test')
|
||||
self.startDB()
|
||||
self.dbpool = self.makePool(cp_max=1, cp_reconnect=True,
|
||||
cp_good_sql=self.good_sql)
|
||||
self.dbpool.start()
|
||||
return self.dbpool.runOperation(simple_table_schema)
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
d = self.dbpool.runOperation('DROP TABLE simple')
|
||||
d.addCallback(lambda res: self.dbpool.close())
|
||||
d.addCallback(lambda res: self.stopDB())
|
||||
return d
|
||||
|
||||
def testPool(self):
|
||||
d = defer.succeed(None)
|
||||
d.addCallback(self._testPool_1)
|
||||
d.addCallback(self._testPool_2)
|
||||
if not self.early_reconnect:
|
||||
d.addCallback(self._testPool_3)
|
||||
d.addCallback(self._testPool_4)
|
||||
d.addCallback(self._testPool_5)
|
||||
return d
|
||||
|
||||
def _testPool_1(self, res):
|
||||
sql = "select count(1) from simple"
|
||||
d = self.dbpool.runQuery(sql)
|
||||
def _check(row):
|
||||
self.failUnless(int(row[0][0]) == 0, "Table not empty")
|
||||
d.addCallback(_check)
|
||||
return d
|
||||
|
||||
def _testPool_2(self, res):
|
||||
# reach in and close the connection manually
|
||||
self.dbpool.connections.values()[0].close()
|
||||
|
||||
def _testPool_3(self, res):
|
||||
sql = "select count(1) from simple"
|
||||
d = defer.maybeDeferred(self.dbpool.runQuery, sql)
|
||||
d.addCallbacks(lambda res: self.fail('no exception'),
|
||||
lambda f: None)
|
||||
return d
|
||||
|
||||
def _testPool_4(self, res):
|
||||
sql = "select count(1) from simple"
|
||||
d = self.dbpool.runQuery(sql)
|
||||
def _check(row):
|
||||
self.failUnless(int(row[0][0]) == 0, "Table not empty")
|
||||
d.addCallback(_check)
|
||||
return d
|
||||
|
||||
def _testPool_5(self, res):
|
||||
self.flushLoggedErrors()
|
||||
sql = "select * from NOTABLE" # bad sql
|
||||
d = defer.maybeDeferred(self.dbpool.runQuery, sql)
|
||||
d.addCallbacks(lambda res: self.fail('no exception'),
|
||||
lambda f: self.failIf(f.check(ConnectionLost)))
|
||||
return d
|
||||
|
||||
|
||||
class DBTestConnector:
|
||||
"""A class which knows how to test for the presence of
|
||||
and establish a connection to a relational database.
|
||||
|
||||
To enable test cases which use a central, system database,
|
||||
you must create a database named DB_NAME with a user DB_USER
|
||||
and password DB_PASS with full access rights to database DB_NAME.
|
||||
"""
|
||||
|
||||
TEST_PREFIX = None # used for creating new test cases
|
||||
|
||||
DB_NAME = "twisted_test"
|
||||
DB_USER = 'twisted_test'
|
||||
DB_PASS = 'twisted_test'
|
||||
|
||||
DB_DIR = None # directory for database storage
|
||||
|
||||
nulls_ok = True # nulls supported
|
||||
trailing_spaces_ok = True # trailing spaces in strings preserved
|
||||
can_rollback = True # rollback supported
|
||||
test_failures = True # test bad sql?
|
||||
escape_slashes = True # escape \ in sql?
|
||||
good_sql = ConnectionPool.good_sql
|
||||
early_reconnect = True # cursor() will fail on closed connection
|
||||
can_clear = True # can try to clear out tables when starting
|
||||
|
||||
num_iterations = 50 # number of iterations for test loops
|
||||
# (lower this for slow db's)
|
||||
|
||||
def setUp(self):
|
||||
self.DB_DIR = self.mktemp()
|
||||
os.mkdir(self.DB_DIR)
|
||||
if not self.can_connect():
|
||||
raise unittest.SkipTest('%s: Cannot access db' % self.TEST_PREFIX)
|
||||
return self.extraSetUp()
|
||||
|
||||
def can_connect(self):
|
||||
"""Return true if this database is present on the system
|
||||
and can be used in a test."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def startDB(self):
|
||||
"""Take any steps needed to bring database up."""
|
||||
pass
|
||||
|
||||
def stopDB(self):
|
||||
"""Bring database down, if needed."""
|
||||
pass
|
||||
|
||||
def makePool(self, **newkw):
|
||||
"""Create a connection pool with additional keyword arguments."""
|
||||
args, kw = self.getPoolArgs()
|
||||
kw = kw.copy()
|
||||
kw.update(newkw)
|
||||
return ConnectionPool(*args, **kw)
|
||||
|
||||
def getPoolArgs(self):
|
||||
"""Return a tuple (args, kw) of list and keyword arguments
|
||||
that need to be passed to ConnectionPool to create a connection
|
||||
to this database."""
|
||||
raise NotImplementedError()
|
||||
|
||||
class GadflyConnector(DBTestConnector):
|
||||
TEST_PREFIX = 'Gadfly'
|
||||
|
||||
nulls_ok = False
|
||||
can_rollback = False
|
||||
escape_slashes = False
|
||||
good_sql = 'select * from simple where 1=0'
|
||||
|
||||
num_iterations = 1 # slow
|
||||
|
||||
def can_connect(self):
|
||||
try: import gadfly
|
||||
except: return False
|
||||
if not getattr(gadfly, 'connect', None):
|
||||
gadfly.connect = gadfly.gadfly
|
||||
return True
|
||||
|
||||
def startDB(self):
|
||||
import gadfly
|
||||
conn = gadfly.gadfly()
|
||||
conn.startup(self.DB_NAME, self.DB_DIR)
|
||||
|
||||
# gadfly seems to want us to create something to get the db going
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("create table x (x integer)")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def getPoolArgs(self):
|
||||
args = ('gadfly', self.DB_NAME, self.DB_DIR)
|
||||
kw = {'cp_max': 1}
|
||||
return args, kw
|
||||
|
||||
class SQLiteConnector(DBTestConnector):
|
||||
TEST_PREFIX = 'SQLite'
|
||||
|
||||
escape_slashes = False
|
||||
|
||||
num_iterations = 1 # slow
|
||||
|
||||
def can_connect(self):
|
||||
try: import sqlite
|
||||
except: return False
|
||||
return True
|
||||
|
||||
def startDB(self):
|
||||
self.database = os.path.join(self.DB_DIR, self.DB_NAME)
|
||||
if os.path.exists(self.database):
|
||||
os.unlink(self.database)
|
||||
|
||||
def getPoolArgs(self):
|
||||
args = ('sqlite',)
|
||||
kw = {'database': self.database, 'cp_max': 1}
|
||||
return args, kw
|
||||
|
||||
class PyPgSQLConnector(DBTestConnector):
|
||||
TEST_PREFIX = "PyPgSQL"
|
||||
|
||||
def can_connect(self):
|
||||
try: from pyPgSQL import PgSQL
|
||||
except: return False
|
||||
try:
|
||||
conn = PgSQL.connect(database=self.DB_NAME, user=self.DB_USER,
|
||||
password=self.DB_PASS)
|
||||
conn.close()
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def getPoolArgs(self):
|
||||
args = ('pyPgSQL.PgSQL',)
|
||||
kw = {'database': self.DB_NAME, 'user': self.DB_USER,
|
||||
'password': self.DB_PASS, 'cp_min': 0}
|
||||
return args, kw
|
||||
|
||||
class PsycopgConnector(DBTestConnector):
|
||||
TEST_PREFIX = 'Psycopg'
|
||||
|
||||
def can_connect(self):
|
||||
try: import psycopg
|
||||
except: return False
|
||||
try:
|
||||
conn = psycopg.connect(database=self.DB_NAME, user=self.DB_USER,
|
||||
password=self.DB_PASS)
|
||||
conn.close()
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def getPoolArgs(self):
|
||||
args = ('psycopg',)
|
||||
kw = {'database': self.DB_NAME, 'user': self.DB_USER,
|
||||
'password': self.DB_PASS, 'cp_min': 0}
|
||||
return args, kw
|
||||
|
||||
class MySQLConnector(DBTestConnector):
|
||||
TEST_PREFIX = 'MySQL'
|
||||
|
||||
trailing_spaces_ok = False
|
||||
can_rollback = False
|
||||
early_reconnect = False
|
||||
|
||||
def can_connect(self):
|
||||
try: import MySQLdb
|
||||
except: return False
|
||||
try:
|
||||
conn = MySQLdb.connect(db=self.DB_NAME, user=self.DB_USER,
|
||||
passwd=self.DB_PASS)
|
||||
conn.close()
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def getPoolArgs(self):
|
||||
args = ('MySQLdb',)
|
||||
kw = {'db': self.DB_NAME, 'user': self.DB_USER, 'passwd': self.DB_PASS}
|
||||
return args, kw
|
||||
|
||||
class FirebirdConnector(DBTestConnector):
|
||||
TEST_PREFIX = 'Firebird'
|
||||
|
||||
test_failures = False # failure testing causes problems
|
||||
escape_slashes = False
|
||||
good_sql = None # firebird doesn't handle failed sql well
|
||||
can_clear = False # firebird is not so good
|
||||
|
||||
num_iterations = 5 # slow
|
||||
|
||||
def can_connect(self):
|
||||
try: import kinterbasdb
|
||||
except: return False
|
||||
try:
|
||||
self.startDB()
|
||||
self.stopDB()
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def startDB(self):
|
||||
import kinterbasdb
|
||||
self.DB_NAME = os.path.join(self.DB_DIR, DBTestConnector.DB_NAME)
|
||||
os.chmod(self.DB_DIR, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO)
|
||||
sql = 'create database "%s" user "%s" password "%s"'
|
||||
sql %= (self.DB_NAME, self.DB_USER, self.DB_PASS);
|
||||
conn = kinterbasdb.create_database(sql)
|
||||
conn.close()
|
||||
|
||||
|
||||
def getPoolArgs(self):
|
||||
args = ('kinterbasdb',)
|
||||
kw = {'database': self.DB_NAME, 'host': '127.0.0.1',
|
||||
'user': self.DB_USER, 'password': self.DB_PASS}
|
||||
return args, kw
|
||||
|
||||
def stopDB(self):
|
||||
import kinterbasdb
|
||||
conn = kinterbasdb.connect(database=self.DB_NAME,
|
||||
host='127.0.0.1', user=self.DB_USER,
|
||||
password=self.DB_PASS)
|
||||
conn.drop_database()
|
||||
|
||||
def makeSQLTests(base, suffix, globals):
|
||||
"""
|
||||
Make a test case for every db connector which can connect.
|
||||
|
||||
@param base: Base class for test case. Additional base classes
|
||||
will be a DBConnector subclass and unittest.TestCase
|
||||
@param suffix: A suffix used to create test case names. Prefixes
|
||||
are defined in the DBConnector subclasses.
|
||||
"""
|
||||
connectors = [GadflyConnector, SQLiteConnector, PyPgSQLConnector,
|
||||
PsycopgConnector, MySQLConnector, FirebirdConnector]
|
||||
for connclass in connectors:
|
||||
name = connclass.TEST_PREFIX + suffix
|
||||
klass = types.ClassType(name, (connclass, base, unittest.TestCase),
|
||||
base.__dict__)
|
||||
globals[name] = klass
|
||||
|
||||
# GadflyADBAPITestCase SQLiteADBAPITestCase PyPgSQLADBAPITestCase
|
||||
# PsycopgADBAPITestCase MySQLADBAPITestCase FirebirdADBAPITestCase
|
||||
makeSQLTests(ADBAPITestBase, 'ADBAPITestCase', globals())
|
||||
|
||||
# GadflyReconnectTestCase SQLiteReconnectTestCase PyPgSQLReconnectTestCase
|
||||
# PsycopgReconnectTestCase MySQLReconnectTestCase FirebirdReconnectTestCase
|
||||
makeSQLTests(ReconnectTestBase, 'ReconnectTestCase', globals())
|
||||
|
||||
|
||||
|
||||
class FakePool(object):
|
||||
"""
|
||||
A fake L{ConnectionPool} for tests.
|
||||
|
||||
@ivar connectionFactory: factory for making connections returned by the
|
||||
C{connect} method.
|
||||
@type connectionFactory: any callable
|
||||
"""
|
||||
reconnect = True
|
||||
noisy = True
|
||||
|
||||
def __init__(self, connectionFactory):
|
||||
self.connectionFactory = connectionFactory
|
||||
|
||||
|
||||
def connect(self):
|
||||
"""
|
||||
Return an instance of C{self.connectionFactory}.
|
||||
"""
|
||||
return self.connectionFactory()
|
||||
|
||||
|
||||
def disconnect(self, connection):
|
||||
"""
|
||||
Do nothing.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class ConnectionTestCase(unittest.TestCase):
|
||||
"""
|
||||
Tests for the L{Connection} class.
|
||||
"""
|
||||
|
||||
def test_rollbackErrorLogged(self):
|
||||
"""
|
||||
If an error happens during rollback, L{ConnectionLost} is raised but
|
||||
the original error is logged.
|
||||
"""
|
||||
class ConnectionRollbackRaise(object):
|
||||
def rollback(self):
|
||||
raise RuntimeError("problem!")
|
||||
|
||||
pool = FakePool(ConnectionRollbackRaise)
|
||||
connection = Connection(pool)
|
||||
self.assertRaises(ConnectionLost, connection.rollback)
|
||||
errors = self.flushLoggedErrors(RuntimeError)
|
||||
self.assertEqual(len(errors), 1)
|
||||
self.assertEqual(errors[0].value.args[0], "problem!")
|
||||
|
||||
|
||||
|
||||
class TransactionTestCase(unittest.TestCase):
|
||||
"""
|
||||
Tests for the L{Transaction} class.
|
||||
"""
|
||||
|
||||
def test_reopenLogErrorIfReconnect(self):
|
||||
"""
|
||||
If the cursor creation raises an error in L{Transaction.reopen}, it
|
||||
reconnects but log the error occurred.
|
||||
"""
|
||||
class ConnectionCursorRaise(object):
|
||||
count = 0
|
||||
|
||||
def reconnect(self):
|
||||
pass
|
||||
|
||||
def cursor(self):
|
||||
if self.count == 0:
|
||||
self.count += 1
|
||||
raise RuntimeError("problem!")
|
||||
|
||||
pool = FakePool(None)
|
||||
transaction = Transaction(pool, ConnectionCursorRaise())
|
||||
transaction.reopen()
|
||||
errors = self.flushLoggedErrors(RuntimeError)
|
||||
self.assertEqual(len(errors), 1)
|
||||
self.assertEqual(errors[0].value.args[0], "problem!")
|
||||
|
||||
|
||||
|
||||
class NonThreadPool(object):
|
||||
def callInThreadWithCallback(self, onResult, f, *a, **kw):
|
||||
success = True
|
||||
try:
|
||||
result = f(*a, **kw)
|
||||
except Exception, e:
|
||||
success = False
|
||||
result = Failure()
|
||||
onResult(success, result)
|
||||
|
||||
|
||||
|
||||
class DummyConnectionPool(ConnectionPool):
|
||||
"""
|
||||
A testable L{ConnectionPool};
|
||||
"""
|
||||
threadpool = NonThreadPool()
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Don't forward init call.
|
||||
"""
|
||||
self.reactor = reactor
|
||||
|
||||
|
||||
|
||||
class EventReactor(object):
|
||||
"""
|
||||
Partial L{IReactorCore} implementation with simple event-related
|
||||
methods.
|
||||
|
||||
@ivar _running: A C{bool} indicating whether the reactor is pretending
|
||||
to have been started already or not.
|
||||
|
||||
@ivar triggers: A C{list} of pending system event triggers.
|
||||
"""
|
||||
def __init__(self, running):
|
||||
self._running = running
|
||||
self.triggers = []
|
||||
|
||||
|
||||
def callWhenRunning(self, function):
|
||||
if self._running:
|
||||
function()
|
||||
else:
|
||||
return self.addSystemEventTrigger('after', 'startup', function)
|
||||
|
||||
|
||||
def addSystemEventTrigger(self, phase, event, trigger):
|
||||
handle = (phase, event, trigger)
|
||||
self.triggers.append(handle)
|
||||
return handle
|
||||
|
||||
|
||||
def removeSystemEventTrigger(self, handle):
|
||||
self.triggers.remove(handle)
|
||||
|
||||
|
||||
|
||||
class ConnectionPoolTestCase(unittest.TestCase):
|
||||
"""
|
||||
Unit tests for L{ConnectionPool}.
|
||||
"""
|
||||
|
||||
def test_runWithConnectionRaiseOriginalError(self):
|
||||
"""
|
||||
If rollback fails, L{ConnectionPool.runWithConnection} raises the
|
||||
original exception and log the error of the rollback.
|
||||
"""
|
||||
class ConnectionRollbackRaise(object):
|
||||
def __init__(self, pool):
|
||||
pass
|
||||
|
||||
def rollback(self):
|
||||
raise RuntimeError("problem!")
|
||||
|
||||
def raisingFunction(connection):
|
||||
raise ValueError("foo")
|
||||
|
||||
pool = DummyConnectionPool()
|
||||
pool.connectionFactory = ConnectionRollbackRaise
|
||||
d = pool.runWithConnection(raisingFunction)
|
||||
d = self.assertFailure(d, ValueError)
|
||||
def cbFailed(ignored):
|
||||
errors = self.flushLoggedErrors(RuntimeError)
|
||||
self.assertEqual(len(errors), 1)
|
||||
self.assertEqual(errors[0].value.args[0], "problem!")
|
||||
d.addCallback(cbFailed)
|
||||
return d
|
||||
|
||||
|
||||
def test_closeLogError(self):
|
||||
"""
|
||||
L{ConnectionPool._close} logs exceptions.
|
||||
"""
|
||||
class ConnectionCloseRaise(object):
|
||||
def close(self):
|
||||
raise RuntimeError("problem!")
|
||||
|
||||
pool = DummyConnectionPool()
|
||||
pool._close(ConnectionCloseRaise())
|
||||
|
||||
errors = self.flushLoggedErrors(RuntimeError)
|
||||
self.assertEqual(len(errors), 1)
|
||||
self.assertEqual(errors[0].value.args[0], "problem!")
|
||||
|
||||
|
||||
def test_runWithInteractionRaiseOriginalError(self):
|
||||
"""
|
||||
If rollback fails, L{ConnectionPool.runInteraction} raises the
|
||||
original exception and log the error of the rollback.
|
||||
"""
|
||||
class ConnectionRollbackRaise(object):
|
||||
def __init__(self, pool):
|
||||
pass
|
||||
|
||||
def rollback(self):
|
||||
raise RuntimeError("problem!")
|
||||
|
||||
class DummyTransaction(object):
|
||||
def __init__(self, pool, connection):
|
||||
pass
|
||||
|
||||
def raisingFunction(transaction):
|
||||
raise ValueError("foo")
|
||||
|
||||
pool = DummyConnectionPool()
|
||||
pool.connectionFactory = ConnectionRollbackRaise
|
||||
pool.transactionFactory = DummyTransaction
|
||||
|
||||
d = pool.runInteraction(raisingFunction)
|
||||
d = self.assertFailure(d, ValueError)
|
||||
def cbFailed(ignored):
|
||||
errors = self.flushLoggedErrors(RuntimeError)
|
||||
self.assertEqual(len(errors), 1)
|
||||
self.assertEqual(errors[0].value.args[0], "problem!")
|
||||
d.addCallback(cbFailed)
|
||||
return d
|
||||
|
||||
|
||||
def test_unstartedClose(self):
|
||||
"""
|
||||
If L{ConnectionPool.close} is called without L{ConnectionPool.start}
|
||||
having been called, the pool's startup event is cancelled.
|
||||
"""
|
||||
reactor = EventReactor(False)
|
||||
pool = ConnectionPool('twisted.test.test_adbapi', cp_reactor=reactor)
|
||||
# There should be a startup trigger waiting.
|
||||
self.assertEqual(reactor.triggers, [('after', 'startup', pool._start)])
|
||||
pool.close()
|
||||
# But not anymore.
|
||||
self.assertFalse(reactor.triggers)
|
||||
|
||||
|
||||
def test_startedClose(self):
|
||||
"""
|
||||
If L{ConnectionPool.close} is called after it has been started, but
|
||||
not by its shutdown trigger, the shutdown trigger is cancelled.
|
||||
"""
|
||||
reactor = EventReactor(True)
|
||||
pool = ConnectionPool('twisted.test.test_adbapi', cp_reactor=reactor)
|
||||
# There should be a shutdown trigger waiting.
|
||||
self.assertEqual(reactor.triggers, [('during', 'shutdown', pool.finalClose)])
|
||||
pool.close()
|
||||
# But not anymore.
|
||||
self.assertFalse(reactor.triggers)
|
||||
3171
Linux_i686/lib/python2.7/site-packages/twisted/test/test_amp.py
Normal file
3171
Linux_i686/lib/python2.7/site-packages/twisted/test/test_amp.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,896 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.application} and its interaction with
|
||||
L{twisted.persisted.sob}.
|
||||
"""
|
||||
|
||||
import copy, os, pickle
|
||||
from StringIO import StringIO
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.application import service, internet, app
|
||||
from twisted.persisted import sob
|
||||
from twisted.python import usage
|
||||
from twisted.internet import interfaces, defer
|
||||
from twisted.protocols import wire, basic
|
||||
from twisted.internet import protocol, reactor
|
||||
from twisted.application import reactors
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
from twisted.python.test.modules_helpers import TwistedModulesMixin
|
||||
|
||||
|
||||
class Dummy:
|
||||
processName=None
|
||||
|
||||
class TestService(unittest.TestCase):
|
||||
|
||||
def testName(self):
|
||||
s = service.Service()
|
||||
s.setName("hello")
|
||||
self.assertEqual(s.name, "hello")
|
||||
|
||||
def testParent(self):
|
||||
s = service.Service()
|
||||
p = service.MultiService()
|
||||
s.setServiceParent(p)
|
||||
self.assertEqual(list(p), [s])
|
||||
self.assertEqual(s.parent, p)
|
||||
|
||||
def testApplicationAsParent(self):
|
||||
s = service.Service()
|
||||
p = service.Application("")
|
||||
s.setServiceParent(p)
|
||||
self.assertEqual(list(service.IServiceCollection(p)), [s])
|
||||
self.assertEqual(s.parent, service.IServiceCollection(p))
|
||||
|
||||
def testNamedChild(self):
|
||||
s = service.Service()
|
||||
p = service.MultiService()
|
||||
s.setName("hello")
|
||||
s.setServiceParent(p)
|
||||
self.assertEqual(list(p), [s])
|
||||
self.assertEqual(s.parent, p)
|
||||
self.assertEqual(p.getServiceNamed("hello"), s)
|
||||
|
||||
def testDoublyNamedChild(self):
|
||||
s = service.Service()
|
||||
p = service.MultiService()
|
||||
s.setName("hello")
|
||||
s.setServiceParent(p)
|
||||
self.failUnlessRaises(RuntimeError, s.setName, "lala")
|
||||
|
||||
def testDuplicateNamedChild(self):
|
||||
s = service.Service()
|
||||
p = service.MultiService()
|
||||
s.setName("hello")
|
||||
s.setServiceParent(p)
|
||||
s = service.Service()
|
||||
s.setName("hello")
|
||||
self.failUnlessRaises(RuntimeError, s.setServiceParent, p)
|
||||
|
||||
def testDisowning(self):
|
||||
s = service.Service()
|
||||
p = service.MultiService()
|
||||
s.setServiceParent(p)
|
||||
self.assertEqual(list(p), [s])
|
||||
self.assertEqual(s.parent, p)
|
||||
s.disownServiceParent()
|
||||
self.assertEqual(list(p), [])
|
||||
self.assertEqual(s.parent, None)
|
||||
|
||||
def testRunning(self):
|
||||
s = service.Service()
|
||||
self.assert_(not s.running)
|
||||
s.startService()
|
||||
self.assert_(s.running)
|
||||
s.stopService()
|
||||
self.assert_(not s.running)
|
||||
|
||||
def testRunningChildren1(self):
|
||||
s = service.Service()
|
||||
p = service.MultiService()
|
||||
s.setServiceParent(p)
|
||||
self.assert_(not s.running)
|
||||
self.assert_(not p.running)
|
||||
p.startService()
|
||||
self.assert_(s.running)
|
||||
self.assert_(p.running)
|
||||
p.stopService()
|
||||
self.assert_(not s.running)
|
||||
self.assert_(not p.running)
|
||||
|
||||
def testRunningChildren2(self):
|
||||
s = service.Service()
|
||||
def checkRunning():
|
||||
self.assert_(s.running)
|
||||
t = service.Service()
|
||||
t.stopService = checkRunning
|
||||
t.startService = checkRunning
|
||||
p = service.MultiService()
|
||||
s.setServiceParent(p)
|
||||
t.setServiceParent(p)
|
||||
p.startService()
|
||||
p.stopService()
|
||||
|
||||
def testAddingIntoRunning(self):
|
||||
p = service.MultiService()
|
||||
p.startService()
|
||||
s = service.Service()
|
||||
self.assert_(not s.running)
|
||||
s.setServiceParent(p)
|
||||
self.assert_(s.running)
|
||||
s.disownServiceParent()
|
||||
self.assert_(not s.running)
|
||||
|
||||
def testPrivileged(self):
|
||||
s = service.Service()
|
||||
def pss():
|
||||
s.privilegedStarted = 1
|
||||
s.privilegedStartService = pss
|
||||
s1 = service.Service()
|
||||
p = service.MultiService()
|
||||
s.setServiceParent(p)
|
||||
s1.setServiceParent(p)
|
||||
p.privilegedStartService()
|
||||
self.assert_(s.privilegedStarted)
|
||||
|
||||
def testCopying(self):
|
||||
s = service.Service()
|
||||
s.startService()
|
||||
s1 = copy.copy(s)
|
||||
self.assert_(not s1.running)
|
||||
self.assert_(s.running)
|
||||
|
||||
|
||||
if hasattr(os, "getuid"):
|
||||
curuid = os.getuid()
|
||||
curgid = os.getgid()
|
||||
else:
|
||||
curuid = curgid = 0
|
||||
|
||||
|
||||
class TestProcess(unittest.TestCase):
|
||||
|
||||
def testID(self):
|
||||
p = service.Process(5, 6)
|
||||
self.assertEqual(p.uid, 5)
|
||||
self.assertEqual(p.gid, 6)
|
||||
|
||||
def testDefaults(self):
|
||||
p = service.Process(5)
|
||||
self.assertEqual(p.uid, 5)
|
||||
self.assertEqual(p.gid, None)
|
||||
p = service.Process(gid=5)
|
||||
self.assertEqual(p.uid, None)
|
||||
self.assertEqual(p.gid, 5)
|
||||
p = service.Process()
|
||||
self.assertEqual(p.uid, None)
|
||||
self.assertEqual(p.gid, None)
|
||||
|
||||
def testProcessName(self):
|
||||
p = service.Process()
|
||||
self.assertEqual(p.processName, None)
|
||||
p.processName = 'hello'
|
||||
self.assertEqual(p.processName, 'hello')
|
||||
|
||||
|
||||
class TestInterfaces(unittest.TestCase):
|
||||
|
||||
def testService(self):
|
||||
self.assert_(service.IService.providedBy(service.Service()))
|
||||
|
||||
def testMultiService(self):
|
||||
self.assert_(service.IService.providedBy(service.MultiService()))
|
||||
self.assert_(service.IServiceCollection.providedBy(service.MultiService()))
|
||||
|
||||
def testProcess(self):
|
||||
self.assert_(service.IProcess.providedBy(service.Process()))
|
||||
|
||||
|
||||
class TestApplication(unittest.TestCase):
|
||||
|
||||
def testConstructor(self):
|
||||
service.Application("hello")
|
||||
service.Application("hello", 5)
|
||||
service.Application("hello", 5, 6)
|
||||
|
||||
def testProcessComponent(self):
|
||||
a = service.Application("hello")
|
||||
self.assertEqual(service.IProcess(a).uid, None)
|
||||
self.assertEqual(service.IProcess(a).gid, None)
|
||||
a = service.Application("hello", 5)
|
||||
self.assertEqual(service.IProcess(a).uid, 5)
|
||||
self.assertEqual(service.IProcess(a).gid, None)
|
||||
a = service.Application("hello", 5, 6)
|
||||
self.assertEqual(service.IProcess(a).uid, 5)
|
||||
self.assertEqual(service.IProcess(a).gid, 6)
|
||||
|
||||
def testServiceComponent(self):
|
||||
a = service.Application("hello")
|
||||
self.assert_(service.IService(a) is service.IServiceCollection(a))
|
||||
self.assertEqual(service.IService(a).name, "hello")
|
||||
self.assertEqual(service.IService(a).parent, None)
|
||||
|
||||
def testPersistableComponent(self):
|
||||
a = service.Application("hello")
|
||||
p = sob.IPersistable(a)
|
||||
self.assertEqual(p.style, 'pickle')
|
||||
self.assertEqual(p.name, 'hello')
|
||||
self.assert_(p.original is a)
|
||||
|
||||
class TestLoading(unittest.TestCase):
|
||||
|
||||
def test_simpleStoreAndLoad(self):
|
||||
a = service.Application("hello")
|
||||
p = sob.IPersistable(a)
|
||||
for style in 'source pickle'.split():
|
||||
p.setStyle(style)
|
||||
p.save()
|
||||
a1 = service.loadApplication("hello.ta"+style[0], style)
|
||||
self.assertEqual(service.IService(a1).name, "hello")
|
||||
f = open("hello.tac", 'w')
|
||||
f.writelines([
|
||||
"from twisted.application import service\n",
|
||||
"application = service.Application('hello')\n",
|
||||
])
|
||||
f.close()
|
||||
a1 = service.loadApplication("hello.tac", 'python')
|
||||
self.assertEqual(service.IService(a1).name, "hello")
|
||||
|
||||
|
||||
|
||||
class TestAppSupport(unittest.TestCase):
|
||||
|
||||
def testPassphrase(self):
|
||||
self.assertEqual(app.getPassphrase(0), None)
|
||||
|
||||
def testLoadApplication(self):
|
||||
"""
|
||||
Test loading an application file in different dump format.
|
||||
"""
|
||||
a = service.Application("hello")
|
||||
baseconfig = {'file': None, 'source': None, 'python':None}
|
||||
for style in 'source pickle'.split():
|
||||
config = baseconfig.copy()
|
||||
config[{'pickle': 'file'}.get(style, style)] = 'helloapplication'
|
||||
sob.IPersistable(a).setStyle(style)
|
||||
sob.IPersistable(a).save(filename='helloapplication')
|
||||
a1 = app.getApplication(config, None)
|
||||
self.assertEqual(service.IService(a1).name, "hello")
|
||||
config = baseconfig.copy()
|
||||
config['python'] = 'helloapplication'
|
||||
f = open("helloapplication", 'w')
|
||||
f.writelines([
|
||||
"from twisted.application import service\n",
|
||||
"application = service.Application('hello')\n",
|
||||
])
|
||||
f.close()
|
||||
a1 = app.getApplication(config, None)
|
||||
self.assertEqual(service.IService(a1).name, "hello")
|
||||
|
||||
def test_convertStyle(self):
|
||||
appl = service.Application("lala")
|
||||
for instyle in 'source pickle'.split():
|
||||
for outstyle in 'source pickle'.split():
|
||||
sob.IPersistable(appl).setStyle(instyle)
|
||||
sob.IPersistable(appl).save(filename="converttest")
|
||||
app.convertStyle("converttest", instyle, None,
|
||||
"converttest.out", outstyle, 0)
|
||||
appl2 = service.loadApplication("converttest.out", outstyle)
|
||||
self.assertEqual(service.IService(appl2).name, "lala")
|
||||
|
||||
|
||||
def test_startApplication(self):
|
||||
appl = service.Application("lala")
|
||||
app.startApplication(appl, 0)
|
||||
self.assert_(service.IService(appl).running)
|
||||
|
||||
|
||||
class Foo(basic.LineReceiver):
|
||||
def connectionMade(self):
|
||||
self.transport.write('lalala\r\n')
|
||||
def lineReceived(self, line):
|
||||
self.factory.line = line
|
||||
self.transport.loseConnection()
|
||||
def connectionLost(self, reason):
|
||||
self.factory.d.callback(self.factory.line)
|
||||
|
||||
|
||||
class DummyApp:
|
||||
processName = None
|
||||
def addService(self, service):
|
||||
self.services[service.name] = service
|
||||
def removeService(self, service):
|
||||
del self.services[service.name]
|
||||
|
||||
|
||||
class TimerTarget:
|
||||
def __init__(self):
|
||||
self.l = []
|
||||
def append(self, what):
|
||||
self.l.append(what)
|
||||
|
||||
class TestEcho(wire.Echo):
|
||||
def connectionLost(self, reason):
|
||||
self.d.callback(True)
|
||||
|
||||
class TestInternet2(unittest.TestCase):
|
||||
|
||||
def testTCP(self):
|
||||
s = service.MultiService()
|
||||
s.startService()
|
||||
factory = protocol.ServerFactory()
|
||||
factory.protocol = TestEcho
|
||||
TestEcho.d = defer.Deferred()
|
||||
t = internet.TCPServer(0, factory)
|
||||
t.setServiceParent(s)
|
||||
num = t._port.getHost().port
|
||||
factory = protocol.ClientFactory()
|
||||
factory.d = defer.Deferred()
|
||||
factory.protocol = Foo
|
||||
factory.line = None
|
||||
internet.TCPClient('127.0.0.1', num, factory).setServiceParent(s)
|
||||
factory.d.addCallback(self.assertEqual, 'lalala')
|
||||
factory.d.addCallback(lambda x : s.stopService())
|
||||
factory.d.addCallback(lambda x : TestEcho.d)
|
||||
return factory.d
|
||||
|
||||
|
||||
def test_UDP(self):
|
||||
"""
|
||||
Test L{internet.UDPServer} with a random port: starting the service
|
||||
should give it valid port, and stopService should free it so that we
|
||||
can start a server on the same port again.
|
||||
"""
|
||||
if not interfaces.IReactorUDP(reactor, None):
|
||||
raise unittest.SkipTest("This reactor does not support UDP sockets")
|
||||
p = protocol.DatagramProtocol()
|
||||
t = internet.UDPServer(0, p)
|
||||
t.startService()
|
||||
num = t._port.getHost().port
|
||||
self.assertNotEquals(num, 0)
|
||||
def onStop(ignored):
|
||||
t = internet.UDPServer(num, p)
|
||||
t.startService()
|
||||
return t.stopService()
|
||||
return defer.maybeDeferred(t.stopService).addCallback(onStop)
|
||||
|
||||
|
||||
def test_deprecatedUDPClient(self):
|
||||
"""
|
||||
L{internet.UDPClient} is deprecated since Twisted-13.1.
|
||||
"""
|
||||
internet.UDPClient
|
||||
warningsShown = self.flushWarnings([self.test_deprecatedUDPClient])
|
||||
self.assertEqual(1, len(warningsShown))
|
||||
self.assertEqual(
|
||||
"twisted.application.internet.UDPClient was deprecated in "
|
||||
"Twisted 13.1.0: It relies upon IReactorUDP.connectUDP "
|
||||
"which was removed in Twisted 10. "
|
||||
"Use twisted.application.internet.UDPServer instead.",
|
||||
warningsShown[0]['message'])
|
||||
|
||||
|
||||
def testPrivileged(self):
|
||||
factory = protocol.ServerFactory()
|
||||
factory.protocol = TestEcho
|
||||
TestEcho.d = defer.Deferred()
|
||||
t = internet.TCPServer(0, factory)
|
||||
t.privileged = 1
|
||||
t.privilegedStartService()
|
||||
num = t._port.getHost().port
|
||||
factory = protocol.ClientFactory()
|
||||
factory.d = defer.Deferred()
|
||||
factory.protocol = Foo
|
||||
factory.line = None
|
||||
c = internet.TCPClient('127.0.0.1', num, factory)
|
||||
c.startService()
|
||||
factory.d.addCallback(self.assertEqual, 'lalala')
|
||||
factory.d.addCallback(lambda x : c.stopService())
|
||||
factory.d.addCallback(lambda x : t.stopService())
|
||||
factory.d.addCallback(lambda x : TestEcho.d)
|
||||
return factory.d
|
||||
|
||||
def testConnectionGettingRefused(self):
|
||||
factory = protocol.ServerFactory()
|
||||
factory.protocol = wire.Echo
|
||||
t = internet.TCPServer(0, factory)
|
||||
t.startService()
|
||||
num = t._port.getHost().port
|
||||
t.stopService()
|
||||
d = defer.Deferred()
|
||||
factory = protocol.ClientFactory()
|
||||
factory.clientConnectionFailed = lambda *args: d.callback(None)
|
||||
c = internet.TCPClient('127.0.0.1', num, factory)
|
||||
c.startService()
|
||||
return d
|
||||
|
||||
def testUNIX(self):
|
||||
# FIXME: This test is far too dense. It needs comments.
|
||||
# -- spiv, 2004-11-07
|
||||
if not interfaces.IReactorUNIX(reactor, None):
|
||||
raise unittest.SkipTest, "This reactor does not support UNIX domain sockets"
|
||||
s = service.MultiService()
|
||||
s.startService()
|
||||
factory = protocol.ServerFactory()
|
||||
factory.protocol = TestEcho
|
||||
TestEcho.d = defer.Deferred()
|
||||
t = internet.UNIXServer('echo.skt', factory)
|
||||
t.setServiceParent(s)
|
||||
factory = protocol.ClientFactory()
|
||||
factory.protocol = Foo
|
||||
factory.d = defer.Deferred()
|
||||
factory.line = None
|
||||
internet.UNIXClient('echo.skt', factory).setServiceParent(s)
|
||||
factory.d.addCallback(self.assertEqual, 'lalala')
|
||||
factory.d.addCallback(lambda x : s.stopService())
|
||||
factory.d.addCallback(lambda x : TestEcho.d)
|
||||
factory.d.addCallback(self._cbTestUnix, factory, s)
|
||||
return factory.d
|
||||
|
||||
def _cbTestUnix(self, ignored, factory, s):
|
||||
TestEcho.d = defer.Deferred()
|
||||
factory.line = None
|
||||
factory.d = defer.Deferred()
|
||||
s.startService()
|
||||
factory.d.addCallback(self.assertEqual, 'lalala')
|
||||
factory.d.addCallback(lambda x : s.stopService())
|
||||
factory.d.addCallback(lambda x : TestEcho.d)
|
||||
return factory.d
|
||||
|
||||
def testVolatile(self):
|
||||
if not interfaces.IReactorUNIX(reactor, None):
|
||||
raise unittest.SkipTest, "This reactor does not support UNIX domain sockets"
|
||||
factory = protocol.ServerFactory()
|
||||
factory.protocol = wire.Echo
|
||||
t = internet.UNIXServer('echo.skt', factory)
|
||||
t.startService()
|
||||
self.failIfIdentical(t._port, None)
|
||||
t1 = copy.copy(t)
|
||||
self.assertIdentical(t1._port, None)
|
||||
t.stopService()
|
||||
self.assertIdentical(t._port, None)
|
||||
self.failIf(t.running)
|
||||
|
||||
factory = protocol.ClientFactory()
|
||||
factory.protocol = wire.Echo
|
||||
t = internet.UNIXClient('echo.skt', factory)
|
||||
t.startService()
|
||||
self.failIfIdentical(t._connection, None)
|
||||
t1 = copy.copy(t)
|
||||
self.assertIdentical(t1._connection, None)
|
||||
t.stopService()
|
||||
self.assertIdentical(t._connection, None)
|
||||
self.failIf(t.running)
|
||||
|
||||
def testStoppingServer(self):
|
||||
if not interfaces.IReactorUNIX(reactor, None):
|
||||
raise unittest.SkipTest, "This reactor does not support UNIX domain sockets"
|
||||
factory = protocol.ServerFactory()
|
||||
factory.protocol = wire.Echo
|
||||
t = internet.UNIXServer('echo.skt', factory)
|
||||
t.startService()
|
||||
t.stopService()
|
||||
self.failIf(t.running)
|
||||
factory = protocol.ClientFactory()
|
||||
d = defer.Deferred()
|
||||
factory.clientConnectionFailed = lambda *args: d.callback(None)
|
||||
reactor.connectUNIX('echo.skt', factory)
|
||||
return d
|
||||
|
||||
def testPickledTimer(self):
|
||||
target = TimerTarget()
|
||||
t0 = internet.TimerService(1, target.append, "hello")
|
||||
t0.startService()
|
||||
s = pickle.dumps(t0)
|
||||
t0.stopService()
|
||||
|
||||
t = pickle.loads(s)
|
||||
self.failIf(t.running)
|
||||
|
||||
def testBrokenTimer(self):
|
||||
d = defer.Deferred()
|
||||
t = internet.TimerService(1, lambda: 1 // 0)
|
||||
oldFailed = t._failed
|
||||
def _failed(why):
|
||||
oldFailed(why)
|
||||
d.callback(None)
|
||||
t._failed = _failed
|
||||
t.startService()
|
||||
d.addCallback(lambda x : t.stopService)
|
||||
d.addCallback(lambda x : self.assertEqual(
|
||||
[ZeroDivisionError],
|
||||
[o.value.__class__ for o in self.flushLoggedErrors(ZeroDivisionError)]))
|
||||
return d
|
||||
|
||||
|
||||
def test_everythingThere(self):
|
||||
"""
|
||||
L{twisted.application.internet} dynamically defines a set of
|
||||
L{service.Service} subclasses that in general have corresponding
|
||||
reactor.listenXXX or reactor.connectXXX calls.
|
||||
"""
|
||||
trans = 'TCP UNIX SSL UDP UNIXDatagram Multicast'.split()
|
||||
for tran in trans[:]:
|
||||
if not getattr(interfaces, "IReactor" + tran)(reactor, None):
|
||||
trans.remove(tran)
|
||||
for tran in trans:
|
||||
for side in 'Server Client'.split():
|
||||
if tran == "Multicast" and side == "Client":
|
||||
continue
|
||||
self.assertTrue(hasattr(internet, tran + side))
|
||||
method = getattr(internet, tran + side).method
|
||||
prefix = {'Server': 'listen', 'Client': 'connect'}[side]
|
||||
self.assertTrue(hasattr(reactor, prefix + method) or
|
||||
(prefix == "connect" and method == "UDP"))
|
||||
o = getattr(internet, tran + side)()
|
||||
self.assertEqual(service.IService(o), o)
|
||||
|
||||
|
||||
def test_importAll(self):
|
||||
"""
|
||||
L{twisted.application.internet} dynamically defines L{service.Service}
|
||||
subclasses. This test ensures that the subclasses exposed by C{__all__}
|
||||
are valid attributes of the module.
|
||||
"""
|
||||
for cls in internet.__all__:
|
||||
self.assertTrue(
|
||||
hasattr(internet, cls),
|
||||
'%s not importable from twisted.application.internet' % (cls,))
|
||||
|
||||
|
||||
def test_reactorParametrizationInServer(self):
|
||||
"""
|
||||
L{internet._AbstractServer} supports a C{reactor} keyword argument
|
||||
that can be used to parametrize the reactor used to listen for
|
||||
connections.
|
||||
"""
|
||||
reactor = MemoryReactor()
|
||||
|
||||
factory = object()
|
||||
t = internet.TCPServer(1234, factory, reactor=reactor)
|
||||
t.startService()
|
||||
self.assertEqual(reactor.tcpServers.pop()[:2], (1234, factory))
|
||||
|
||||
|
||||
def test_reactorParametrizationInClient(self):
|
||||
"""
|
||||
L{internet._AbstractClient} supports a C{reactor} keyword arguments
|
||||
that can be used to parametrize the reactor used to create new client
|
||||
connections.
|
||||
"""
|
||||
reactor = MemoryReactor()
|
||||
|
||||
factory = protocol.ClientFactory()
|
||||
t = internet.TCPClient('127.0.0.1', 1234, factory, reactor=reactor)
|
||||
t.startService()
|
||||
self.assertEqual(
|
||||
reactor.tcpClients.pop()[:3], ('127.0.0.1', 1234, factory))
|
||||
|
||||
|
||||
def test_reactorParametrizationInServerMultipleStart(self):
|
||||
"""
|
||||
Like L{test_reactorParametrizationInServer}, but stop and restart the
|
||||
service and check that the given reactor is still used.
|
||||
"""
|
||||
reactor = MemoryReactor()
|
||||
|
||||
factory = protocol.Factory()
|
||||
t = internet.TCPServer(1234, factory, reactor=reactor)
|
||||
t.startService()
|
||||
self.assertEqual(reactor.tcpServers.pop()[:2], (1234, factory))
|
||||
t.stopService()
|
||||
t.startService()
|
||||
self.assertEqual(reactor.tcpServers.pop()[:2], (1234, factory))
|
||||
|
||||
|
||||
def test_reactorParametrizationInClientMultipleStart(self):
|
||||
"""
|
||||
Like L{test_reactorParametrizationInClient}, but stop and restart the
|
||||
service and check that the given reactor is still used.
|
||||
"""
|
||||
reactor = MemoryReactor()
|
||||
|
||||
factory = protocol.ClientFactory()
|
||||
t = internet.TCPClient('127.0.0.1', 1234, factory, reactor=reactor)
|
||||
t.startService()
|
||||
self.assertEqual(
|
||||
reactor.tcpClients.pop()[:3], ('127.0.0.1', 1234, factory))
|
||||
t.stopService()
|
||||
t.startService()
|
||||
self.assertEqual(
|
||||
reactor.tcpClients.pop()[:3], ('127.0.0.1', 1234, factory))
|
||||
|
||||
|
||||
|
||||
class TestTimerBasic(unittest.TestCase):
|
||||
|
||||
def testTimerRuns(self):
|
||||
d = defer.Deferred()
|
||||
self.t = internet.TimerService(1, d.callback, 'hello')
|
||||
self.t.startService()
|
||||
d.addCallback(self.assertEqual, 'hello')
|
||||
d.addCallback(lambda x : self.t.stopService())
|
||||
d.addCallback(lambda x : self.failIf(self.t.running))
|
||||
return d
|
||||
|
||||
def tearDown(self):
|
||||
return self.t.stopService()
|
||||
|
||||
def testTimerRestart(self):
|
||||
# restart the same TimerService
|
||||
d1 = defer.Deferred()
|
||||
d2 = defer.Deferred()
|
||||
work = [(d2, "bar"), (d1, "foo")]
|
||||
def trigger():
|
||||
d, arg = work.pop()
|
||||
d.callback(arg)
|
||||
self.t = internet.TimerService(1, trigger)
|
||||
self.t.startService()
|
||||
def onFirstResult(result):
|
||||
self.assertEqual(result, 'foo')
|
||||
return self.t.stopService()
|
||||
def onFirstStop(ignored):
|
||||
self.failIf(self.t.running)
|
||||
self.t.startService()
|
||||
return d2
|
||||
def onSecondResult(result):
|
||||
self.assertEqual(result, 'bar')
|
||||
self.t.stopService()
|
||||
d1.addCallback(onFirstResult)
|
||||
d1.addCallback(onFirstStop)
|
||||
d1.addCallback(onSecondResult)
|
||||
return d1
|
||||
|
||||
def testTimerLoops(self):
|
||||
l = []
|
||||
def trigger(data, number, d):
|
||||
l.append(data)
|
||||
if len(l) == number:
|
||||
d.callback(l)
|
||||
d = defer.Deferred()
|
||||
self.t = internet.TimerService(0.01, trigger, "hello", 10, d)
|
||||
self.t.startService()
|
||||
d.addCallback(self.assertEqual, ['hello'] * 10)
|
||||
d.addCallback(lambda x : self.t.stopService())
|
||||
return d
|
||||
|
||||
|
||||
class FakeReactor(reactors.Reactor):
|
||||
"""
|
||||
A fake reactor with a hooked install method.
|
||||
"""
|
||||
|
||||
def __init__(self, install, *args, **kwargs):
|
||||
"""
|
||||
@param install: any callable that will be used as install method.
|
||||
@type install: C{callable}
|
||||
"""
|
||||
reactors.Reactor.__init__(self, *args, **kwargs)
|
||||
self.install = install
|
||||
|
||||
|
||||
|
||||
class PluggableReactorTestCase(TwistedModulesMixin, unittest.TestCase):
|
||||
"""
|
||||
Tests for the reactor discovery/inspection APIs.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Override the L{reactors.getPlugins} function, normally bound to
|
||||
L{twisted.plugin.getPlugins}, in order to control which
|
||||
L{IReactorInstaller} plugins are seen as available.
|
||||
|
||||
C{self.pluginResults} can be customized and will be used as the
|
||||
result of calls to C{reactors.getPlugins}.
|
||||
"""
|
||||
self.pluginCalls = []
|
||||
self.pluginResults = []
|
||||
self.originalFunction = reactors.getPlugins
|
||||
reactors.getPlugins = self._getPlugins
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
"""
|
||||
Restore the original L{reactors.getPlugins}.
|
||||
"""
|
||||
reactors.getPlugins = self.originalFunction
|
||||
|
||||
|
||||
def _getPlugins(self, interface, package=None):
|
||||
"""
|
||||
Stand-in for the real getPlugins method which records its arguments
|
||||
and returns a fixed result.
|
||||
"""
|
||||
self.pluginCalls.append((interface, package))
|
||||
return list(self.pluginResults)
|
||||
|
||||
|
||||
def test_getPluginReactorTypes(self):
|
||||
"""
|
||||
Test that reactor plugins are returned from L{getReactorTypes}
|
||||
"""
|
||||
name = 'fakereactortest'
|
||||
package = __name__ + '.fakereactor'
|
||||
description = 'description'
|
||||
self.pluginResults = [reactors.Reactor(name, package, description)]
|
||||
reactorTypes = reactors.getReactorTypes()
|
||||
|
||||
self.assertEqual(
|
||||
self.pluginCalls,
|
||||
[(reactors.IReactorInstaller, None)])
|
||||
|
||||
for r in reactorTypes:
|
||||
if r.shortName == name:
|
||||
self.assertEqual(r.description, description)
|
||||
break
|
||||
else:
|
||||
self.fail("Reactor plugin not present in getReactorTypes() result")
|
||||
|
||||
|
||||
def test_reactorInstallation(self):
|
||||
"""
|
||||
Test that L{reactors.Reactor.install} loads the correct module and
|
||||
calls its install attribute.
|
||||
"""
|
||||
installed = []
|
||||
def install():
|
||||
installed.append(True)
|
||||
fakeReactor = FakeReactor(install,
|
||||
'fakereactortest', __name__, 'described')
|
||||
modules = {'fakereactortest': fakeReactor}
|
||||
self.replaceSysModules(modules)
|
||||
installer = reactors.Reactor('fakereactor', 'fakereactortest', 'described')
|
||||
installer.install()
|
||||
self.assertEqual(installed, [True])
|
||||
|
||||
|
||||
def test_installReactor(self):
|
||||
"""
|
||||
Test that the L{reactors.installReactor} function correctly installs
|
||||
the specified reactor.
|
||||
"""
|
||||
installed = []
|
||||
def install():
|
||||
installed.append(True)
|
||||
name = 'fakereactortest'
|
||||
package = __name__
|
||||
description = 'description'
|
||||
self.pluginResults = [FakeReactor(install, name, package, description)]
|
||||
reactors.installReactor(name)
|
||||
self.assertEqual(installed, [True])
|
||||
|
||||
|
||||
def test_installReactorReturnsReactor(self):
|
||||
"""
|
||||
Test that the L{reactors.installReactor} function correctly returns
|
||||
the installed reactor.
|
||||
"""
|
||||
reactor = object()
|
||||
def install():
|
||||
from twisted import internet
|
||||
self.patch(internet, 'reactor', reactor)
|
||||
name = 'fakereactortest'
|
||||
package = __name__
|
||||
description = 'description'
|
||||
self.pluginResults = [FakeReactor(install, name, package, description)]
|
||||
installed = reactors.installReactor(name)
|
||||
self.assertIdentical(installed, reactor)
|
||||
|
||||
|
||||
def test_installReactorMultiplePlugins(self):
|
||||
"""
|
||||
Test that the L{reactors.installReactor} function correctly installs
|
||||
the specified reactor when there are multiple reactor plugins.
|
||||
"""
|
||||
installed = []
|
||||
def install():
|
||||
installed.append(True)
|
||||
name = 'fakereactortest'
|
||||
package = __name__
|
||||
description = 'description'
|
||||
fakeReactor = FakeReactor(install, name, package, description)
|
||||
otherReactor = FakeReactor(lambda: None,
|
||||
"otherreactor", package, description)
|
||||
self.pluginResults = [otherReactor, fakeReactor]
|
||||
reactors.installReactor(name)
|
||||
self.assertEqual(installed, [True])
|
||||
|
||||
|
||||
def test_installNonExistentReactor(self):
|
||||
"""
|
||||
Test that L{reactors.installReactor} raises L{reactors.NoSuchReactor}
|
||||
when asked to install a reactor which it cannot find.
|
||||
"""
|
||||
self.pluginResults = []
|
||||
self.assertRaises(
|
||||
reactors.NoSuchReactor,
|
||||
reactors.installReactor, 'somereactor')
|
||||
|
||||
|
||||
def test_installNotAvailableReactor(self):
|
||||
"""
|
||||
Test that L{reactors.installReactor} raises an exception when asked to
|
||||
install a reactor which doesn't work in this environment.
|
||||
"""
|
||||
def install():
|
||||
raise ImportError("Missing foo bar")
|
||||
name = 'fakereactortest'
|
||||
package = __name__
|
||||
description = 'description'
|
||||
self.pluginResults = [FakeReactor(install, name, package, description)]
|
||||
self.assertRaises(ImportError, reactors.installReactor, name)
|
||||
|
||||
|
||||
def test_reactorSelectionMixin(self):
|
||||
"""
|
||||
Test that the reactor selected is installed as soon as possible, ie
|
||||
when the option is parsed.
|
||||
"""
|
||||
executed = []
|
||||
INSTALL_EVENT = 'reactor installed'
|
||||
SUBCOMMAND_EVENT = 'subcommands loaded'
|
||||
|
||||
class ReactorSelectionOptions(usage.Options, app.ReactorSelectionMixin):
|
||||
def subCommands(self):
|
||||
executed.append(SUBCOMMAND_EVENT)
|
||||
return [('subcommand', None, lambda: self, 'test subcommand')]
|
||||
subCommands = property(subCommands)
|
||||
|
||||
def install():
|
||||
executed.append(INSTALL_EVENT)
|
||||
self.pluginResults = [
|
||||
FakeReactor(install, 'fakereactortest', __name__, 'described')
|
||||
]
|
||||
|
||||
options = ReactorSelectionOptions()
|
||||
options.parseOptions(['--reactor', 'fakereactortest', 'subcommand'])
|
||||
self.assertEqual(executed[0], INSTALL_EVENT)
|
||||
self.assertEqual(executed.count(INSTALL_EVENT), 1)
|
||||
self.assertEqual(options["reactor"], "fakereactortest")
|
||||
|
||||
|
||||
def test_reactorSelectionMixinNonExistent(self):
|
||||
"""
|
||||
Test that the usage mixin exits when trying to use a non existent
|
||||
reactor (the name not matching to any reactor), giving an error
|
||||
message.
|
||||
"""
|
||||
class ReactorSelectionOptions(usage.Options, app.ReactorSelectionMixin):
|
||||
pass
|
||||
self.pluginResults = []
|
||||
|
||||
options = ReactorSelectionOptions()
|
||||
options.messageOutput = StringIO()
|
||||
e = self.assertRaises(usage.UsageError, options.parseOptions,
|
||||
['--reactor', 'fakereactortest', 'subcommand'])
|
||||
self.assertIn("fakereactortest", e.args[0])
|
||||
self.assertIn("help-reactors", e.args[0])
|
||||
|
||||
|
||||
def test_reactorSelectionMixinNotAvailable(self):
|
||||
"""
|
||||
Test that the usage mixin exits when trying to use a reactor not
|
||||
available (the reactor raises an error at installation), giving an
|
||||
error message.
|
||||
"""
|
||||
class ReactorSelectionOptions(usage.Options, app.ReactorSelectionMixin):
|
||||
pass
|
||||
message = "Missing foo bar"
|
||||
def install():
|
||||
raise ImportError(message)
|
||||
|
||||
name = 'fakereactortest'
|
||||
package = __name__
|
||||
description = 'description'
|
||||
self.pluginResults = [FakeReactor(install, name, package, description)]
|
||||
|
||||
options = ReactorSelectionOptions()
|
||||
options.messageOutput = StringIO()
|
||||
e = self.assertRaises(usage.UsageError, options.parseOptions,
|
||||
['--reactor', 'fakereactortest', 'subcommand'])
|
||||
self.assertIn(message, e.args[0])
|
||||
self.assertIn("help-reactors", e.args[0])
|
||||
|
|
@ -0,0 +1,278 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
import StringIO
|
||||
import sys
|
||||
|
||||
# Twisted Imports
|
||||
from twisted.trial import unittest
|
||||
from twisted.spread import banana
|
||||
from twisted.python import failure
|
||||
from twisted.internet import protocol, main
|
||||
|
||||
|
||||
class MathTestCase(unittest.TestCase):
|
||||
def testInt2b128(self):
|
||||
funkylist = range(0,100) + range(1000,1100) + range(1000000,1000100) + [1024 **10l]
|
||||
for i in funkylist:
|
||||
x = StringIO.StringIO()
|
||||
banana.int2b128(i, x.write)
|
||||
v = x.getvalue()
|
||||
y = banana.b1282int(v)
|
||||
assert y == i, "y = %s; i = %s" % (y,i)
|
||||
|
||||
class BananaTestCase(unittest.TestCase):
|
||||
|
||||
encClass = banana.Banana
|
||||
|
||||
def setUp(self):
|
||||
self.io = StringIO.StringIO()
|
||||
self.enc = self.encClass()
|
||||
self.enc.makeConnection(protocol.FileWrapper(self.io))
|
||||
self.enc._selectDialect("none")
|
||||
self.enc.expressionReceived = self.putResult
|
||||
|
||||
def putResult(self, result):
|
||||
self.result = result
|
||||
|
||||
def tearDown(self):
|
||||
self.enc.connectionLost(failure.Failure(main.CONNECTION_DONE))
|
||||
del self.enc
|
||||
|
||||
def testString(self):
|
||||
self.enc.sendEncoded("hello")
|
||||
l = []
|
||||
self.enc.dataReceived(self.io.getvalue())
|
||||
assert self.result == 'hello'
|
||||
|
||||
def test_int(self):
|
||||
"""
|
||||
A positive integer less than 2 ** 32 should round-trip through
|
||||
banana without changing value and should come out represented
|
||||
as an C{int} (regardless of the type which was encoded).
|
||||
"""
|
||||
for value in (10151, 10151L):
|
||||
self.enc.sendEncoded(value)
|
||||
self.enc.dataReceived(self.io.getvalue())
|
||||
self.assertEqual(self.result, 10151)
|
||||
self.assertIsInstance(self.result, int)
|
||||
|
||||
|
||||
def test_largeLong(self):
|
||||
"""
|
||||
Integers greater than 2 ** 32 and less than -2 ** 32 should
|
||||
round-trip through banana without changing value and should
|
||||
come out represented as C{int} instances if the value fits
|
||||
into that type on the receiving platform.
|
||||
"""
|
||||
for exp in (32, 64, 128, 256):
|
||||
for add in (0, 1):
|
||||
m = 2 ** exp + add
|
||||
for n in (m, -m-1):
|
||||
self.io.truncate(0)
|
||||
self.enc.sendEncoded(n)
|
||||
self.enc.dataReceived(self.io.getvalue())
|
||||
self.assertEqual(self.result, n)
|
||||
if n > sys.maxint or n < -sys.maxint - 1:
|
||||
self.assertIsInstance(self.result, long)
|
||||
else:
|
||||
self.assertIsInstance(self.result, int)
|
||||
|
||||
|
||||
def _getSmallest(self):
|
||||
# How many bytes of prefix our implementation allows
|
||||
bytes = self.enc.prefixLimit
|
||||
# How many useful bits we can extract from that based on Banana's
|
||||
# base-128 representation.
|
||||
bits = bytes * 7
|
||||
# The largest number we _should_ be able to encode
|
||||
largest = 2 ** bits - 1
|
||||
# The smallest number we _shouldn't_ be able to encode
|
||||
smallest = largest + 1
|
||||
return smallest
|
||||
|
||||
|
||||
def test_encodeTooLargeLong(self):
|
||||
"""
|
||||
Test that a long above the implementation-specific limit is rejected
|
||||
as too large to be encoded.
|
||||
"""
|
||||
smallest = self._getSmallest()
|
||||
self.assertRaises(banana.BananaError, self.enc.sendEncoded, smallest)
|
||||
|
||||
|
||||
def test_decodeTooLargeLong(self):
|
||||
"""
|
||||
Test that a long above the implementation specific limit is rejected
|
||||
as too large to be decoded.
|
||||
"""
|
||||
smallest = self._getSmallest()
|
||||
self.enc.setPrefixLimit(self.enc.prefixLimit * 2)
|
||||
self.enc.sendEncoded(smallest)
|
||||
encoded = self.io.getvalue()
|
||||
self.io.truncate(0)
|
||||
self.enc.setPrefixLimit(self.enc.prefixLimit // 2)
|
||||
|
||||
self.assertRaises(banana.BananaError, self.enc.dataReceived, encoded)
|
||||
|
||||
|
||||
def _getLargest(self):
|
||||
return -self._getSmallest()
|
||||
|
||||
|
||||
def test_encodeTooSmallLong(self):
|
||||
"""
|
||||
Test that a negative long below the implementation-specific limit is
|
||||
rejected as too small to be encoded.
|
||||
"""
|
||||
largest = self._getLargest()
|
||||
self.assertRaises(banana.BananaError, self.enc.sendEncoded, largest)
|
||||
|
||||
|
||||
def test_decodeTooSmallLong(self):
|
||||
"""
|
||||
Test that a negative long below the implementation specific limit is
|
||||
rejected as too small to be decoded.
|
||||
"""
|
||||
largest = self._getLargest()
|
||||
self.enc.setPrefixLimit(self.enc.prefixLimit * 2)
|
||||
self.enc.sendEncoded(largest)
|
||||
encoded = self.io.getvalue()
|
||||
self.io.truncate(0)
|
||||
self.enc.setPrefixLimit(self.enc.prefixLimit // 2)
|
||||
|
||||
self.assertRaises(banana.BananaError, self.enc.dataReceived, encoded)
|
||||
|
||||
|
||||
def testNegativeLong(self):
|
||||
self.enc.sendEncoded(-1015l)
|
||||
self.enc.dataReceived(self.io.getvalue())
|
||||
assert self.result == -1015l, "should be -1015l, got %s" % self.result
|
||||
|
||||
def testInteger(self):
|
||||
self.enc.sendEncoded(1015)
|
||||
self.enc.dataReceived(self.io.getvalue())
|
||||
assert self.result == 1015, "should be 1015, got %s" % self.result
|
||||
|
||||
def testNegative(self):
|
||||
self.enc.sendEncoded(-1015)
|
||||
self.enc.dataReceived(self.io.getvalue())
|
||||
assert self.result == -1015, "should be -1015, got %s" % self.result
|
||||
|
||||
def testFloat(self):
|
||||
self.enc.sendEncoded(1015.)
|
||||
self.enc.dataReceived(self.io.getvalue())
|
||||
assert self.result == 1015.
|
||||
|
||||
def testList(self):
|
||||
foo = [1, 2, [3, 4], [30.5, 40.2], 5, ["six", "seven", ["eight", 9]], [10], []]
|
||||
self.enc.sendEncoded(foo)
|
||||
self.enc.dataReceived(self.io.getvalue())
|
||||
assert self.result == foo, "%s!=%s" % (repr(self.result), repr(self.result))
|
||||
|
||||
def testPartial(self):
|
||||
foo = [1, 2, [3, 4], [30.5, 40.2], 5,
|
||||
["six", "seven", ["eight", 9]], [10],
|
||||
# TODO: currently the C implementation's a bit buggy...
|
||||
sys.maxint * 3l, sys.maxint * 2l, sys.maxint * -2l]
|
||||
self.enc.sendEncoded(foo)
|
||||
for byte in self.io.getvalue():
|
||||
self.enc.dataReceived(byte)
|
||||
assert self.result == foo, "%s!=%s" % (repr(self.result), repr(foo))
|
||||
|
||||
def feed(self, data):
|
||||
for byte in data:
|
||||
self.enc.dataReceived(byte)
|
||||
def testOversizedList(self):
|
||||
data = '\x02\x01\x01\x01\x01\x80'
|
||||
# list(size=0x0101010102, about 4.3e9)
|
||||
self.failUnlessRaises(banana.BananaError, self.feed, data)
|
||||
def testOversizedString(self):
|
||||
data = '\x02\x01\x01\x01\x01\x82'
|
||||
# string(size=0x0101010102, about 4.3e9)
|
||||
self.failUnlessRaises(banana.BananaError, self.feed, data)
|
||||
|
||||
def testCrashString(self):
|
||||
crashString = '\x00\x00\x00\x00\x04\x80'
|
||||
# string(size=0x0400000000, about 17.2e9)
|
||||
|
||||
# cBanana would fold that into a 32-bit 'int', then try to allocate
|
||||
# a list with PyList_New(). cBanana ignored the NULL return value,
|
||||
# so it would segfault when trying to free the imaginary list.
|
||||
|
||||
# This variant doesn't segfault straight out in my environment.
|
||||
# Instead, it takes up large amounts of CPU and memory...
|
||||
#crashString = '\x00\x00\x00\x00\x01\x80'
|
||||
# print repr(crashString)
|
||||
#self.failUnlessRaises(Exception, self.enc.dataReceived, crashString)
|
||||
try:
|
||||
# should now raise MemoryError
|
||||
self.enc.dataReceived(crashString)
|
||||
except banana.BananaError:
|
||||
pass
|
||||
|
||||
def testCrashNegativeLong(self):
|
||||
# There was a bug in cBanana which relied on negating a negative integer
|
||||
# always giving a postive result, but for the lowest possible number in
|
||||
# 2s-complement arithmetic, that's not true, i.e.
|
||||
# long x = -2147483648;
|
||||
# long y = -x;
|
||||
# x == y; /* true! */
|
||||
# (assuming 32-bit longs)
|
||||
self.enc.sendEncoded(-2147483648)
|
||||
self.enc.dataReceived(self.io.getvalue())
|
||||
assert self.result == -2147483648, "should be -2147483648, got %s" % self.result
|
||||
|
||||
|
||||
def test_sizedIntegerTypes(self):
|
||||
"""
|
||||
Test that integers below the maximum C{INT} token size cutoff are
|
||||
serialized as C{INT} or C{NEG} and that larger integers are
|
||||
serialized as C{LONGINT} or C{LONGNEG}.
|
||||
"""
|
||||
def encoded(n):
|
||||
self.io.seek(0)
|
||||
self.io.truncate()
|
||||
self.enc.sendEncoded(n)
|
||||
return self.io.getvalue()
|
||||
|
||||
baseIntIn = +2147483647
|
||||
baseNegIn = -2147483648
|
||||
|
||||
baseIntOut = '\x7f\x7f\x7f\x07\x81'
|
||||
self.assertEqual(encoded(baseIntIn - 2), '\x7d' + baseIntOut)
|
||||
self.assertEqual(encoded(baseIntIn - 1), '\x7e' + baseIntOut)
|
||||
self.assertEqual(encoded(baseIntIn - 0), '\x7f' + baseIntOut)
|
||||
|
||||
baseLongIntOut = '\x00\x00\x00\x08\x85'
|
||||
self.assertEqual(encoded(baseIntIn + 1), '\x00' + baseLongIntOut)
|
||||
self.assertEqual(encoded(baseIntIn + 2), '\x01' + baseLongIntOut)
|
||||
self.assertEqual(encoded(baseIntIn + 3), '\x02' + baseLongIntOut)
|
||||
|
||||
baseNegOut = '\x7f\x7f\x7f\x07\x83'
|
||||
self.assertEqual(encoded(baseNegIn + 2), '\x7e' + baseNegOut)
|
||||
self.assertEqual(encoded(baseNegIn + 1), '\x7f' + baseNegOut)
|
||||
self.assertEqual(encoded(baseNegIn + 0), '\x00\x00\x00\x00\x08\x83')
|
||||
|
||||
baseLongNegOut = '\x00\x00\x00\x08\x86'
|
||||
self.assertEqual(encoded(baseNegIn - 1), '\x01' + baseLongNegOut)
|
||||
self.assertEqual(encoded(baseNegIn - 2), '\x02' + baseLongNegOut)
|
||||
self.assertEqual(encoded(baseNegIn - 3), '\x03' + baseLongNegOut)
|
||||
|
||||
|
||||
|
||||
class GlobalCoderTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for the free functions L{banana.encode} and L{banana.decode}.
|
||||
"""
|
||||
def test_statelessDecode(self):
|
||||
"""
|
||||
Test that state doesn't carry over between calls to L{banana.decode}.
|
||||
"""
|
||||
# Banana encoding of 2 ** 449
|
||||
undecodable = '\x7f' * 65 + '\x85'
|
||||
self.assertRaises(banana.BananaError, banana.decode, undecodable)
|
||||
|
||||
# Banana encoding of 1
|
||||
decodable = '\x01\x81'
|
||||
self.assertEqual(banana.decode(decodable), 1)
|
||||
|
|
@ -0,0 +1,623 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""
|
||||
Tests for L{twisted.python.compat}.
|
||||
"""
|
||||
|
||||
from __future__ import division, absolute_import
|
||||
|
||||
import socket, sys, traceback
|
||||
|
||||
from twisted.trial import unittest
|
||||
|
||||
from twisted.python.compat import reduce, execfile, _PY3
|
||||
from twisted.python.compat import comparable, cmp, nativeString, networkString
|
||||
from twisted.python.compat import unicode as unicodeCompat, lazyByteSlice
|
||||
from twisted.python.compat import reraise, NativeStringIO, iterbytes, intToBytes
|
||||
from twisted.python.filepath import FilePath
|
||||
|
||||
|
||||
class CompatTestCase(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Various utility functions in C{twisted.python.compat} provide same
|
||||
functionality as modern Python variants.
|
||||
"""
|
||||
|
||||
def test_set(self):
|
||||
"""
|
||||
L{set} should behave like the expected set interface.
|
||||
"""
|
||||
a = set()
|
||||
a.add('b')
|
||||
a.add('c')
|
||||
a.add('a')
|
||||
b = list(a)
|
||||
b.sort()
|
||||
self.assertEqual(b, ['a', 'b', 'c'])
|
||||
a.remove('b')
|
||||
b = list(a)
|
||||
b.sort()
|
||||
self.assertEqual(b, ['a', 'c'])
|
||||
|
||||
a.discard('d')
|
||||
|
||||
b = set(['r', 's'])
|
||||
d = a.union(b)
|
||||
b = list(d)
|
||||
b.sort()
|
||||
self.assertEqual(b, ['a', 'c', 'r', 's'])
|
||||
|
||||
|
||||
def test_frozenset(self):
|
||||
"""
|
||||
L{frozenset} should behave like the expected frozenset interface.
|
||||
"""
|
||||
a = frozenset(['a', 'b'])
|
||||
self.assertRaises(AttributeError, getattr, a, "add")
|
||||
self.assertEqual(sorted(a), ['a', 'b'])
|
||||
|
||||
b = frozenset(['r', 's'])
|
||||
d = a.union(b)
|
||||
b = list(d)
|
||||
b.sort()
|
||||
self.assertEqual(b, ['a', 'b', 'r', 's'])
|
||||
|
||||
|
||||
def test_reduce(self):
|
||||
"""
|
||||
L{reduce} should behave like the builtin reduce.
|
||||
"""
|
||||
self.assertEqual(15, reduce(lambda x, y: x + y, [1, 2, 3, 4, 5]))
|
||||
self.assertEqual(16, reduce(lambda x, y: x + y, [1, 2, 3, 4, 5], 1))
|
||||
|
||||
|
||||
|
||||
class IPv6Tests(unittest.SynchronousTestCase):
|
||||
"""
|
||||
C{inet_pton} and C{inet_ntop} implementations support IPv6.
|
||||
"""
|
||||
|
||||
def testNToP(self):
|
||||
from twisted.python.compat import inet_ntop
|
||||
|
||||
f = lambda a: inet_ntop(socket.AF_INET6, a)
|
||||
g = lambda a: inet_ntop(socket.AF_INET, a)
|
||||
|
||||
self.assertEqual('::', f('\x00' * 16))
|
||||
self.assertEqual('::1', f('\x00' * 15 + '\x01'))
|
||||
self.assertEqual(
|
||||
'aef:b01:506:1001:ffff:9997:55:170',
|
||||
f('\x0a\xef\x0b\x01\x05\x06\x10\x01\xff\xff\x99\x97\x00\x55\x01\x70'))
|
||||
|
||||
self.assertEqual('1.0.1.0', g('\x01\x00\x01\x00'))
|
||||
self.assertEqual('170.85.170.85', g('\xaa\x55\xaa\x55'))
|
||||
self.assertEqual('255.255.255.255', g('\xff\xff\xff\xff'))
|
||||
|
||||
self.assertEqual('100::', f('\x01' + '\x00' * 15))
|
||||
self.assertEqual('100::1', f('\x01' + '\x00' * 14 + '\x01'))
|
||||
|
||||
def testPToN(self):
|
||||
from twisted.python.compat import inet_pton
|
||||
|
||||
f = lambda a: inet_pton(socket.AF_INET6, a)
|
||||
g = lambda a: inet_pton(socket.AF_INET, a)
|
||||
|
||||
self.assertEqual('\x00\x00\x00\x00', g('0.0.0.0'))
|
||||
self.assertEqual('\xff\x00\xff\x00', g('255.0.255.0'))
|
||||
self.assertEqual('\xaa\xaa\xaa\xaa', g('170.170.170.170'))
|
||||
|
||||
self.assertEqual('\x00' * 16, f('::'))
|
||||
self.assertEqual('\x00' * 16, f('0::0'))
|
||||
self.assertEqual('\x00\x01' + '\x00' * 14, f('1::'))
|
||||
self.assertEqual(
|
||||
'\x45\xef\x76\xcb\x00\x1a\x56\xef\xaf\xeb\x0b\xac\x19\x24\xae\xae',
|
||||
f('45ef:76cb:1a:56ef:afeb:bac:1924:aeae'))
|
||||
|
||||
self.assertEqual('\x00' * 14 + '\x00\x01', f('::1'))
|
||||
self.assertEqual('\x00' * 12 + '\x01\x02\x03\x04', f('::1.2.3.4'))
|
||||
self.assertEqual(
|
||||
'\x00\x01\x00\x02\x00\x03\x00\x04\x00\x05\x00\x06\x01\x02\x03\xff',
|
||||
f('1:2:3:4:5:6:1.2.3.255'))
|
||||
|
||||
for badaddr in ['1:2:3:4:5:6:7:8:', ':1:2:3:4:5:6:7:8', '1::2::3',
|
||||
'1:::3', ':::', '1:2', '::1.2', '1.2.3.4::',
|
||||
'abcd:1.2.3.4:abcd:abcd:abcd:abcd:abcd',
|
||||
'1234:1.2.3.4:1234:1234:1234:1234:1234:1234',
|
||||
'1.2.3.4']:
|
||||
self.assertRaises(ValueError, f, badaddr)
|
||||
|
||||
if _PY3:
|
||||
IPv6Tests.skip = "These tests are only relevant to old versions of Python"
|
||||
|
||||
|
||||
|
||||
class ExecfileCompatTestCase(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Tests for the Python 3-friendly L{execfile} implementation.
|
||||
"""
|
||||
|
||||
def writeScript(self, content):
|
||||
"""
|
||||
Write L{content} to a new temporary file, returning the L{FilePath}
|
||||
for the new file.
|
||||
"""
|
||||
path = self.mktemp()
|
||||
with open(path, "wb") as f:
|
||||
f.write(content.encode("ascii"))
|
||||
return FilePath(path.encode("utf-8"))
|
||||
|
||||
|
||||
def test_execfileGlobals(self):
|
||||
"""
|
||||
L{execfile} executes the specified file in the given global namespace.
|
||||
"""
|
||||
script = self.writeScript(u"foo += 1\n")
|
||||
globalNamespace = {"foo": 1}
|
||||
execfile(script.path, globalNamespace)
|
||||
self.assertEqual(2, globalNamespace["foo"])
|
||||
|
||||
|
||||
def test_execfileGlobalsAndLocals(self):
|
||||
"""
|
||||
L{execfile} executes the specified file in the given global and local
|
||||
namespaces.
|
||||
"""
|
||||
script = self.writeScript(u"foo += 1\n")
|
||||
globalNamespace = {"foo": 10}
|
||||
localNamespace = {"foo": 20}
|
||||
execfile(script.path, globalNamespace, localNamespace)
|
||||
self.assertEqual(10, globalNamespace["foo"])
|
||||
self.assertEqual(21, localNamespace["foo"])
|
||||
|
||||
|
||||
def test_execfileUniversalNewlines(self):
|
||||
"""
|
||||
L{execfile} reads in the specified file using universal newlines so
|
||||
that scripts written on one platform will work on another.
|
||||
"""
|
||||
for lineEnding in u"\n", u"\r", u"\r\n":
|
||||
script = self.writeScript(u"foo = 'okay'" + lineEnding)
|
||||
globalNamespace = {"foo": None}
|
||||
execfile(script.path, globalNamespace)
|
||||
self.assertEqual("okay", globalNamespace["foo"])
|
||||
|
||||
|
||||
|
||||
class PY3Tests(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Identification of Python 2 vs. Python 3.
|
||||
"""
|
||||
|
||||
def test_python2(self):
|
||||
"""
|
||||
On Python 2, C{_PY3} is False.
|
||||
"""
|
||||
if sys.version.startswith("2."):
|
||||
self.assertFalse(_PY3)
|
||||
|
||||
|
||||
def test_python3(self):
|
||||
"""
|
||||
On Python 3, C{_PY3} is True.
|
||||
"""
|
||||
if sys.version.startswith("3."):
|
||||
self.assertTrue(_PY3)
|
||||
|
||||
|
||||
|
||||
@comparable
|
||||
class Comparable(object):
|
||||
"""
|
||||
Objects that can be compared to each other, but not others.
|
||||
"""
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
|
||||
def __cmp__(self, other):
|
||||
if not isinstance(other, Comparable):
|
||||
return NotImplemented
|
||||
return cmp(self.value, other.value)
|
||||
|
||||
|
||||
|
||||
class ComparableTests(unittest.SynchronousTestCase):
|
||||
"""
|
||||
L{comparable} decorated classes emulate Python 2's C{__cmp__} semantics.
|
||||
"""
|
||||
|
||||
def test_equality(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
equality comparisons.
|
||||
"""
|
||||
# Make explicitly sure we're using ==:
|
||||
self.assertTrue(Comparable(1) == Comparable(1))
|
||||
self.assertFalse(Comparable(2) == Comparable(1))
|
||||
|
||||
|
||||
def test_nonEquality(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
inequality comparisons.
|
||||
"""
|
||||
# Make explicitly sure we're using !=:
|
||||
self.assertFalse(Comparable(1) != Comparable(1))
|
||||
self.assertTrue(Comparable(2) != Comparable(1))
|
||||
|
||||
|
||||
def test_greaterThan(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
greater-than comparisons.
|
||||
"""
|
||||
self.assertTrue(Comparable(2) > Comparable(1))
|
||||
self.assertFalse(Comparable(0) > Comparable(3))
|
||||
|
||||
|
||||
def test_greaterThanOrEqual(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
greater-than-or-equal comparisons.
|
||||
"""
|
||||
self.assertTrue(Comparable(1) >= Comparable(1))
|
||||
self.assertTrue(Comparable(2) >= Comparable(1))
|
||||
self.assertFalse(Comparable(0) >= Comparable(3))
|
||||
|
||||
|
||||
def test_lessThan(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
less-than comparisons.
|
||||
"""
|
||||
self.assertTrue(Comparable(0) < Comparable(3))
|
||||
self.assertFalse(Comparable(2) < Comparable(0))
|
||||
|
||||
|
||||
def test_lessThanOrEqual(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
less-than-or-equal comparisons.
|
||||
"""
|
||||
self.assertTrue(Comparable(3) <= Comparable(3))
|
||||
self.assertTrue(Comparable(0) <= Comparable(3))
|
||||
self.assertFalse(Comparable(2) <= Comparable(0))
|
||||
|
||||
|
||||
|
||||
class Python3ComparableTests(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Python 3-specific functionality of C{comparable}.
|
||||
"""
|
||||
|
||||
def test_notImplementedEquals(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
returning C{NotImplemented} from C{__eq__} if it is returned by the
|
||||
underlying C{__cmp__} call.
|
||||
"""
|
||||
self.assertEqual(Comparable(1).__eq__(object()), NotImplemented)
|
||||
|
||||
|
||||
def test_notImplementedNotEquals(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
returning C{NotImplemented} from C{__ne__} if it is returned by the
|
||||
underlying C{__cmp__} call.
|
||||
"""
|
||||
self.assertEqual(Comparable(1).__ne__(object()), NotImplemented)
|
||||
|
||||
|
||||
def test_notImplementedGreaterThan(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
returning C{NotImplemented} from C{__gt__} if it is returned by the
|
||||
underlying C{__cmp__} call.
|
||||
"""
|
||||
self.assertEqual(Comparable(1).__gt__(object()), NotImplemented)
|
||||
|
||||
|
||||
def test_notImplementedLessThan(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
returning C{NotImplemented} from C{__lt__} if it is returned by the
|
||||
underlying C{__cmp__} call.
|
||||
"""
|
||||
self.assertEqual(Comparable(1).__lt__(object()), NotImplemented)
|
||||
|
||||
|
||||
def test_notImplementedGreaterThanEquals(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
returning C{NotImplemented} from C{__ge__} if it is returned by the
|
||||
underlying C{__cmp__} call.
|
||||
"""
|
||||
self.assertEqual(Comparable(1).__ge__(object()), NotImplemented)
|
||||
|
||||
|
||||
def test_notImplementedLessThanEquals(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
returning C{NotImplemented} from C{__le__} if it is returned by the
|
||||
underlying C{__cmp__} call.
|
||||
"""
|
||||
self.assertEqual(Comparable(1).__le__(object()), NotImplemented)
|
||||
|
||||
if not _PY3:
|
||||
# On Python 2, we just use __cmp__ directly, so checking detailed
|
||||
# comparison methods doesn't makes sense.
|
||||
Python3ComparableTests.skip = "Python 3 only."
|
||||
|
||||
|
||||
|
||||
class CmpTests(unittest.SynchronousTestCase):
|
||||
"""
|
||||
L{cmp} should behave like the built-in Python 2 C{cmp}.
|
||||
"""
|
||||
|
||||
def test_equals(self):
|
||||
"""
|
||||
L{cmp} returns 0 for equal objects.
|
||||
"""
|
||||
self.assertEqual(cmp(u"a", u"a"), 0)
|
||||
self.assertEqual(cmp(1, 1), 0)
|
||||
self.assertEqual(cmp([1], [1]), 0)
|
||||
|
||||
|
||||
def test_greaterThan(self):
|
||||
"""
|
||||
L{cmp} returns 1 if its first argument is bigger than its second.
|
||||
"""
|
||||
self.assertEqual(cmp(4, 0), 1)
|
||||
self.assertEqual(cmp(b"z", b"a"), 1)
|
||||
|
||||
|
||||
def test_lessThan(self):
|
||||
"""
|
||||
L{cmp} returns -1 if its first argument is smaller than its second.
|
||||
"""
|
||||
self.assertEqual(cmp(0.1, 2.3), -1)
|
||||
self.assertEqual(cmp(b"a", b"d"), -1)
|
||||
|
||||
|
||||
|
||||
class StringTests(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Compatibility functions and types for strings.
|
||||
"""
|
||||
|
||||
def assertNativeString(self, original, expected):
|
||||
"""
|
||||
Raise an exception indicating a failed test if the output of
|
||||
C{nativeString(original)} is unequal to the expected string, or is not
|
||||
a native string.
|
||||
"""
|
||||
self.assertEqual(nativeString(original), expected)
|
||||
self.assertIsInstance(nativeString(original), str)
|
||||
|
||||
|
||||
def test_nonASCIIBytesToString(self):
|
||||
"""
|
||||
C{nativeString} raises a C{UnicodeError} if input bytes are not ASCII
|
||||
decodable.
|
||||
"""
|
||||
self.assertRaises(UnicodeError, nativeString, b"\xFF")
|
||||
|
||||
|
||||
def test_nonASCIIUnicodeToString(self):
|
||||
"""
|
||||
C{nativeString} raises a C{UnicodeError} if input Unicode is not ASCII
|
||||
encodable.
|
||||
"""
|
||||
self.assertRaises(UnicodeError, nativeString, u"\u1234")
|
||||
|
||||
|
||||
def test_bytesToString(self):
|
||||
"""
|
||||
C{nativeString} converts bytes to the native string format, assuming
|
||||
an ASCII encoding if applicable.
|
||||
"""
|
||||
self.assertNativeString(b"hello", "hello")
|
||||
|
||||
|
||||
def test_unicodeToString(self):
|
||||
"""
|
||||
C{nativeString} converts unicode to the native string format, assuming
|
||||
an ASCII encoding if applicable.
|
||||
"""
|
||||
self.assertNativeString(u"Good day", "Good day")
|
||||
|
||||
|
||||
def test_stringToString(self):
|
||||
"""
|
||||
C{nativeString} leaves native strings as native strings.
|
||||
"""
|
||||
self.assertNativeString("Hello!", "Hello!")
|
||||
|
||||
|
||||
def test_unexpectedType(self):
|
||||
"""
|
||||
C{nativeString} raises a C{TypeError} if given an object that is not a
|
||||
string of some sort.
|
||||
"""
|
||||
self.assertRaises(TypeError, nativeString, 1)
|
||||
|
||||
|
||||
def test_unicode(self):
|
||||
"""
|
||||
C{compat.unicode} is C{str} on Python 3, C{unicode} on Python 2.
|
||||
"""
|
||||
if _PY3:
|
||||
expected = str
|
||||
else:
|
||||
expected = unicode
|
||||
self.assertTrue(unicodeCompat is expected)
|
||||
|
||||
|
||||
def test_nativeStringIO(self):
|
||||
"""
|
||||
L{NativeStringIO} is a file-like object that stores native strings in
|
||||
memory.
|
||||
"""
|
||||
f = NativeStringIO()
|
||||
f.write("hello")
|
||||
f.write(" there")
|
||||
self.assertEqual(f.getvalue(), "hello there")
|
||||
|
||||
|
||||
|
||||
class NetworkStringTests(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Tests for L{networkString}.
|
||||
"""
|
||||
def test_bytes(self):
|
||||
"""
|
||||
L{networkString} returns a C{bytes} object passed to it unmodified.
|
||||
"""
|
||||
self.assertEqual(b"foo", networkString(b"foo"))
|
||||
|
||||
|
||||
def test_bytesOutOfRange(self):
|
||||
"""
|
||||
L{networkString} raises C{UnicodeError} if passed a C{bytes} instance
|
||||
containing bytes not used by ASCII.
|
||||
"""
|
||||
self.assertRaises(
|
||||
UnicodeError, networkString, u"\N{SNOWMAN}".encode('utf-8'))
|
||||
if _PY3:
|
||||
test_bytes.skip = test_bytesOutOfRange.skip = (
|
||||
"Bytes behavior of networkString only provided on Python 2.")
|
||||
|
||||
def test_unicode(self):
|
||||
"""
|
||||
L{networkString} returns a C{unicode} object passed to it encoded into a
|
||||
C{bytes} instance.
|
||||
"""
|
||||
self.assertEqual(b"foo", networkString(u"foo"))
|
||||
|
||||
|
||||
def test_unicodeOutOfRange(self):
|
||||
"""
|
||||
L{networkString} raises L{UnicodeError} if passed a C{unicode} instance
|
||||
containing characters not encodable in ASCII.
|
||||
"""
|
||||
self.assertRaises(
|
||||
UnicodeError, networkString, u"\N{SNOWMAN}")
|
||||
if not _PY3:
|
||||
test_unicode.skip = test_unicodeOutOfRange.skip = (
|
||||
"Unicode behavior of networkString only provided on Python 3.")
|
||||
|
||||
|
||||
def test_nonString(self):
|
||||
"""
|
||||
L{networkString} raises L{TypeError} if passed a non-string object or
|
||||
the wrong type of string object.
|
||||
"""
|
||||
self.assertRaises(TypeError, networkString, object())
|
||||
if _PY3:
|
||||
self.assertRaises(TypeError, networkString, b"bytes")
|
||||
else:
|
||||
self.assertRaises(TypeError, networkString, u"text")
|
||||
|
||||
|
||||
|
||||
class ReraiseTests(unittest.SynchronousTestCase):
|
||||
"""
|
||||
L{reraise} re-raises exceptions on both Python 2 and Python 3.
|
||||
"""
|
||||
|
||||
def test_reraiseWithNone(self):
|
||||
"""
|
||||
Calling L{reraise} with an exception instance and a traceback of
|
||||
C{None} re-raises it with a new traceback.
|
||||
"""
|
||||
try:
|
||||
1/0
|
||||
except:
|
||||
typ, value, tb = sys.exc_info()
|
||||
try:
|
||||
reraise(value, None)
|
||||
except:
|
||||
typ2, value2, tb2 = sys.exc_info()
|
||||
self.assertEqual(typ2, ZeroDivisionError)
|
||||
self.assertTrue(value is value2)
|
||||
self.assertNotEqual(traceback.format_tb(tb)[-1],
|
||||
traceback.format_tb(tb2)[-1])
|
||||
else:
|
||||
self.fail("The exception was not raised.")
|
||||
|
||||
|
||||
def test_reraiseWithTraceback(self):
|
||||
"""
|
||||
Calling L{reraise} with an exception instance and a traceback
|
||||
re-raises the exception with the given traceback.
|
||||
"""
|
||||
try:
|
||||
1/0
|
||||
except:
|
||||
typ, value, tb = sys.exc_info()
|
||||
try:
|
||||
reraise(value, tb)
|
||||
except:
|
||||
typ2, value2, tb2 = sys.exc_info()
|
||||
self.assertEqual(typ2, ZeroDivisionError)
|
||||
self.assertTrue(value is value2)
|
||||
self.assertEqual(traceback.format_tb(tb)[-1],
|
||||
traceback.format_tb(tb2)[-1])
|
||||
else:
|
||||
self.fail("The exception was not raised.")
|
||||
|
||||
|
||||
|
||||
class Python3BytesTests(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Tests for L{iterbytes}, L{intToBytes}, L{lazyByteSlice}.
|
||||
"""
|
||||
|
||||
def test_iteration(self):
|
||||
"""
|
||||
When L{iterbytes} is called with a bytestring, the returned object
|
||||
can be iterated over, resulting in the individual bytes of the
|
||||
bytestring.
|
||||
"""
|
||||
input = b"abcd"
|
||||
result = list(iterbytes(input))
|
||||
self.assertEqual(result, [b'a', b'b', b'c', b'd'])
|
||||
|
||||
|
||||
def test_intToBytes(self):
|
||||
"""
|
||||
When L{intToBytes} is called with an integer, the result is an
|
||||
ASCII-encoded string representation of the number.
|
||||
"""
|
||||
self.assertEqual(intToBytes(213), b"213")
|
||||
|
||||
|
||||
def test_lazyByteSliceNoOffset(self):
|
||||
"""
|
||||
L{lazyByteSlice} called with some bytes returns a semantically equal version
|
||||
of these bytes.
|
||||
"""
|
||||
data = b'123XYZ'
|
||||
self.assertEqual(bytes(lazyByteSlice(data)), data)
|
||||
|
||||
|
||||
def test_lazyByteSliceOffset(self):
|
||||
"""
|
||||
L{lazyByteSlice} called with some bytes and an offset returns a semantically
|
||||
equal version of these bytes starting at the given offset.
|
||||
"""
|
||||
data = b'123XYZ'
|
||||
self.assertEqual(bytes(lazyByteSlice(data, 2)), data[2:])
|
||||
|
||||
|
||||
def test_lazyByteSliceOffsetAndLength(self):
|
||||
"""
|
||||
L{lazyByteSlice} called with some bytes, an offset and a length returns a
|
||||
semantically equal version of these bytes starting at the given
|
||||
offset, up to the given length.
|
||||
"""
|
||||
data = b'123XYZ'
|
||||
self.assertEqual(bytes(lazyByteSlice(data, 2, 3)), data[2:5])
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.python.context}.
|
||||
"""
|
||||
|
||||
from __future__ import division, absolute_import
|
||||
|
||||
from twisted.trial.unittest import SynchronousTestCase
|
||||
|
||||
from twisted.python import context
|
||||
|
||||
class ContextTest(SynchronousTestCase):
|
||||
"""
|
||||
Tests for the module-scope APIs for L{twisted.python.context}.
|
||||
"""
|
||||
def test_notPresentIfNotSet(self):
|
||||
"""
|
||||
Arbitrary keys which have not been set in the context have an associated
|
||||
value of C{None}.
|
||||
"""
|
||||
self.assertEqual(context.get("x"), None)
|
||||
|
||||
|
||||
def test_setByCall(self):
|
||||
"""
|
||||
Values may be associated with keys by passing them in a dictionary as
|
||||
the first argument to L{twisted.python.context.call}.
|
||||
"""
|
||||
self.assertEqual(context.call({"x": "y"}, context.get, "x"), "y")
|
||||
|
||||
|
||||
def test_unsetAfterCall(self):
|
||||
"""
|
||||
After a L{twisted.python.context.call} completes, keys specified in the
|
||||
call are no longer associated with the values from that call.
|
||||
"""
|
||||
context.call({"x": "y"}, lambda: None)
|
||||
self.assertEqual(context.get("x"), None)
|
||||
|
||||
|
||||
def test_setDefault(self):
|
||||
"""
|
||||
A default value may be set for a key in the context using
|
||||
L{twisted.python.context.setDefault}.
|
||||
"""
|
||||
key = object()
|
||||
self.addCleanup(context.defaultContextDict.pop, key, None)
|
||||
context.setDefault(key, "y")
|
||||
self.assertEqual("y", context.get(key))
|
||||
|
|
@ -0,0 +1,711 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
This module contains tests for L{twisted.internet.task.Cooperator} and
|
||||
related functionality.
|
||||
"""
|
||||
|
||||
from __future__ import division, absolute_import
|
||||
|
||||
from twisted.internet import reactor, defer, task
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
||||
|
||||
class FakeDelayedCall(object):
|
||||
"""
|
||||
Fake delayed call which lets us simulate the scheduler.
|
||||
"""
|
||||
def __init__(self, func):
|
||||
"""
|
||||
A function to run, later.
|
||||
"""
|
||||
self.func = func
|
||||
self.cancelled = False
|
||||
|
||||
|
||||
def cancel(self):
|
||||
"""
|
||||
Don't run my function later.
|
||||
"""
|
||||
self.cancelled = True
|
||||
|
||||
|
||||
|
||||
class FakeScheduler(object):
|
||||
"""
|
||||
A fake scheduler for testing against.
|
||||
"""
|
||||
def __init__(self):
|
||||
"""
|
||||
Create a fake scheduler with a list of work to do.
|
||||
"""
|
||||
self.work = []
|
||||
|
||||
|
||||
def __call__(self, thunk):
|
||||
"""
|
||||
Schedule a unit of work to be done later.
|
||||
"""
|
||||
unit = FakeDelayedCall(thunk)
|
||||
self.work.append(unit)
|
||||
return unit
|
||||
|
||||
|
||||
def pump(self):
|
||||
"""
|
||||
Do all of the work that is currently available to be done.
|
||||
"""
|
||||
work, self.work = self.work, []
|
||||
for unit in work:
|
||||
if not unit.cancelled:
|
||||
unit.func()
|
||||
|
||||
|
||||
|
||||
class TestCooperator(unittest.TestCase):
|
||||
RESULT = 'done'
|
||||
|
||||
def ebIter(self, err):
|
||||
err.trap(task.SchedulerStopped)
|
||||
return self.RESULT
|
||||
|
||||
|
||||
def cbIter(self, ign):
|
||||
self.fail()
|
||||
|
||||
|
||||
def testStoppedRejectsNewTasks(self):
|
||||
"""
|
||||
Test that Cooperators refuse new tasks when they have been stopped.
|
||||
"""
|
||||
def testwith(stuff):
|
||||
c = task.Cooperator()
|
||||
c.stop()
|
||||
d = c.coiterate(iter(()), stuff)
|
||||
d.addCallback(self.cbIter)
|
||||
d.addErrback(self.ebIter)
|
||||
return d.addCallback(lambda result:
|
||||
self.assertEqual(result, self.RESULT))
|
||||
return testwith(None).addCallback(lambda ign: testwith(defer.Deferred()))
|
||||
|
||||
|
||||
def testStopRunning(self):
|
||||
"""
|
||||
Test that a running iterator will not run to completion when the
|
||||
cooperator is stopped.
|
||||
"""
|
||||
c = task.Cooperator()
|
||||
def myiter():
|
||||
for myiter.value in range(3):
|
||||
yield myiter.value
|
||||
myiter.value = -1
|
||||
d = c.coiterate(myiter())
|
||||
d.addCallback(self.cbIter)
|
||||
d.addErrback(self.ebIter)
|
||||
c.stop()
|
||||
def doasserts(result):
|
||||
self.assertEqual(result, self.RESULT)
|
||||
self.assertEqual(myiter.value, -1)
|
||||
d.addCallback(doasserts)
|
||||
return d
|
||||
|
||||
|
||||
def testStopOutstanding(self):
|
||||
"""
|
||||
An iterator run with L{Cooperator.coiterate} paused on a L{Deferred}
|
||||
yielded by that iterator will fire its own L{Deferred} (the one
|
||||
returned by C{coiterate}) when L{Cooperator.stop} is called.
|
||||
"""
|
||||
testControlD = defer.Deferred()
|
||||
outstandingD = defer.Deferred()
|
||||
def myiter():
|
||||
reactor.callLater(0, testControlD.callback, None)
|
||||
yield outstandingD
|
||||
self.fail()
|
||||
c = task.Cooperator()
|
||||
d = c.coiterate(myiter())
|
||||
def stopAndGo(ign):
|
||||
c.stop()
|
||||
outstandingD.callback('arglebargle')
|
||||
|
||||
testControlD.addCallback(stopAndGo)
|
||||
d.addCallback(self.cbIter)
|
||||
d.addErrback(self.ebIter)
|
||||
|
||||
return d.addCallback(
|
||||
lambda result: self.assertEqual(result, self.RESULT))
|
||||
|
||||
|
||||
def testUnexpectedError(self):
|
||||
c = task.Cooperator()
|
||||
def myiter():
|
||||
if 0:
|
||||
yield None
|
||||
else:
|
||||
raise RuntimeError()
|
||||
d = c.coiterate(myiter())
|
||||
return self.assertFailure(d, RuntimeError)
|
||||
|
||||
|
||||
def testUnexpectedErrorActuallyLater(self):
|
||||
def myiter():
|
||||
D = defer.Deferred()
|
||||
reactor.callLater(0, D.errback, RuntimeError())
|
||||
yield D
|
||||
|
||||
c = task.Cooperator()
|
||||
d = c.coiterate(myiter())
|
||||
return self.assertFailure(d, RuntimeError)
|
||||
|
||||
|
||||
def testUnexpectedErrorNotActuallyLater(self):
|
||||
def myiter():
|
||||
yield defer.fail(RuntimeError())
|
||||
|
||||
c = task.Cooperator()
|
||||
d = c.coiterate(myiter())
|
||||
return self.assertFailure(d, RuntimeError)
|
||||
|
||||
|
||||
def testCooperation(self):
|
||||
L = []
|
||||
def myiter(things):
|
||||
for th in things:
|
||||
L.append(th)
|
||||
yield None
|
||||
|
||||
groupsOfThings = ['abc', (1, 2, 3), 'def', (4, 5, 6)]
|
||||
|
||||
c = task.Cooperator()
|
||||
tasks = []
|
||||
for stuff in groupsOfThings:
|
||||
tasks.append(c.coiterate(myiter(stuff)))
|
||||
|
||||
return defer.DeferredList(tasks).addCallback(
|
||||
lambda ign: self.assertEqual(tuple(L), sum(zip(*groupsOfThings), ())))
|
||||
|
||||
|
||||
def testResourceExhaustion(self):
|
||||
output = []
|
||||
def myiter():
|
||||
for i in range(100):
|
||||
output.append(i)
|
||||
if i == 9:
|
||||
_TPF.stopped = True
|
||||
yield i
|
||||
|
||||
class _TPF:
|
||||
stopped = False
|
||||
def __call__(self):
|
||||
return self.stopped
|
||||
|
||||
c = task.Cooperator(terminationPredicateFactory=_TPF)
|
||||
c.coiterate(myiter()).addErrback(self.ebIter)
|
||||
c._delayedCall.cancel()
|
||||
# testing a private method because only the test case will ever care
|
||||
# about this, so we have to carefully clean up after ourselves.
|
||||
c._tick()
|
||||
c.stop()
|
||||
self.failUnless(_TPF.stopped)
|
||||
self.assertEqual(output, list(range(10)))
|
||||
|
||||
|
||||
def testCallbackReCoiterate(self):
|
||||
"""
|
||||
If a callback to a deferred returned by coiterate calls coiterate on
|
||||
the same Cooperator, we should make sure to only do the minimal amount
|
||||
of scheduling work. (This test was added to demonstrate a specific bug
|
||||
that was found while writing the scheduler.)
|
||||
"""
|
||||
calls = []
|
||||
|
||||
class FakeCall:
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
|
||||
def __repr__(self):
|
||||
return '<FakeCall %r>' % (self.func,)
|
||||
|
||||
def sched(f):
|
||||
self.failIf(calls, repr(calls))
|
||||
calls.append(FakeCall(f))
|
||||
return calls[-1]
|
||||
|
||||
c = task.Cooperator(scheduler=sched, terminationPredicateFactory=lambda: lambda: True)
|
||||
d = c.coiterate(iter(()))
|
||||
|
||||
done = []
|
||||
def anotherTask(ign):
|
||||
c.coiterate(iter(())).addBoth(done.append)
|
||||
|
||||
d.addCallback(anotherTask)
|
||||
|
||||
work = 0
|
||||
while not done:
|
||||
work += 1
|
||||
while calls:
|
||||
calls.pop(0).func()
|
||||
work += 1
|
||||
if work > 50:
|
||||
self.fail("Cooperator took too long")
|
||||
|
||||
|
||||
def test_removingLastTaskStopsScheduledCall(self):
|
||||
"""
|
||||
If the last task in a Cooperator is removed, the scheduled call for
|
||||
the next tick is cancelled, since it is no longer necessary.
|
||||
|
||||
This behavior is useful for tests that want to assert they have left
|
||||
no reactor state behind when they're done.
|
||||
"""
|
||||
calls = [None]
|
||||
def sched(f):
|
||||
calls[0] = FakeDelayedCall(f)
|
||||
return calls[0]
|
||||
coop = task.Cooperator(scheduler=sched)
|
||||
|
||||
# Add two task; this should schedule the tick:
|
||||
task1 = coop.cooperate(iter([1, 2]))
|
||||
task2 = coop.cooperate(iter([1, 2]))
|
||||
self.assertEqual(calls[0].func, coop._tick)
|
||||
|
||||
# Remove first task; scheduled call should still be going:
|
||||
task1.stop()
|
||||
self.assertEqual(calls[0].cancelled, False)
|
||||
self.assertEqual(coop._delayedCall, calls[0])
|
||||
|
||||
# Remove second task; scheduled call should be cancelled:
|
||||
task2.stop()
|
||||
self.assertEqual(calls[0].cancelled, True)
|
||||
self.assertEqual(coop._delayedCall, None)
|
||||
|
||||
# Add another task; scheduled call will be recreated:
|
||||
coop.cooperate(iter([1, 2]))
|
||||
self.assertEqual(calls[0].cancelled, False)
|
||||
self.assertEqual(coop._delayedCall, calls[0])
|
||||
|
||||
|
||||
def test_runningWhenStarted(self):
|
||||
"""
|
||||
L{Cooperator.running} reports C{True} if the L{Cooperator}
|
||||
was started on creation.
|
||||
"""
|
||||
c = task.Cooperator()
|
||||
self.assertTrue(c.running)
|
||||
|
||||
|
||||
def test_runningWhenNotStarted(self):
|
||||
"""
|
||||
L{Cooperator.running} reports C{False} if the L{Cooperator}
|
||||
has not been started.
|
||||
"""
|
||||
c = task.Cooperator(started=False)
|
||||
self.assertFalse(c.running)
|
||||
|
||||
|
||||
def test_runningWhenRunning(self):
|
||||
"""
|
||||
L{Cooperator.running} reports C{True} when the L{Cooperator}
|
||||
is running.
|
||||
"""
|
||||
c = task.Cooperator(started=False)
|
||||
c.start()
|
||||
self.addCleanup(c.stop)
|
||||
self.assertTrue(c.running)
|
||||
|
||||
|
||||
def test_runningWhenStopped(self):
|
||||
"""
|
||||
L{Cooperator.running} reports C{False} after the L{Cooperator}
|
||||
has been stopped.
|
||||
"""
|
||||
c = task.Cooperator(started=False)
|
||||
c.start()
|
||||
c.stop()
|
||||
self.assertFalse(c.running)
|
||||
|
||||
|
||||
|
||||
class UnhandledException(Exception):
|
||||
"""
|
||||
An exception that should go unhandled.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class AliasTests(unittest.TestCase):
|
||||
"""
|
||||
Integration test to verify that the global singleton aliases do what
|
||||
they're supposed to.
|
||||
"""
|
||||
|
||||
def test_cooperate(self):
|
||||
"""
|
||||
L{twisted.internet.task.cooperate} ought to run the generator that it is
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
def doit():
|
||||
yield 1
|
||||
yield 2
|
||||
yield 3
|
||||
d.callback("yay")
|
||||
it = doit()
|
||||
theTask = task.cooperate(it)
|
||||
self.assertIn(theTask, task._theCooperator._tasks)
|
||||
return d
|
||||
|
||||
|
||||
|
||||
class RunStateTests(unittest.TestCase):
|
||||
"""
|
||||
Tests to verify the behavior of L{CooperativeTask.pause},
|
||||
L{CooperativeTask.resume}, L{CooperativeTask.stop}, exhausting the
|
||||
underlying iterator, and their interactions with each other.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Create a cooperator with a fake scheduler and a termination predicate
|
||||
that ensures only one unit of work will take place per tick.
|
||||
"""
|
||||
self._doDeferNext = False
|
||||
self._doStopNext = False
|
||||
self._doDieNext = False
|
||||
self.work = []
|
||||
self.scheduler = FakeScheduler()
|
||||
self.cooperator = task.Cooperator(
|
||||
scheduler=self.scheduler,
|
||||
# Always stop after one iteration of work (return a function which
|
||||
# returns a function which always returns True)
|
||||
terminationPredicateFactory=lambda: lambda: True)
|
||||
self.task = self.cooperator.cooperate(self.worker())
|
||||
self.cooperator.start()
|
||||
|
||||
|
||||
def worker(self):
|
||||
"""
|
||||
This is a sample generator which yields Deferreds when we are testing
|
||||
deferral and an ascending integer count otherwise.
|
||||
"""
|
||||
i = 0
|
||||
while True:
|
||||
i += 1
|
||||
if self._doDeferNext:
|
||||
self._doDeferNext = False
|
||||
d = defer.Deferred()
|
||||
self.work.append(d)
|
||||
yield d
|
||||
elif self._doStopNext:
|
||||
return
|
||||
elif self._doDieNext:
|
||||
raise UnhandledException()
|
||||
else:
|
||||
self.work.append(i)
|
||||
yield i
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
"""
|
||||
Drop references to interesting parts of the fixture to allow Deferred
|
||||
errors to be noticed when things start failing.
|
||||
"""
|
||||
del self.task
|
||||
del self.scheduler
|
||||
|
||||
|
||||
def deferNext(self):
|
||||
"""
|
||||
Defer the next result from my worker iterator.
|
||||
"""
|
||||
self._doDeferNext = True
|
||||
|
||||
|
||||
def stopNext(self):
|
||||
"""
|
||||
Make the next result from my worker iterator be completion (raising
|
||||
StopIteration).
|
||||
"""
|
||||
self._doStopNext = True
|
||||
|
||||
|
||||
def dieNext(self):
|
||||
"""
|
||||
Make the next result from my worker iterator be raising an
|
||||
L{UnhandledException}.
|
||||
"""
|
||||
def ignoreUnhandled(failure):
|
||||
failure.trap(UnhandledException)
|
||||
return None
|
||||
self._doDieNext = True
|
||||
|
||||
|
||||
def test_pauseResume(self):
|
||||
"""
|
||||
Cooperators should stop running their tasks when they're paused, and
|
||||
start again when they're resumed.
|
||||
"""
|
||||
# first, sanity check
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(self.work, [1])
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(self.work, [1, 2])
|
||||
|
||||
# OK, now for real
|
||||
self.task.pause()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(self.work, [1, 2])
|
||||
self.task.resume()
|
||||
# Resuming itself shoult not do any work
|
||||
self.assertEqual(self.work, [1, 2])
|
||||
self.scheduler.pump()
|
||||
# But when the scheduler rolls around again...
|
||||
self.assertEqual(self.work, [1, 2, 3])
|
||||
|
||||
|
||||
def test_resumeNotPaused(self):
|
||||
"""
|
||||
L{CooperativeTask.resume} should raise a L{TaskNotPaused} exception if
|
||||
it was not paused; e.g. if L{CooperativeTask.pause} was not invoked
|
||||
more times than L{CooperativeTask.resume} on that object.
|
||||
"""
|
||||
self.assertRaises(task.NotPaused, self.task.resume)
|
||||
self.task.pause()
|
||||
self.task.resume()
|
||||
self.assertRaises(task.NotPaused, self.task.resume)
|
||||
|
||||
|
||||
def test_pauseTwice(self):
|
||||
"""
|
||||
Pauses on tasks should behave like a stack. If a task is paused twice,
|
||||
it needs to be resumed twice.
|
||||
"""
|
||||
# pause once
|
||||
self.task.pause()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(self.work, [])
|
||||
# pause twice
|
||||
self.task.pause()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(self.work, [])
|
||||
# resume once (it shouldn't)
|
||||
self.task.resume()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(self.work, [])
|
||||
# resume twice (now it should go)
|
||||
self.task.resume()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(self.work, [1])
|
||||
|
||||
|
||||
def test_pauseWhileDeferred(self):
|
||||
"""
|
||||
C{pause()}ing a task while it is waiting on an outstanding
|
||||
L{defer.Deferred} should put the task into a state where the
|
||||
outstanding L{defer.Deferred} must be called back I{and} the task is
|
||||
C{resume}d before it will continue processing.
|
||||
"""
|
||||
self.deferNext()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(len(self.work), 1)
|
||||
self.failUnless(isinstance(self.work[0], defer.Deferred))
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(len(self.work), 1)
|
||||
self.task.pause()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(len(self.work), 1)
|
||||
self.task.resume()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(len(self.work), 1)
|
||||
self.work[0].callback("STUFF!")
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(len(self.work), 2)
|
||||
self.assertEqual(self.work[1], 2)
|
||||
|
||||
|
||||
def test_whenDone(self):
|
||||
"""
|
||||
L{CooperativeTask.whenDone} returns a Deferred which fires when the
|
||||
Cooperator's iterator is exhausted. It returns a new Deferred each
|
||||
time it is called; callbacks added to other invocations will not modify
|
||||
the value that subsequent invocations will fire with.
|
||||
"""
|
||||
|
||||
deferred1 = self.task.whenDone()
|
||||
deferred2 = self.task.whenDone()
|
||||
results1 = []
|
||||
results2 = []
|
||||
final1 = []
|
||||
final2 = []
|
||||
|
||||
def callbackOne(result):
|
||||
results1.append(result)
|
||||
return 1
|
||||
|
||||
def callbackTwo(result):
|
||||
results2.append(result)
|
||||
return 2
|
||||
|
||||
deferred1.addCallback(callbackOne)
|
||||
deferred2.addCallback(callbackTwo)
|
||||
|
||||
deferred1.addCallback(final1.append)
|
||||
deferred2.addCallback(final2.append)
|
||||
|
||||
# exhaust the task iterator
|
||||
# callbacks fire
|
||||
self.stopNext()
|
||||
self.scheduler.pump()
|
||||
|
||||
self.assertEqual(len(results1), 1)
|
||||
self.assertEqual(len(results2), 1)
|
||||
|
||||
self.assertIdentical(results1[0], self.task._iterator)
|
||||
self.assertIdentical(results2[0], self.task._iterator)
|
||||
|
||||
self.assertEqual(final1, [1])
|
||||
self.assertEqual(final2, [2])
|
||||
|
||||
|
||||
def test_whenDoneError(self):
|
||||
"""
|
||||
L{CooperativeTask.whenDone} returns a L{defer.Deferred} that will fail
|
||||
when the iterable's C{next} method raises an exception, with that
|
||||
exception.
|
||||
"""
|
||||
deferred1 = self.task.whenDone()
|
||||
results = []
|
||||
deferred1.addErrback(results.append)
|
||||
self.dieNext()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0].check(UnhandledException), UnhandledException)
|
||||
|
||||
|
||||
def test_whenDoneStop(self):
|
||||
"""
|
||||
L{CooperativeTask.whenDone} returns a L{defer.Deferred} that fails with
|
||||
L{TaskStopped} when the C{stop} method is called on that
|
||||
L{CooperativeTask}.
|
||||
"""
|
||||
deferred1 = self.task.whenDone()
|
||||
errors = []
|
||||
deferred1.addErrback(errors.append)
|
||||
self.task.stop()
|
||||
self.assertEqual(len(errors), 1)
|
||||
self.assertEqual(errors[0].check(task.TaskStopped), task.TaskStopped)
|
||||
|
||||
|
||||
def test_whenDoneAlreadyDone(self):
|
||||
"""
|
||||
L{CooperativeTask.whenDone} will return a L{defer.Deferred} that will
|
||||
succeed immediately if its iterator has already completed.
|
||||
"""
|
||||
self.stopNext()
|
||||
self.scheduler.pump()
|
||||
results = []
|
||||
self.task.whenDone().addCallback(results.append)
|
||||
self.assertEqual(results, [self.task._iterator])
|
||||
|
||||
|
||||
def test_stopStops(self):
|
||||
"""
|
||||
C{stop()}ping a task should cause it to be removed from the run just as
|
||||
C{pause()}ing, with the distinction that C{resume()} will raise a
|
||||
L{TaskStopped} exception.
|
||||
"""
|
||||
self.task.stop()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(len(self.work), 0)
|
||||
self.assertRaises(task.TaskStopped, self.task.stop)
|
||||
self.assertRaises(task.TaskStopped, self.task.pause)
|
||||
# Sanity check - it's still not scheduled, is it?
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(self.work, [])
|
||||
|
||||
|
||||
def test_pauseStopResume(self):
|
||||
"""
|
||||
C{resume()}ing a paused, stopped task should be a no-op; it should not
|
||||
raise an exception, because it's paused, but neither should it actually
|
||||
do more work from the task.
|
||||
"""
|
||||
self.task.pause()
|
||||
self.task.stop()
|
||||
self.task.resume()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(self.work, [])
|
||||
|
||||
|
||||
def test_stopDeferred(self):
|
||||
"""
|
||||
As a corrolary of the interaction of C{pause()} and C{unpause()},
|
||||
C{stop()}ping a task which is waiting on a L{Deferred} should cause the
|
||||
task to gracefully shut down, meaning that it should not be unpaused
|
||||
when the deferred fires.
|
||||
"""
|
||||
self.deferNext()
|
||||
self.scheduler.pump()
|
||||
d = self.work.pop()
|
||||
self.assertEqual(self.task._pauseCount, 1)
|
||||
results = []
|
||||
d.addBoth(results.append)
|
||||
self.scheduler.pump()
|
||||
self.task.stop()
|
||||
self.scheduler.pump()
|
||||
d.callback(7)
|
||||
self.scheduler.pump()
|
||||
# Let's make sure that Deferred doesn't come out fried with an
|
||||
# unhandled error that will be logged. The value is None, rather than
|
||||
# our test value, 7, because this Deferred is returned to and consumed
|
||||
# by the cooperator code. Its callback therefore has no contract.
|
||||
self.assertEqual(results, [None])
|
||||
# But more importantly, no further work should have happened.
|
||||
self.assertEqual(self.work, [])
|
||||
|
||||
|
||||
def test_stopExhausted(self):
|
||||
"""
|
||||
C{stop()}ping a L{CooperativeTask} whose iterator has been exhausted
|
||||
should raise L{TaskDone}.
|
||||
"""
|
||||
self.stopNext()
|
||||
self.scheduler.pump()
|
||||
self.assertRaises(task.TaskDone, self.task.stop)
|
||||
|
||||
|
||||
def test_stopErrored(self):
|
||||
"""
|
||||
C{stop()}ping a L{CooperativeTask} whose iterator has encountered an
|
||||
error should raise L{TaskFailed}.
|
||||
"""
|
||||
self.dieNext()
|
||||
self.scheduler.pump()
|
||||
self.assertRaises(task.TaskFailed, self.task.stop)
|
||||
|
||||
|
||||
def test_stopCooperatorReentrancy(self):
|
||||
"""
|
||||
If a callback of a L{Deferred} from L{CooperativeTask.whenDone} calls
|
||||
C{Cooperator.stop} on its L{CooperativeTask._cooperator}, the
|
||||
L{Cooperator} will stop, but the L{CooperativeTask} whose callback is
|
||||
calling C{stop} should already be considered 'stopped' by the time the
|
||||
callback is running, and therefore removed from the
|
||||
L{CoooperativeTask}.
|
||||
"""
|
||||
callbackPhases = []
|
||||
def stopit(result):
|
||||
callbackPhases.append(result)
|
||||
self.cooperator.stop()
|
||||
# "done" here is a sanity check to make sure that we get all the
|
||||
# way through the callback; i.e. stop() shouldn't be raising an
|
||||
# exception due to the stopped-ness of our main task.
|
||||
callbackPhases.append("done")
|
||||
self.task.whenDone().addCallback(stopit)
|
||||
self.stopNext()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(callbackPhases, [self.task._iterator, "done"])
|
||||
|
||||
|
||||
|
||||
2361
Linux_i686/lib/python2.7/site-packages/twisted/test/test_defer.py
Normal file
2361
Linux_i686/lib/python2.7/site-packages/twisted/test/test_defer.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,301 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.internet.defer.deferredGenerator} and related APIs.
|
||||
"""
|
||||
|
||||
from __future__ import division, absolute_import
|
||||
|
||||
import sys
|
||||
|
||||
from twisted.internet import reactor
|
||||
|
||||
from twisted.trial import unittest
|
||||
|
||||
from twisted.internet.defer import waitForDeferred, deferredGenerator, Deferred
|
||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||
from twisted.internet import defer
|
||||
|
||||
|
||||
def getThing():
|
||||
d = Deferred()
|
||||
reactor.callLater(0, d.callback, "hi")
|
||||
return d
|
||||
|
||||
def getOwie():
|
||||
d = Deferred()
|
||||
def CRAP():
|
||||
d.errback(ZeroDivisionError('OMG'))
|
||||
reactor.callLater(0, CRAP)
|
||||
return d
|
||||
|
||||
# NOTE: most of the tests in DeferredGeneratorTests are duplicated
|
||||
# with slightly different syntax for the InlineCallbacksTests below.
|
||||
|
||||
class TerminalException(Exception):
|
||||
pass
|
||||
|
||||
class BaseDefgenTests:
|
||||
"""
|
||||
This class sets up a bunch of test cases which will test both
|
||||
deferredGenerator and inlineCallbacks based generators. The subclasses
|
||||
DeferredGeneratorTests and InlineCallbacksTests each provide the actual
|
||||
generator implementations tested.
|
||||
"""
|
||||
|
||||
def testBasics(self):
|
||||
"""
|
||||
Test that a normal deferredGenerator works. Tests yielding a
|
||||
deferred which callbacks, as well as a deferred errbacks. Also
|
||||
ensures returning a final value works.
|
||||
"""
|
||||
|
||||
return self._genBasics().addCallback(self.assertEqual, 'WOOSH')
|
||||
|
||||
def testBuggy(self):
|
||||
"""
|
||||
Ensure that a buggy generator properly signals a Failure
|
||||
condition on result deferred.
|
||||
"""
|
||||
return self.assertFailure(self._genBuggy(), ZeroDivisionError)
|
||||
|
||||
def testNothing(self):
|
||||
"""Test that a generator which never yields results in None."""
|
||||
|
||||
return self._genNothing().addCallback(self.assertEqual, None)
|
||||
|
||||
def testHandledTerminalFailure(self):
|
||||
"""
|
||||
Create a Deferred Generator which yields a Deferred which fails and
|
||||
handles the exception which results. Assert that the Deferred
|
||||
Generator does not errback its Deferred.
|
||||
"""
|
||||
return self._genHandledTerminalFailure().addCallback(self.assertEqual, None)
|
||||
|
||||
def testHandledTerminalAsyncFailure(self):
|
||||
"""
|
||||
Just like testHandledTerminalFailure, only with a Deferred which fires
|
||||
asynchronously with an error.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
deferredGeneratorResultDeferred = self._genHandledTerminalAsyncFailure(d)
|
||||
d.errback(TerminalException("Handled Terminal Failure"))
|
||||
return deferredGeneratorResultDeferred.addCallback(
|
||||
self.assertEqual, None)
|
||||
|
||||
def testStackUsage(self):
|
||||
"""
|
||||
Make sure we don't blow the stack when yielding immediately
|
||||
available deferreds.
|
||||
"""
|
||||
return self._genStackUsage().addCallback(self.assertEqual, 0)
|
||||
|
||||
def testStackUsage2(self):
|
||||
"""
|
||||
Make sure we don't blow the stack when yielding immediately
|
||||
available values.
|
||||
"""
|
||||
return self._genStackUsage2().addCallback(self.assertEqual, 0)
|
||||
|
||||
|
||||
|
||||
|
||||
class DeferredGeneratorTests(BaseDefgenTests, unittest.TestCase):
|
||||
|
||||
# First provide all the generator impls necessary for BaseDefgenTests
|
||||
def _genBasics(self):
|
||||
|
||||
x = waitForDeferred(getThing())
|
||||
yield x
|
||||
x = x.getResult()
|
||||
|
||||
self.assertEqual(x, "hi")
|
||||
|
||||
ow = waitForDeferred(getOwie())
|
||||
yield ow
|
||||
try:
|
||||
ow.getResult()
|
||||
except ZeroDivisionError as e:
|
||||
self.assertEqual(str(e), 'OMG')
|
||||
yield "WOOSH"
|
||||
return
|
||||
_genBasics = deferredGenerator(_genBasics)
|
||||
|
||||
def _genBuggy(self):
|
||||
yield waitForDeferred(getThing())
|
||||
1//0
|
||||
_genBuggy = deferredGenerator(_genBuggy)
|
||||
|
||||
|
||||
def _genNothing(self):
|
||||
if 0: yield 1
|
||||
_genNothing = deferredGenerator(_genNothing)
|
||||
|
||||
def _genHandledTerminalFailure(self):
|
||||
x = waitForDeferred(defer.fail(TerminalException("Handled Terminal Failure")))
|
||||
yield x
|
||||
try:
|
||||
x.getResult()
|
||||
except TerminalException:
|
||||
pass
|
||||
_genHandledTerminalFailure = deferredGenerator(_genHandledTerminalFailure)
|
||||
|
||||
|
||||
def _genHandledTerminalAsyncFailure(self, d):
|
||||
x = waitForDeferred(d)
|
||||
yield x
|
||||
try:
|
||||
x.getResult()
|
||||
except TerminalException:
|
||||
pass
|
||||
_genHandledTerminalAsyncFailure = deferredGenerator(_genHandledTerminalAsyncFailure)
|
||||
|
||||
|
||||
def _genStackUsage(self):
|
||||
for x in range(5000):
|
||||
# Test with yielding a deferred
|
||||
x = waitForDeferred(defer.succeed(1))
|
||||
yield x
|
||||
x = x.getResult()
|
||||
yield 0
|
||||
_genStackUsage = deferredGenerator(_genStackUsage)
|
||||
|
||||
def _genStackUsage2(self):
|
||||
for x in range(5000):
|
||||
# Test with yielding a random value
|
||||
yield 1
|
||||
yield 0
|
||||
_genStackUsage2 = deferredGenerator(_genStackUsage2)
|
||||
|
||||
# Tests unique to deferredGenerator
|
||||
|
||||
def testDeferredYielding(self):
|
||||
"""
|
||||
Ensure that yielding a Deferred directly is trapped as an
|
||||
error.
|
||||
"""
|
||||
# See the comment _deferGenerator about d.callback(Deferred).
|
||||
def _genDeferred():
|
||||
yield getThing()
|
||||
_genDeferred = deferredGenerator(_genDeferred)
|
||||
|
||||
return self.assertFailure(_genDeferred(), TypeError)
|
||||
|
||||
|
||||
|
||||
class InlineCallbacksTests(BaseDefgenTests, unittest.TestCase):
|
||||
# First provide all the generator impls necessary for BaseDefgenTests
|
||||
|
||||
def _genBasics(self):
|
||||
|
||||
x = yield getThing()
|
||||
|
||||
self.assertEqual(x, "hi")
|
||||
|
||||
try:
|
||||
ow = yield getOwie()
|
||||
except ZeroDivisionError as e:
|
||||
self.assertEqual(str(e), 'OMG')
|
||||
returnValue("WOOSH")
|
||||
_genBasics = inlineCallbacks(_genBasics)
|
||||
|
||||
def _genBuggy(self):
|
||||
yield getThing()
|
||||
1/0
|
||||
_genBuggy = inlineCallbacks(_genBuggy)
|
||||
|
||||
|
||||
def _genNothing(self):
|
||||
if 0: yield 1
|
||||
_genNothing = inlineCallbacks(_genNothing)
|
||||
|
||||
|
||||
def _genHandledTerminalFailure(self):
|
||||
try:
|
||||
x = yield defer.fail(TerminalException("Handled Terminal Failure"))
|
||||
except TerminalException:
|
||||
pass
|
||||
_genHandledTerminalFailure = inlineCallbacks(_genHandledTerminalFailure)
|
||||
|
||||
|
||||
def _genHandledTerminalAsyncFailure(self, d):
|
||||
try:
|
||||
x = yield d
|
||||
except TerminalException:
|
||||
pass
|
||||
_genHandledTerminalAsyncFailure = inlineCallbacks(
|
||||
_genHandledTerminalAsyncFailure)
|
||||
|
||||
|
||||
def _genStackUsage(self):
|
||||
for x in range(5000):
|
||||
# Test with yielding a deferred
|
||||
x = yield defer.succeed(1)
|
||||
returnValue(0)
|
||||
_genStackUsage = inlineCallbacks(_genStackUsage)
|
||||
|
||||
def _genStackUsage2(self):
|
||||
for x in range(5000):
|
||||
# Test with yielding a random value
|
||||
yield 1
|
||||
returnValue(0)
|
||||
_genStackUsage2 = inlineCallbacks(_genStackUsage2)
|
||||
|
||||
# Tests unique to inlineCallbacks
|
||||
|
||||
def testYieldNonDeferrred(self):
|
||||
"""
|
||||
Ensure that yielding a non-deferred passes it back as the
|
||||
result of the yield expression.
|
||||
"""
|
||||
def _test():
|
||||
x = yield 5
|
||||
returnValue(5)
|
||||
_test = inlineCallbacks(_test)
|
||||
|
||||
return _test().addCallback(self.assertEqual, 5)
|
||||
|
||||
def testReturnNoValue(self):
|
||||
"""Ensure a standard python return results in a None result."""
|
||||
def _noReturn():
|
||||
yield 5
|
||||
return
|
||||
_noReturn = inlineCallbacks(_noReturn)
|
||||
|
||||
return _noReturn().addCallback(self.assertEqual, None)
|
||||
|
||||
def testReturnValue(self):
|
||||
"""Ensure that returnValue works."""
|
||||
def _return():
|
||||
yield 5
|
||||
returnValue(6)
|
||||
_return = inlineCallbacks(_return)
|
||||
|
||||
return _return().addCallback(self.assertEqual, 6)
|
||||
|
||||
|
||||
def test_nonGeneratorReturn(self):
|
||||
"""
|
||||
Ensure that C{TypeError} with a message about L{inlineCallbacks} is
|
||||
raised when a non-generator returns something other than a generator.
|
||||
"""
|
||||
def _noYield():
|
||||
return 5
|
||||
_noYield = inlineCallbacks(_noYield)
|
||||
|
||||
self.assertIn("inlineCallbacks",
|
||||
str(self.assertRaises(TypeError, _noYield)))
|
||||
|
||||
|
||||
def test_nonGeneratorReturnValue(self):
|
||||
"""
|
||||
Ensure that C{TypeError} with a message about L{inlineCallbacks} is
|
||||
raised when a non-generator calls L{returnValue}.
|
||||
"""
|
||||
def _noYield():
|
||||
returnValue(5)
|
||||
_noYield = inlineCallbacks(_noYield)
|
||||
|
||||
self.assertIn("inlineCallbacks",
|
||||
str(self.assertRaises(TypeError, _noYield)))
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.protocols import dict
|
||||
|
||||
paramString = "\"This is a dqstring \\w\\i\\t\\h boring stuff like: \\\"\" and t\\hes\\\"e are a\\to\\ms"
|
||||
goodparams = ["This is a dqstring with boring stuff like: \"", "and", "thes\"e", "are", "atoms"]
|
||||
|
||||
class ParamTest(unittest.TestCase):
|
||||
def testParseParam(self):
|
||||
"""Testing command response handling"""
|
||||
params = []
|
||||
rest = paramString
|
||||
while 1:
|
||||
(param, rest) = dict.parseParam(rest)
|
||||
if param == None:
|
||||
break
|
||||
params.append(param)
|
||||
self.assertEqual(params, goodparams)#, "DictClient.parseParam returns unexpected results")
|
||||
|
|
@ -0,0 +1,681 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.cred._digest} and the associated bits in
|
||||
L{twisted.cred.credentials}.
|
||||
"""
|
||||
from hashlib import md5, sha1
|
||||
|
||||
from zope.interface.verify import verifyObject
|
||||
from twisted.trial.unittest import TestCase
|
||||
from twisted.internet.address import IPv4Address
|
||||
from twisted.cred.error import LoginFailed
|
||||
from twisted.cred.credentials import calcHA1, calcHA2, IUsernameDigestHash
|
||||
from twisted.cred.credentials import calcResponse, DigestCredentialFactory
|
||||
|
||||
def b64encode(s):
|
||||
return s.encode('base64').strip()
|
||||
|
||||
|
||||
class FakeDigestCredentialFactory(DigestCredentialFactory):
|
||||
"""
|
||||
A Fake Digest Credential Factory that generates a predictable
|
||||
nonce and opaque
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(FakeDigestCredentialFactory, self).__init__(*args, **kwargs)
|
||||
self.privateKey = "0"
|
||||
|
||||
|
||||
def _generateNonce(self):
|
||||
"""
|
||||
Generate a static nonce
|
||||
"""
|
||||
return '178288758716122392881254770685'
|
||||
|
||||
|
||||
def _getTime(self):
|
||||
"""
|
||||
Return a stable time
|
||||
"""
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
class DigestAuthTests(TestCase):
|
||||
"""
|
||||
L{TestCase} mixin class which defines a number of tests for
|
||||
L{DigestCredentialFactory}. Because this mixin defines C{setUp}, it
|
||||
must be inherited before L{TestCase}.
|
||||
"""
|
||||
def setUp(self):
|
||||
"""
|
||||
Create a DigestCredentialFactory for testing
|
||||
"""
|
||||
self.username = "foobar"
|
||||
self.password = "bazquux"
|
||||
self.realm = "test realm"
|
||||
self.algorithm = "md5"
|
||||
self.cnonce = "29fc54aa1641c6fa0e151419361c8f23"
|
||||
self.qop = "auth"
|
||||
self.uri = "/write/"
|
||||
self.clientAddress = IPv4Address('TCP', '10.2.3.4', 43125)
|
||||
self.method = 'GET'
|
||||
self.credentialFactory = DigestCredentialFactory(
|
||||
self.algorithm, self.realm)
|
||||
|
||||
|
||||
def test_MD5HashA1(self, _algorithm='md5', _hash=md5):
|
||||
"""
|
||||
L{calcHA1} accepts the C{'md5'} algorithm and returns an MD5 hash of
|
||||
its parameters, excluding the nonce and cnonce.
|
||||
"""
|
||||
nonce = 'abc123xyz'
|
||||
hashA1 = calcHA1(_algorithm, self.username, self.realm, self.password,
|
||||
nonce, self.cnonce)
|
||||
a1 = '%s:%s:%s' % (self.username, self.realm, self.password)
|
||||
expected = _hash(a1).hexdigest()
|
||||
self.assertEqual(hashA1, expected)
|
||||
|
||||
|
||||
def test_MD5SessionHashA1(self):
|
||||
"""
|
||||
L{calcHA1} accepts the C{'md5-sess'} algorithm and returns an MD5 hash
|
||||
of its parameters, including the nonce and cnonce.
|
||||
"""
|
||||
nonce = 'xyz321abc'
|
||||
hashA1 = calcHA1('md5-sess', self.username, self.realm, self.password,
|
||||
nonce, self.cnonce)
|
||||
a1 = '%s:%s:%s' % (self.username, self.realm, self.password)
|
||||
ha1 = md5(a1).digest()
|
||||
a1 = '%s:%s:%s' % (ha1, nonce, self.cnonce)
|
||||
expected = md5(a1).hexdigest()
|
||||
self.assertEqual(hashA1, expected)
|
||||
|
||||
|
||||
def test_SHAHashA1(self):
|
||||
"""
|
||||
L{calcHA1} accepts the C{'sha'} algorithm and returns a SHA hash of its
|
||||
parameters, excluding the nonce and cnonce.
|
||||
"""
|
||||
self.test_MD5HashA1('sha', sha1)
|
||||
|
||||
|
||||
def test_MD5HashA2Auth(self, _algorithm='md5', _hash=md5):
|
||||
"""
|
||||
L{calcHA2} accepts the C{'md5'} algorithm and returns an MD5 hash of
|
||||
its arguments, excluding the entity hash for QOP other than
|
||||
C{'auth-int'}.
|
||||
"""
|
||||
method = 'GET'
|
||||
hashA2 = calcHA2(_algorithm, method, self.uri, 'auth', None)
|
||||
a2 = '%s:%s' % (method, self.uri)
|
||||
expected = _hash(a2).hexdigest()
|
||||
self.assertEqual(hashA2, expected)
|
||||
|
||||
|
||||
def test_MD5HashA2AuthInt(self, _algorithm='md5', _hash=md5):
|
||||
"""
|
||||
L{calcHA2} accepts the C{'md5'} algorithm and returns an MD5 hash of
|
||||
its arguments, including the entity hash for QOP of C{'auth-int'}.
|
||||
"""
|
||||
method = 'GET'
|
||||
hentity = 'foobarbaz'
|
||||
hashA2 = calcHA2(_algorithm, method, self.uri, 'auth-int', hentity)
|
||||
a2 = '%s:%s:%s' % (method, self.uri, hentity)
|
||||
expected = _hash(a2).hexdigest()
|
||||
self.assertEqual(hashA2, expected)
|
||||
|
||||
|
||||
def test_MD5SessHashA2Auth(self):
|
||||
"""
|
||||
L{calcHA2} accepts the C{'md5-sess'} algorithm and QOP of C{'auth'} and
|
||||
returns the same value as it does for the C{'md5'} algorithm.
|
||||
"""
|
||||
self.test_MD5HashA2Auth('md5-sess')
|
||||
|
||||
|
||||
def test_MD5SessHashA2AuthInt(self):
|
||||
"""
|
||||
L{calcHA2} accepts the C{'md5-sess'} algorithm and QOP of C{'auth-int'}
|
||||
and returns the same value as it does for the C{'md5'} algorithm.
|
||||
"""
|
||||
self.test_MD5HashA2AuthInt('md5-sess')
|
||||
|
||||
|
||||
def test_SHAHashA2Auth(self):
|
||||
"""
|
||||
L{calcHA2} accepts the C{'sha'} algorithm and returns a SHA hash of
|
||||
its arguments, excluding the entity hash for QOP other than
|
||||
C{'auth-int'}.
|
||||
"""
|
||||
self.test_MD5HashA2Auth('sha', sha1)
|
||||
|
||||
|
||||
def test_SHAHashA2AuthInt(self):
|
||||
"""
|
||||
L{calcHA2} accepts the C{'sha'} algorithm and returns a SHA hash of
|
||||
its arguments, including the entity hash for QOP of C{'auth-int'}.
|
||||
"""
|
||||
self.test_MD5HashA2AuthInt('sha', sha1)
|
||||
|
||||
|
||||
def test_MD5HashResponse(self, _algorithm='md5', _hash=md5):
|
||||
"""
|
||||
L{calcResponse} accepts the C{'md5'} algorithm and returns an MD5 hash
|
||||
of its parameters, excluding the nonce count, client nonce, and QoP
|
||||
value if the nonce count and client nonce are C{None}
|
||||
"""
|
||||
hashA1 = 'abc123'
|
||||
hashA2 = '789xyz'
|
||||
nonce = 'lmnopq'
|
||||
|
||||
response = '%s:%s:%s' % (hashA1, nonce, hashA2)
|
||||
expected = _hash(response).hexdigest()
|
||||
|
||||
digest = calcResponse(hashA1, hashA2, _algorithm, nonce, None, None,
|
||||
None)
|
||||
self.assertEqual(expected, digest)
|
||||
|
||||
|
||||
def test_MD5SessionHashResponse(self):
|
||||
"""
|
||||
L{calcResponse} accepts the C{'md5-sess'} algorithm and returns an MD5
|
||||
hash of its parameters, excluding the nonce count, client nonce, and
|
||||
QoP value if the nonce count and client nonce are C{None}
|
||||
"""
|
||||
self.test_MD5HashResponse('md5-sess')
|
||||
|
||||
|
||||
def test_SHAHashResponse(self):
|
||||
"""
|
||||
L{calcResponse} accepts the C{'sha'} algorithm and returns a SHA hash
|
||||
of its parameters, excluding the nonce count, client nonce, and QoP
|
||||
value if the nonce count and client nonce are C{None}
|
||||
"""
|
||||
self.test_MD5HashResponse('sha', sha1)
|
||||
|
||||
|
||||
def test_MD5HashResponseExtra(self, _algorithm='md5', _hash=md5):
|
||||
"""
|
||||
L{calcResponse} accepts the C{'md5'} algorithm and returns an MD5 hash
|
||||
of its parameters, including the nonce count, client nonce, and QoP
|
||||
value if they are specified.
|
||||
"""
|
||||
hashA1 = 'abc123'
|
||||
hashA2 = '789xyz'
|
||||
nonce = 'lmnopq'
|
||||
nonceCount = '00000004'
|
||||
clientNonce = 'abcxyz123'
|
||||
qop = 'auth'
|
||||
|
||||
response = '%s:%s:%s:%s:%s:%s' % (
|
||||
hashA1, nonce, nonceCount, clientNonce, qop, hashA2)
|
||||
expected = _hash(response).hexdigest()
|
||||
|
||||
digest = calcResponse(
|
||||
hashA1, hashA2, _algorithm, nonce, nonceCount, clientNonce, qop)
|
||||
self.assertEqual(expected, digest)
|
||||
|
||||
|
||||
def test_MD5SessionHashResponseExtra(self):
|
||||
"""
|
||||
L{calcResponse} accepts the C{'md5-sess'} algorithm and returns an MD5
|
||||
hash of its parameters, including the nonce count, client nonce, and
|
||||
QoP value if they are specified.
|
||||
"""
|
||||
self.test_MD5HashResponseExtra('md5-sess')
|
||||
|
||||
|
||||
def test_SHAHashResponseExtra(self):
|
||||
"""
|
||||
L{calcResponse} accepts the C{'sha'} algorithm and returns a SHA hash
|
||||
of its parameters, including the nonce count, client nonce, and QoP
|
||||
value if they are specified.
|
||||
"""
|
||||
self.test_MD5HashResponseExtra('sha', sha1)
|
||||
|
||||
|
||||
def formatResponse(self, quotes=True, **kw):
|
||||
"""
|
||||
Format all given keyword arguments and their values suitably for use as
|
||||
the value of an HTTP header.
|
||||
|
||||
@types quotes: C{bool}
|
||||
@param quotes: A flag indicating whether to quote the values of each
|
||||
field in the response.
|
||||
|
||||
@param **kw: Keywords and C{str} values which will be treated as field
|
||||
name/value pairs to include in the result.
|
||||
|
||||
@rtype: C{str}
|
||||
@return: The given fields formatted for use as an HTTP header value.
|
||||
"""
|
||||
if 'username' not in kw:
|
||||
kw['username'] = self.username
|
||||
if 'realm' not in kw:
|
||||
kw['realm'] = self.realm
|
||||
if 'algorithm' not in kw:
|
||||
kw['algorithm'] = self.algorithm
|
||||
if 'qop' not in kw:
|
||||
kw['qop'] = self.qop
|
||||
if 'cnonce' not in kw:
|
||||
kw['cnonce'] = self.cnonce
|
||||
if 'uri' not in kw:
|
||||
kw['uri'] = self.uri
|
||||
if quotes:
|
||||
quote = '"'
|
||||
else:
|
||||
quote = ''
|
||||
return ', '.join([
|
||||
'%s=%s%s%s' % (k, quote, v, quote)
|
||||
for (k, v)
|
||||
in kw.iteritems()
|
||||
if v is not None])
|
||||
|
||||
|
||||
def getDigestResponse(self, challenge, ncount):
|
||||
"""
|
||||
Calculate the response for the given challenge
|
||||
"""
|
||||
nonce = challenge.get('nonce')
|
||||
algo = challenge.get('algorithm').lower()
|
||||
qop = challenge.get('qop')
|
||||
|
||||
ha1 = calcHA1(
|
||||
algo, self.username, self.realm, self.password, nonce, self.cnonce)
|
||||
ha2 = calcHA2(algo, "GET", self.uri, qop, None)
|
||||
expected = calcResponse(ha1, ha2, algo, nonce, ncount, self.cnonce, qop)
|
||||
return expected
|
||||
|
||||
|
||||
def test_response(self, quotes=True):
|
||||
"""
|
||||
L{DigestCredentialFactory.decode} accepts a digest challenge response
|
||||
and parses it into an L{IUsernameHashedPassword} provider.
|
||||
"""
|
||||
challenge = self.credentialFactory.getChallenge(self.clientAddress.host)
|
||||
|
||||
nc = "00000001"
|
||||
clientResponse = self.formatResponse(
|
||||
quotes=quotes,
|
||||
nonce=challenge['nonce'],
|
||||
response=self.getDigestResponse(challenge, nc),
|
||||
nc=nc,
|
||||
opaque=challenge['opaque'])
|
||||
creds = self.credentialFactory.decode(
|
||||
clientResponse, self.method, self.clientAddress.host)
|
||||
self.assertTrue(creds.checkPassword(self.password))
|
||||
self.assertFalse(creds.checkPassword(self.password + 'wrong'))
|
||||
|
||||
|
||||
def test_responseWithoutQuotes(self):
|
||||
"""
|
||||
L{DigestCredentialFactory.decode} accepts a digest challenge response
|
||||
which does not quote the values of its fields and parses it into an
|
||||
L{IUsernameHashedPassword} provider in the same way it would a
|
||||
response which included quoted field values.
|
||||
"""
|
||||
self.test_response(False)
|
||||
|
||||
|
||||
def test_responseWithCommaURI(self):
|
||||
"""
|
||||
L{DigestCredentialFactory.decode} accepts a digest challenge response
|
||||
which quotes the values of its fields and includes a C{b","} in the URI
|
||||
field.
|
||||
"""
|
||||
self.uri = b"/some,path/"
|
||||
self.test_response(True)
|
||||
|
||||
|
||||
def test_caseInsensitiveAlgorithm(self):
|
||||
"""
|
||||
The case of the algorithm value in the response is ignored when
|
||||
checking the credentials.
|
||||
"""
|
||||
self.algorithm = 'MD5'
|
||||
self.test_response()
|
||||
|
||||
|
||||
def test_md5DefaultAlgorithm(self):
|
||||
"""
|
||||
The algorithm defaults to MD5 if it is not supplied in the response.
|
||||
"""
|
||||
self.algorithm = None
|
||||
self.test_response()
|
||||
|
||||
|
||||
def test_responseWithoutClientIP(self):
|
||||
"""
|
||||
L{DigestCredentialFactory.decode} accepts a digest challenge response
|
||||
even if the client address it is passed is C{None}.
|
||||
"""
|
||||
challenge = self.credentialFactory.getChallenge(None)
|
||||
|
||||
nc = "00000001"
|
||||
clientResponse = self.formatResponse(
|
||||
nonce=challenge['nonce'],
|
||||
response=self.getDigestResponse(challenge, nc),
|
||||
nc=nc,
|
||||
opaque=challenge['opaque'])
|
||||
creds = self.credentialFactory.decode(clientResponse, self.method, None)
|
||||
self.assertTrue(creds.checkPassword(self.password))
|
||||
self.assertFalse(creds.checkPassword(self.password + 'wrong'))
|
||||
|
||||
|
||||
def test_multiResponse(self):
|
||||
"""
|
||||
L{DigestCredentialFactory.decode} handles multiple responses to a
|
||||
single challenge.
|
||||
"""
|
||||
challenge = self.credentialFactory.getChallenge(self.clientAddress.host)
|
||||
|
||||
nc = "00000001"
|
||||
clientResponse = self.formatResponse(
|
||||
nonce=challenge['nonce'],
|
||||
response=self.getDigestResponse(challenge, nc),
|
||||
nc=nc,
|
||||
opaque=challenge['opaque'])
|
||||
|
||||
creds = self.credentialFactory.decode(clientResponse, self.method,
|
||||
self.clientAddress.host)
|
||||
self.assertTrue(creds.checkPassword(self.password))
|
||||
self.assertFalse(creds.checkPassword(self.password + 'wrong'))
|
||||
|
||||
nc = "00000002"
|
||||
clientResponse = self.formatResponse(
|
||||
nonce=challenge['nonce'],
|
||||
response=self.getDigestResponse(challenge, nc),
|
||||
nc=nc,
|
||||
opaque=challenge['opaque'])
|
||||
|
||||
creds = self.credentialFactory.decode(clientResponse, self.method,
|
||||
self.clientAddress.host)
|
||||
self.assertTrue(creds.checkPassword(self.password))
|
||||
self.assertFalse(creds.checkPassword(self.password + 'wrong'))
|
||||
|
||||
|
||||
def test_failsWithDifferentMethod(self):
|
||||
"""
|
||||
L{DigestCredentialFactory.decode} returns an L{IUsernameHashedPassword}
|
||||
provider which rejects a correct password for the given user if the
|
||||
challenge response request is made using a different HTTP method than
|
||||
was used to request the initial challenge.
|
||||
"""
|
||||
challenge = self.credentialFactory.getChallenge(self.clientAddress.host)
|
||||
|
||||
nc = "00000001"
|
||||
clientResponse = self.formatResponse(
|
||||
nonce=challenge['nonce'],
|
||||
response=self.getDigestResponse(challenge, nc),
|
||||
nc=nc,
|
||||
opaque=challenge['opaque'])
|
||||
creds = self.credentialFactory.decode(clientResponse, 'POST',
|
||||
self.clientAddress.host)
|
||||
self.assertFalse(creds.checkPassword(self.password))
|
||||
self.assertFalse(creds.checkPassword(self.password + 'wrong'))
|
||||
|
||||
|
||||
def test_noUsername(self):
|
||||
"""
|
||||
L{DigestCredentialFactory.decode} raises L{LoginFailed} if the response
|
||||
has no username field or if the username field is empty.
|
||||
"""
|
||||
# Check for no username
|
||||
e = self.assertRaises(
|
||||
LoginFailed,
|
||||
self.credentialFactory.decode,
|
||||
self.formatResponse(username=None),
|
||||
self.method, self.clientAddress.host)
|
||||
self.assertEqual(str(e), "Invalid response, no username given.")
|
||||
|
||||
# Check for an empty username
|
||||
e = self.assertRaises(
|
||||
LoginFailed,
|
||||
self.credentialFactory.decode,
|
||||
self.formatResponse(username=""),
|
||||
self.method, self.clientAddress.host)
|
||||
self.assertEqual(str(e), "Invalid response, no username given.")
|
||||
|
||||
|
||||
def test_noNonce(self):
|
||||
"""
|
||||
L{DigestCredentialFactory.decode} raises L{LoginFailed} if the response
|
||||
has no nonce.
|
||||
"""
|
||||
e = self.assertRaises(
|
||||
LoginFailed,
|
||||
self.credentialFactory.decode,
|
||||
self.formatResponse(opaque="abc123"),
|
||||
self.method, self.clientAddress.host)
|
||||
self.assertEqual(str(e), "Invalid response, no nonce given.")
|
||||
|
||||
|
||||
def test_noOpaque(self):
|
||||
"""
|
||||
L{DigestCredentialFactory.decode} raises L{LoginFailed} if the response
|
||||
has no opaque.
|
||||
"""
|
||||
e = self.assertRaises(
|
||||
LoginFailed,
|
||||
self.credentialFactory.decode,
|
||||
self.formatResponse(),
|
||||
self.method, self.clientAddress.host)
|
||||
self.assertEqual(str(e), "Invalid response, no opaque given.")
|
||||
|
||||
|
||||
def test_checkHash(self):
|
||||
"""
|
||||
L{DigestCredentialFactory.decode} returns an L{IUsernameDigestHash}
|
||||
provider which can verify a hash of the form 'username:realm:password'.
|
||||
"""
|
||||
challenge = self.credentialFactory.getChallenge(self.clientAddress.host)
|
||||
|
||||
nc = "00000001"
|
||||
clientResponse = self.formatResponse(
|
||||
nonce=challenge['nonce'],
|
||||
response=self.getDigestResponse(challenge, nc),
|
||||
nc=nc,
|
||||
opaque=challenge['opaque'])
|
||||
|
||||
creds = self.credentialFactory.decode(clientResponse, self.method,
|
||||
self.clientAddress.host)
|
||||
self.assertTrue(verifyObject(IUsernameDigestHash, creds))
|
||||
|
||||
cleartext = '%s:%s:%s' % (self.username, self.realm, self.password)
|
||||
hash = md5(cleartext)
|
||||
self.assertTrue(creds.checkHash(hash.hexdigest()))
|
||||
hash.update('wrong')
|
||||
self.assertFalse(creds.checkHash(hash.hexdigest()))
|
||||
|
||||
|
||||
def test_invalidOpaque(self):
|
||||
"""
|
||||
L{DigestCredentialFactory.decode} raises L{LoginFailed} when the opaque
|
||||
value does not contain all the required parts.
|
||||
"""
|
||||
credentialFactory = FakeDigestCredentialFactory(self.algorithm,
|
||||
self.realm)
|
||||
challenge = credentialFactory.getChallenge(self.clientAddress.host)
|
||||
|
||||
exc = self.assertRaises(
|
||||
LoginFailed,
|
||||
credentialFactory._verifyOpaque,
|
||||
'badOpaque',
|
||||
challenge['nonce'],
|
||||
self.clientAddress.host)
|
||||
self.assertEqual(str(exc), 'Invalid response, invalid opaque value')
|
||||
|
||||
badOpaque = 'foo-' + b64encode('nonce,clientip')
|
||||
|
||||
exc = self.assertRaises(
|
||||
LoginFailed,
|
||||
credentialFactory._verifyOpaque,
|
||||
badOpaque,
|
||||
challenge['nonce'],
|
||||
self.clientAddress.host)
|
||||
self.assertEqual(str(exc), 'Invalid response, invalid opaque value')
|
||||
|
||||
exc = self.assertRaises(
|
||||
LoginFailed,
|
||||
credentialFactory._verifyOpaque,
|
||||
'',
|
||||
challenge['nonce'],
|
||||
self.clientAddress.host)
|
||||
self.assertEqual(str(exc), 'Invalid response, invalid opaque value')
|
||||
|
||||
badOpaque = (
|
||||
'foo-' + b64encode('%s,%s,foobar' % (
|
||||
challenge['nonce'],
|
||||
self.clientAddress.host)))
|
||||
exc = self.assertRaises(
|
||||
LoginFailed,
|
||||
credentialFactory._verifyOpaque,
|
||||
badOpaque,
|
||||
challenge['nonce'],
|
||||
self.clientAddress.host)
|
||||
self.assertEqual(
|
||||
str(exc), 'Invalid response, invalid opaque/time values')
|
||||
|
||||
|
||||
def test_incompatibleNonce(self):
|
||||
"""
|
||||
L{DigestCredentialFactory.decode} raises L{LoginFailed} when the given
|
||||
nonce from the response does not match the nonce encoded in the opaque.
|
||||
"""
|
||||
credentialFactory = FakeDigestCredentialFactory(self.algorithm, self.realm)
|
||||
challenge = credentialFactory.getChallenge(self.clientAddress.host)
|
||||
|
||||
badNonceOpaque = credentialFactory._generateOpaque(
|
||||
'1234567890',
|
||||
self.clientAddress.host)
|
||||
|
||||
exc = self.assertRaises(
|
||||
LoginFailed,
|
||||
credentialFactory._verifyOpaque,
|
||||
badNonceOpaque,
|
||||
challenge['nonce'],
|
||||
self.clientAddress.host)
|
||||
self.assertEqual(
|
||||
str(exc),
|
||||
'Invalid response, incompatible opaque/nonce values')
|
||||
|
||||
exc = self.assertRaises(
|
||||
LoginFailed,
|
||||
credentialFactory._verifyOpaque,
|
||||
badNonceOpaque,
|
||||
'',
|
||||
self.clientAddress.host)
|
||||
self.assertEqual(
|
||||
str(exc),
|
||||
'Invalid response, incompatible opaque/nonce values')
|
||||
|
||||
|
||||
def test_incompatibleClientIP(self):
|
||||
"""
|
||||
L{DigestCredentialFactory.decode} raises L{LoginFailed} when the
|
||||
request comes from a client IP other than what is encoded in the
|
||||
opaque.
|
||||
"""
|
||||
credentialFactory = FakeDigestCredentialFactory(self.algorithm, self.realm)
|
||||
challenge = credentialFactory.getChallenge(self.clientAddress.host)
|
||||
|
||||
badAddress = '10.0.0.1'
|
||||
# Sanity check
|
||||
self.assertNotEqual(self.clientAddress.host, badAddress)
|
||||
|
||||
badNonceOpaque = credentialFactory._generateOpaque(
|
||||
challenge['nonce'], badAddress)
|
||||
|
||||
self.assertRaises(
|
||||
LoginFailed,
|
||||
credentialFactory._verifyOpaque,
|
||||
badNonceOpaque,
|
||||
challenge['nonce'],
|
||||
self.clientAddress.host)
|
||||
|
||||
|
||||
def test_oldNonce(self):
|
||||
"""
|
||||
L{DigestCredentialFactory.decode} raises L{LoginFailed} when the given
|
||||
opaque is older than C{DigestCredentialFactory.CHALLENGE_LIFETIME_SECS}
|
||||
"""
|
||||
credentialFactory = FakeDigestCredentialFactory(self.algorithm,
|
||||
self.realm)
|
||||
challenge = credentialFactory.getChallenge(self.clientAddress.host)
|
||||
|
||||
key = '%s,%s,%s' % (challenge['nonce'],
|
||||
self.clientAddress.host,
|
||||
'-137876876')
|
||||
digest = md5(key + credentialFactory.privateKey).hexdigest()
|
||||
ekey = b64encode(key)
|
||||
|
||||
oldNonceOpaque = '%s-%s' % (digest, ekey.strip('\n'))
|
||||
|
||||
self.assertRaises(
|
||||
LoginFailed,
|
||||
credentialFactory._verifyOpaque,
|
||||
oldNonceOpaque,
|
||||
challenge['nonce'],
|
||||
self.clientAddress.host)
|
||||
|
||||
|
||||
def test_mismatchedOpaqueChecksum(self):
|
||||
"""
|
||||
L{DigestCredentialFactory.decode} raises L{LoginFailed} when the opaque
|
||||
checksum fails verification.
|
||||
"""
|
||||
credentialFactory = FakeDigestCredentialFactory(self.algorithm,
|
||||
self.realm)
|
||||
challenge = credentialFactory.getChallenge(self.clientAddress.host)
|
||||
|
||||
key = '%s,%s,%s' % (challenge['nonce'],
|
||||
self.clientAddress.host,
|
||||
'0')
|
||||
|
||||
digest = md5(key + 'this is not the right pkey').hexdigest()
|
||||
badChecksum = '%s-%s' % (digest, b64encode(key))
|
||||
|
||||
self.assertRaises(
|
||||
LoginFailed,
|
||||
credentialFactory._verifyOpaque,
|
||||
badChecksum,
|
||||
challenge['nonce'],
|
||||
self.clientAddress.host)
|
||||
|
||||
|
||||
def test_incompatibleCalcHA1Options(self):
|
||||
"""
|
||||
L{calcHA1} raises L{TypeError} when any of the pszUsername, pszRealm,
|
||||
or pszPassword arguments are specified with the preHA1 keyword
|
||||
argument.
|
||||
"""
|
||||
arguments = (
|
||||
("user", "realm", "password", "preHA1"),
|
||||
(None, "realm", None, "preHA1"),
|
||||
(None, None, "password", "preHA1"),
|
||||
)
|
||||
|
||||
for pszUsername, pszRealm, pszPassword, preHA1 in arguments:
|
||||
self.assertRaises(
|
||||
TypeError,
|
||||
calcHA1,
|
||||
"md5",
|
||||
pszUsername,
|
||||
pszRealm,
|
||||
pszPassword,
|
||||
"nonce",
|
||||
"cnonce",
|
||||
preHA1=preHA1)
|
||||
|
||||
|
||||
def test_noNewlineOpaque(self):
|
||||
"""
|
||||
L{DigestCredentialFactory._generateOpaque} returns a value without
|
||||
newlines, regardless of the length of the nonce.
|
||||
"""
|
||||
opaque = self.credentialFactory._generateOpaque(
|
||||
"long nonce " * 10, None)
|
||||
self.assertNotIn('\n', opaque)
|
||||
|
|
@ -0,0 +1,170 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test cases for dirdbm module.
|
||||
"""
|
||||
|
||||
import os, shutil, glob
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.persisted import dirdbm
|
||||
|
||||
|
||||
|
||||
class DirDbmTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.path = self.mktemp()
|
||||
self.dbm = dirdbm.open(self.path)
|
||||
self.items = (('abc', 'foo'), ('/lalal', '\000\001'), ('\000\012', 'baz'))
|
||||
|
||||
|
||||
def testAll(self):
|
||||
k = "//==".decode("base64")
|
||||
self.dbm[k] = "a"
|
||||
self.dbm[k] = "a"
|
||||
self.assertEqual(self.dbm[k], "a")
|
||||
|
||||
|
||||
def testRebuildInteraction(self):
|
||||
from twisted.persisted import dirdbm
|
||||
from twisted.python import rebuild
|
||||
|
||||
s = dirdbm.Shelf('dirdbm.rebuild.test')
|
||||
s['key'] = 'value'
|
||||
rebuild.rebuild(dirdbm)
|
||||
# print s['key']
|
||||
|
||||
|
||||
def testDbm(self):
|
||||
d = self.dbm
|
||||
|
||||
# insert keys
|
||||
keys = []
|
||||
values = set()
|
||||
for k, v in self.items:
|
||||
d[k] = v
|
||||
keys.append(k)
|
||||
values.add(v)
|
||||
keys.sort()
|
||||
|
||||
# check they exist
|
||||
for k, v in self.items:
|
||||
assert d.has_key(k), "has_key() failed"
|
||||
assert d[k] == v, "database has wrong value"
|
||||
|
||||
# check non existent key
|
||||
try:
|
||||
d["XXX"]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
assert 0, "didn't raise KeyError on non-existent key"
|
||||
|
||||
# check keys(), values() and items()
|
||||
dbkeys = list(d.keys())
|
||||
dbvalues = set(d.values())
|
||||
dbitems = set(d.items())
|
||||
dbkeys.sort()
|
||||
items = set(self.items)
|
||||
assert keys == dbkeys, ".keys() output didn't match: %s != %s" % (repr(keys), repr(dbkeys))
|
||||
assert values == dbvalues, ".values() output didn't match: %s != %s" % (repr(values), repr(dbvalues))
|
||||
assert items == dbitems, "items() didn't match: %s != %s" % (repr(items), repr(dbitems))
|
||||
|
||||
copyPath = self.mktemp()
|
||||
d2 = d.copyTo(copyPath)
|
||||
|
||||
copykeys = list(d.keys())
|
||||
copyvalues = set(d.values())
|
||||
copyitems = set(d.items())
|
||||
copykeys.sort()
|
||||
|
||||
assert dbkeys == copykeys, ".copyTo().keys() didn't match: %s != %s" % (repr(dbkeys), repr(copykeys))
|
||||
assert dbvalues == copyvalues, ".copyTo().values() didn't match: %s != %s" % (repr(dbvalues), repr(copyvalues))
|
||||
assert dbitems == copyitems, ".copyTo().items() didn't match: %s != %s" % (repr(dbkeys), repr(copyitems))
|
||||
|
||||
d2.clear()
|
||||
assert len(d2.keys()) == len(d2.values()) == len(d2.items()) == 0, ".clear() failed"
|
||||
shutil.rmtree(copyPath)
|
||||
|
||||
# delete items
|
||||
for k, v in self.items:
|
||||
del d[k]
|
||||
assert not d.has_key(k), "has_key() even though we deleted it"
|
||||
assert len(d.keys()) == 0, "database has keys"
|
||||
assert len(d.values()) == 0, "database has values"
|
||||
assert len(d.items()) == 0, "database has items"
|
||||
|
||||
|
||||
def testModificationTime(self):
|
||||
import time
|
||||
# the mtime value for files comes from a different place than the
|
||||
# gettimeofday() system call. On linux, gettimeofday() can be
|
||||
# slightly ahead (due to clock drift which gettimeofday() takes into
|
||||
# account but which open()/write()/close() do not), and if we are
|
||||
# close to the edge of the next second, time.time() can give a value
|
||||
# which is larger than the mtime which results from a subsequent
|
||||
# write(). I consider this a kernel bug, but it is beyond the scope
|
||||
# of this test. Thus we keep the range of acceptability to 3 seconds time.
|
||||
# -warner
|
||||
self.dbm["k"] = "v"
|
||||
self.assert_(abs(time.time() - self.dbm.getModificationTime("k")) <= 3)
|
||||
|
||||
|
||||
def testRecovery(self):
|
||||
"""DirDBM: test recovery from directory after a faked crash"""
|
||||
k = self.dbm._encode("key1")
|
||||
f = open(os.path.join(self.path, k + ".rpl"), "wb")
|
||||
f.write("value")
|
||||
f.close()
|
||||
|
||||
k2 = self.dbm._encode("key2")
|
||||
f = open(os.path.join(self.path, k2), "wb")
|
||||
f.write("correct")
|
||||
f.close()
|
||||
f = open(os.path.join(self.path, k2 + ".rpl"), "wb")
|
||||
f.write("wrong")
|
||||
f.close()
|
||||
|
||||
f = open(os.path.join(self.path, "aa.new"), "wb")
|
||||
f.write("deleted")
|
||||
f.close()
|
||||
|
||||
dbm = dirdbm.DirDBM(self.path)
|
||||
assert dbm["key1"] == "value"
|
||||
assert dbm["key2"] == "correct"
|
||||
assert not glob.glob(os.path.join(self.path, "*.new"))
|
||||
assert not glob.glob(os.path.join(self.path, "*.rpl"))
|
||||
|
||||
|
||||
def test_nonStringKeys(self):
|
||||
"""
|
||||
L{dirdbm.DirDBM} operations only support string keys: other types
|
||||
should raise a C{AssertionError}. This really ought to be a
|
||||
C{TypeError}, but it'll stay like this for backward compatibility.
|
||||
"""
|
||||
self.assertRaises(AssertionError, self.dbm.__setitem__, 2, "3")
|
||||
try:
|
||||
self.assertRaises(AssertionError, self.dbm.__setitem__, "2", 3)
|
||||
except unittest.FailTest:
|
||||
# dirdbm.Shelf.__setitem__ supports non-string values
|
||||
self.assertIsInstance(self.dbm, dirdbm.Shelf)
|
||||
self.assertRaises(AssertionError, self.dbm.__getitem__, 2)
|
||||
self.assertRaises(AssertionError, self.dbm.__delitem__, 2)
|
||||
self.assertRaises(AssertionError, self.dbm.has_key, 2)
|
||||
self.assertRaises(AssertionError, self.dbm.__contains__, 2)
|
||||
self.assertRaises(AssertionError, self.dbm.getModificationTime, 2)
|
||||
|
||||
|
||||
|
||||
class ShelfTestCase(DirDbmTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.path = self.mktemp()
|
||||
self.dbm = dirdbm.Shelf(self.path)
|
||||
self.items = (('abc', 'foo'), ('/lalal', '\000\001'), ('\000\012', 'baz'),
|
||||
('int', 12), ('float', 12.0), ('tuple', (None, 12)))
|
||||
|
||||
|
||||
testCases = [DirDbmTestCase, ShelfTestCase]
|
||||
104
Linux_i686/lib/python2.7/site-packages/twisted/test/test_doc.py
Normal file
104
Linux_i686/lib/python2.7/site-packages/twisted/test/test_doc.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
import inspect, glob
|
||||
from os import path
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.python import reflect
|
||||
from twisted.python.modules import getModule
|
||||
|
||||
|
||||
|
||||
def errorInFile(f, line=17, name=''):
|
||||
"""
|
||||
Return a filename formatted so emacs will recognize it as an error point
|
||||
|
||||
@param line: Line number in file. Defaults to 17 because that's about how
|
||||
long the copyright headers are.
|
||||
"""
|
||||
return '%s:%d:%s' % (f, line, name)
|
||||
# return 'File "%s", line %d, in %s' % (f, line, name)
|
||||
|
||||
|
||||
class DocCoverage(unittest.TestCase):
|
||||
"""
|
||||
Looking for docstrings in all modules and packages.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.packageNames = []
|
||||
for mod in getModule('twisted').walkModules():
|
||||
if mod.isPackage():
|
||||
self.packageNames.append(mod.name)
|
||||
|
||||
|
||||
def testModules(self):
|
||||
"""
|
||||
Looking for docstrings in all modules.
|
||||
"""
|
||||
docless = []
|
||||
for packageName in self.packageNames:
|
||||
if packageName in ('twisted.test',):
|
||||
# because some stuff in here behaves oddly when imported
|
||||
continue
|
||||
try:
|
||||
package = reflect.namedModule(packageName)
|
||||
except ImportError, e:
|
||||
# This is testing doc coverage, not importability.
|
||||
# (Really, I don't want to deal with the fact that I don't
|
||||
# have pyserial installed.)
|
||||
# print e
|
||||
pass
|
||||
else:
|
||||
docless.extend(self.modulesInPackage(packageName, package))
|
||||
self.failIf(docless, "No docstrings in module files:\n"
|
||||
"%s" % ('\n'.join(map(errorInFile, docless)),))
|
||||
|
||||
|
||||
def modulesInPackage(self, packageName, package):
|
||||
docless = []
|
||||
directory = path.dirname(package.__file__)
|
||||
for modfile in glob.glob(path.join(directory, '*.py')):
|
||||
moduleName = inspect.getmodulename(modfile)
|
||||
if moduleName == '__init__':
|
||||
# These are tested by test_packages.
|
||||
continue
|
||||
elif moduleName in ('spelunk_gnome','gtkmanhole'):
|
||||
# argh special case pygtk evil argh. How does epydoc deal
|
||||
# with this?
|
||||
continue
|
||||
try:
|
||||
module = reflect.namedModule('.'.join([packageName,
|
||||
moduleName]))
|
||||
except Exception, e:
|
||||
# print moduleName, "misbehaved:", e
|
||||
pass
|
||||
else:
|
||||
if not inspect.getdoc(module):
|
||||
docless.append(modfile)
|
||||
return docless
|
||||
|
||||
|
||||
def testPackages(self):
|
||||
"""
|
||||
Looking for docstrings in all packages.
|
||||
"""
|
||||
docless = []
|
||||
for packageName in self.packageNames:
|
||||
try:
|
||||
package = reflect.namedModule(packageName)
|
||||
except Exception, e:
|
||||
# This is testing doc coverage, not importability.
|
||||
# (Really, I don't want to deal with the fact that I don't
|
||||
# have pyserial installed.)
|
||||
# print e
|
||||
pass
|
||||
else:
|
||||
if not inspect.getdoc(package):
|
||||
docless.append(package.__file__.replace('.pyc','.py'))
|
||||
self.failIf(docless, "No docstrings for package files\n"
|
||||
"%s" % ('\n'.join(map(errorInFile, docless),)))
|
||||
|
||||
|
||||
# This test takes a while and doesn't come close to passing. :(
|
||||
testModules.skip = "Activate me when you feel like writing docstrings, and fixing GTK crashing bugs."
|
||||
|
|
@ -0,0 +1,266 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
from __future__ import division, absolute_import
|
||||
|
||||
import socket, errno
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import error
|
||||
from twisted.python.runtime import platformType
|
||||
|
||||
|
||||
class TestStringification(unittest.SynchronousTestCase):
|
||||
"""Test that the exceptions have useful stringifications.
|
||||
"""
|
||||
|
||||
listOfTests = [
|
||||
#(output, exception[, args[, kwargs]]),
|
||||
|
||||
("An error occurred binding to an interface.",
|
||||
error.BindError),
|
||||
|
||||
("An error occurred binding to an interface: foo.",
|
||||
error.BindError, ['foo']),
|
||||
|
||||
("An error occurred binding to an interface: foo bar.",
|
||||
error.BindError, ['foo', 'bar']),
|
||||
|
||||
("Couldn't listen on eth0:4242: Foo.",
|
||||
error.CannotListenError,
|
||||
('eth0', 4242, socket.error('Foo'))),
|
||||
|
||||
("Message is too long to send.",
|
||||
error.MessageLengthError),
|
||||
|
||||
("Message is too long to send: foo bar.",
|
||||
error.MessageLengthError, ['foo', 'bar']),
|
||||
|
||||
("DNS lookup failed.",
|
||||
error.DNSLookupError),
|
||||
|
||||
("DNS lookup failed: foo bar.",
|
||||
error.DNSLookupError, ['foo', 'bar']),
|
||||
|
||||
("An error occurred while connecting.",
|
||||
error.ConnectError),
|
||||
|
||||
("An error occurred while connecting: someOsError.",
|
||||
error.ConnectError, ['someOsError']),
|
||||
|
||||
("An error occurred while connecting: foo.",
|
||||
error.ConnectError, [], {'string': 'foo'}),
|
||||
|
||||
("An error occurred while connecting: someOsError: foo.",
|
||||
error.ConnectError, ['someOsError', 'foo']),
|
||||
|
||||
("Couldn't bind.",
|
||||
error.ConnectBindError),
|
||||
|
||||
("Couldn't bind: someOsError.",
|
||||
error.ConnectBindError, ['someOsError']),
|
||||
|
||||
("Couldn't bind: someOsError: foo.",
|
||||
error.ConnectBindError, ['someOsError', 'foo']),
|
||||
|
||||
("Hostname couldn't be looked up.",
|
||||
error.UnknownHostError),
|
||||
|
||||
("No route to host.",
|
||||
error.NoRouteError),
|
||||
|
||||
("Connection was refused by other side.",
|
||||
error.ConnectionRefusedError),
|
||||
|
||||
("TCP connection timed out.",
|
||||
error.TCPTimedOutError),
|
||||
|
||||
("File used for UNIX socket is no good.",
|
||||
error.BadFileError),
|
||||
|
||||
("Service name given as port is unknown.",
|
||||
error.ServiceNameUnknownError),
|
||||
|
||||
("User aborted connection.",
|
||||
error.UserError),
|
||||
|
||||
("User timeout caused connection failure.",
|
||||
error.TimeoutError),
|
||||
|
||||
("An SSL error occurred.",
|
||||
error.SSLError),
|
||||
|
||||
("Connection to the other side was lost in a non-clean fashion.",
|
||||
error.ConnectionLost),
|
||||
|
||||
("Connection to the other side was lost in a non-clean fashion: foo bar.",
|
||||
error.ConnectionLost, ['foo', 'bar']),
|
||||
|
||||
("Connection was closed cleanly.",
|
||||
error.ConnectionDone),
|
||||
|
||||
("Connection was closed cleanly: foo bar.",
|
||||
error.ConnectionDone, ['foo', 'bar']),
|
||||
|
||||
("Uh.", #TODO nice docstring, you've got there.
|
||||
error.ConnectionFdescWentAway),
|
||||
|
||||
("Tried to cancel an already-called event.",
|
||||
error.AlreadyCalled),
|
||||
|
||||
("Tried to cancel an already-called event: foo bar.",
|
||||
error.AlreadyCalled, ['foo', 'bar']),
|
||||
|
||||
("Tried to cancel an already-cancelled event.",
|
||||
error.AlreadyCancelled),
|
||||
|
||||
("Tried to cancel an already-cancelled event: x 2.",
|
||||
error.AlreadyCancelled, ["x", "2"]),
|
||||
|
||||
("A process has ended without apparent errors: process finished with exit code 0.",
|
||||
error.ProcessDone,
|
||||
[None]),
|
||||
|
||||
("A process has ended with a probable error condition: process ended.",
|
||||
error.ProcessTerminated),
|
||||
|
||||
("A process has ended with a probable error condition: process ended with exit code 42.",
|
||||
error.ProcessTerminated,
|
||||
[],
|
||||
{'exitCode': 42}),
|
||||
|
||||
("A process has ended with a probable error condition: process ended by signal SIGBUS.",
|
||||
error.ProcessTerminated,
|
||||
[],
|
||||
{'signal': 'SIGBUS'}),
|
||||
|
||||
("The Connector was not connecting when it was asked to stop connecting.",
|
||||
error.NotConnectingError),
|
||||
|
||||
("The Connector was not connecting when it was asked to stop connecting: x 13.",
|
||||
error.NotConnectingError, ["x", "13"]),
|
||||
|
||||
("The Port was not listening when it was asked to stop listening.",
|
||||
error.NotListeningError),
|
||||
|
||||
("The Port was not listening when it was asked to stop listening: a 12.",
|
||||
error.NotListeningError, ["a", "12"]),
|
||||
]
|
||||
|
||||
def testThemAll(self):
|
||||
for entry in self.listOfTests:
|
||||
output = entry[0]
|
||||
exception = entry[1]
|
||||
try:
|
||||
args = entry[2]
|
||||
except IndexError:
|
||||
args = ()
|
||||
try:
|
||||
kwargs = entry[3]
|
||||
except IndexError:
|
||||
kwargs = {}
|
||||
|
||||
self.assertEqual(
|
||||
str(exception(*args, **kwargs)),
|
||||
output)
|
||||
|
||||
|
||||
def test_connectingCancelledError(self):
|
||||
"""
|
||||
L{error.ConnectingCancelledError} has an C{address} attribute.
|
||||
"""
|
||||
address = object()
|
||||
e = error.ConnectingCancelledError(address)
|
||||
self.assertIdentical(e.address, address)
|
||||
|
||||
|
||||
|
||||
class SubclassingTests(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Some exceptions are subclasses of other exceptions.
|
||||
"""
|
||||
|
||||
def test_connectionLostSubclassOfConnectionClosed(self):
|
||||
"""
|
||||
L{error.ConnectionClosed} is a superclass of L{error.ConnectionLost}.
|
||||
"""
|
||||
self.assertTrue(issubclass(error.ConnectionLost,
|
||||
error.ConnectionClosed))
|
||||
|
||||
|
||||
def test_connectionDoneSubclassOfConnectionClosed(self):
|
||||
"""
|
||||
L{error.ConnectionClosed} is a superclass of L{error.ConnectionDone}.
|
||||
"""
|
||||
self.assertTrue(issubclass(error.ConnectionDone,
|
||||
error.ConnectionClosed))
|
||||
|
||||
|
||||
def test_invalidAddressErrorSubclassOfValueError(self):
|
||||
"""
|
||||
L{ValueError} is a superclass of L{error.InvalidAddressError}.
|
||||
"""
|
||||
self.assertTrue(issubclass(error.InvalidAddressError,
|
||||
ValueError))
|
||||
|
||||
|
||||
|
||||
|
||||
class GetConnectErrorTests(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Given an exception instance thrown by C{socket.connect},
|
||||
L{error.getConnectError} returns the appropriate high-level Twisted
|
||||
exception instance.
|
||||
"""
|
||||
|
||||
def assertErrnoException(self, errno, expectedClass):
|
||||
"""
|
||||
When called with a tuple with the given errno,
|
||||
L{error.getConnectError} returns an exception which is an instance of
|
||||
the expected class.
|
||||
"""
|
||||
e = (errno, "lalala")
|
||||
result = error.getConnectError(e)
|
||||
self.assertCorrectException(errno, "lalala", result, expectedClass)
|
||||
|
||||
|
||||
def assertCorrectException(self, errno, message, result, expectedClass):
|
||||
"""
|
||||
The given result of L{error.getConnectError} has the given attributes
|
||||
(C{osError} and C{args}), and is an instance of the given class.
|
||||
"""
|
||||
|
||||
# Want exact class match, not inherited classes, so no isinstance():
|
||||
self.assertEqual(result.__class__, expectedClass)
|
||||
self.assertEqual(result.osError, errno)
|
||||
self.assertEqual(result.args, (message,))
|
||||
|
||||
|
||||
def test_errno(self):
|
||||
"""
|
||||
L{error.getConnectError} converts based on errno for C{socket.error}.
|
||||
"""
|
||||
self.assertErrnoException(errno.ENETUNREACH, error.NoRouteError)
|
||||
self.assertErrnoException(errno.ECONNREFUSED, error.ConnectionRefusedError)
|
||||
self.assertErrnoException(errno.ETIMEDOUT, error.TCPTimedOutError)
|
||||
if platformType == "win32":
|
||||
self.assertErrnoException(errno.WSAECONNREFUSED, error.ConnectionRefusedError)
|
||||
self.assertErrnoException(errno.WSAENETUNREACH, error.NoRouteError)
|
||||
|
||||
|
||||
def test_gaierror(self):
|
||||
"""
|
||||
L{error.getConnectError} converts to a L{error.UnknownHostError} given
|
||||
a C{socket.gaierror} instance.
|
||||
"""
|
||||
result = error.getConnectError(socket.gaierror(12, "hello"))
|
||||
self.assertCorrectException(12, "hello", result, error.UnknownHostError)
|
||||
|
||||
|
||||
def test_nonTuple(self):
|
||||
"""
|
||||
L{error.getConnectError} converts to a L{error.ConnectError} given
|
||||
an argument that cannot be unpacked.
|
||||
"""
|
||||
e = Exception()
|
||||
result = error.getConnectError(e)
|
||||
self.assertCorrectException(None, e, result, error.ConnectError)
|
||||
|
|
@ -0,0 +1,236 @@
|
|||
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""
|
||||
Test cases for explorer
|
||||
"""
|
||||
|
||||
from twisted.trial import unittest
|
||||
|
||||
from twisted.manhole import explorer
|
||||
|
||||
import types
|
||||
|
||||
"""
|
||||
# Tests:
|
||||
|
||||
Get an ObjectLink. Browse ObjectLink.identifier. Is it the same?
|
||||
|
||||
Watch Object. Make sure an ObjectLink is received when:
|
||||
Call a method.
|
||||
Set an attribute.
|
||||
|
||||
Have an Object with a setattr class. Watch it.
|
||||
Do both the navite setattr and the watcher get called?
|
||||
|
||||
Sequences with circular references. Does it blow up?
|
||||
"""
|
||||
|
||||
class SomeDohickey:
|
||||
def __init__(self, *a):
|
||||
self.__dict__['args'] = a
|
||||
|
||||
def bip(self):
|
||||
return self.args
|
||||
|
||||
|
||||
class TestBrowser(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.pool = explorer.explorerPool
|
||||
self.pool.clear()
|
||||
self.testThing = ["How many stairs must a man climb down?",
|
||||
SomeDohickey(42)]
|
||||
|
||||
def test_chain(self):
|
||||
"Following a chain of Explorers."
|
||||
xplorer = self.pool.getExplorer(self.testThing, 'testThing')
|
||||
self.assertEqual(xplorer.id, id(self.testThing))
|
||||
self.assertEqual(xplorer.identifier, 'testThing')
|
||||
|
||||
dxplorer = xplorer.get_elements()[1]
|
||||
self.assertEqual(dxplorer.id, id(self.testThing[1]))
|
||||
|
||||
class Watcher:
|
||||
zero = 0
|
||||
def __init__(self):
|
||||
self.links = []
|
||||
|
||||
def receiveBrowserObject(self, olink):
|
||||
self.links.append(olink)
|
||||
|
||||
def setZero(self):
|
||||
self.zero = len(self.links)
|
||||
|
||||
def len(self):
|
||||
return len(self.links) - self.zero
|
||||
|
||||
|
||||
class SetattrDohickey:
|
||||
def __setattr__(self, k, v):
|
||||
v = list(str(v))
|
||||
v.reverse()
|
||||
self.__dict__[k] = ''.join(v)
|
||||
|
||||
class MiddleMan(SomeDohickey, SetattrDohickey):
|
||||
pass
|
||||
|
||||
# class TestWatch(unittest.TestCase):
|
||||
class FIXME_Watch:
|
||||
def setUp(self):
|
||||
self.globalNS = globals().copy()
|
||||
self.localNS = {}
|
||||
self.browser = explorer.ObjectBrowser(self.globalNS, self.localNS)
|
||||
self.watcher = Watcher()
|
||||
|
||||
def test_setAttrPlain(self):
|
||||
"Triggering a watcher response by setting an attribute."
|
||||
|
||||
testThing = SomeDohickey('pencil')
|
||||
self.browser.watchObject(testThing, 'testThing',
|
||||
self.watcher.receiveBrowserObject)
|
||||
self.watcher.setZero()
|
||||
|
||||
testThing.someAttr = 'someValue'
|
||||
|
||||
self.assertEqual(testThing.someAttr, 'someValue')
|
||||
self.failUnless(self.watcher.len())
|
||||
olink = self.watcher.links[-1]
|
||||
self.assertEqual(olink.id, id(testThing))
|
||||
|
||||
def test_setAttrChain(self):
|
||||
"Setting an attribute on a watched object that has __setattr__"
|
||||
testThing = MiddleMan('pencil')
|
||||
|
||||
self.browser.watchObject(testThing, 'testThing',
|
||||
self.watcher.receiveBrowserObject)
|
||||
self.watcher.setZero()
|
||||
|
||||
testThing.someAttr = 'ZORT'
|
||||
|
||||
self.assertEqual(testThing.someAttr, 'TROZ')
|
||||
self.failUnless(self.watcher.len())
|
||||
olink = self.watcher.links[-1]
|
||||
self.assertEqual(olink.id, id(testThing))
|
||||
|
||||
|
||||
def test_method(self):
|
||||
"Triggering a watcher response by invoking a method."
|
||||
|
||||
for testThing in (SomeDohickey('pencil'), MiddleMan('pencil')):
|
||||
self.browser.watchObject(testThing, 'testThing',
|
||||
self.watcher.receiveBrowserObject)
|
||||
self.watcher.setZero()
|
||||
|
||||
rval = testThing.bip()
|
||||
self.assertEqual(rval, ('pencil',))
|
||||
|
||||
self.failUnless(self.watcher.len())
|
||||
olink = self.watcher.links[-1]
|
||||
self.assertEqual(olink.id, id(testThing))
|
||||
|
||||
|
||||
def function_noArgs():
|
||||
"A function which accepts no arguments at all."
|
||||
return
|
||||
|
||||
def function_simple(a, b, c):
|
||||
"A function which accepts several arguments."
|
||||
return a, b, c
|
||||
|
||||
def function_variable(*a, **kw):
|
||||
"A function which accepts a variable number of args and keywords."
|
||||
return a, kw
|
||||
|
||||
def function_crazy((alpha, beta), c, d=range(4), **kw):
|
||||
"A function with a mad crazy signature."
|
||||
return alpha, beta, c, d, kw
|
||||
|
||||
class TestBrowseFunction(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.pool = explorer.explorerPool
|
||||
self.pool.clear()
|
||||
|
||||
def test_sanity(self):
|
||||
"""Basic checks for browse_function.
|
||||
|
||||
Was the proper type returned? Does it have the right name and ID?
|
||||
"""
|
||||
for f_name in ('function_noArgs', 'function_simple',
|
||||
'function_variable', 'function_crazy'):
|
||||
f = eval(f_name)
|
||||
|
||||
xplorer = self.pool.getExplorer(f, f_name)
|
||||
|
||||
self.assertEqual(xplorer.id, id(f))
|
||||
|
||||
self.failUnless(isinstance(xplorer, explorer.ExplorerFunction))
|
||||
|
||||
self.assertEqual(xplorer.name, f_name)
|
||||
|
||||
def test_signature_noArgs(self):
|
||||
"""Testing zero-argument function signature.
|
||||
"""
|
||||
|
||||
xplorer = self.pool.getExplorer(function_noArgs, 'function_noArgs')
|
||||
|
||||
self.assertEqual(len(xplorer.signature), 0)
|
||||
|
||||
def test_signature_simple(self):
|
||||
"""Testing simple function signature.
|
||||
"""
|
||||
|
||||
xplorer = self.pool.getExplorer(function_simple, 'function_simple')
|
||||
|
||||
expected_signature = ('a','b','c')
|
||||
|
||||
self.assertEqual(xplorer.signature.name, expected_signature)
|
||||
|
||||
def test_signature_variable(self):
|
||||
"""Testing variable-argument function signature.
|
||||
"""
|
||||
|
||||
xplorer = self.pool.getExplorer(function_variable,
|
||||
'function_variable')
|
||||
|
||||
expected_names = ('a','kw')
|
||||
signature = xplorer.signature
|
||||
|
||||
self.assertEqual(signature.name, expected_names)
|
||||
self.failUnless(signature.is_varlist(0))
|
||||
self.failUnless(signature.is_keyword(1))
|
||||
|
||||
def test_signature_crazy(self):
|
||||
"""Testing function with crazy signature.
|
||||
"""
|
||||
xplorer = self.pool.getExplorer(function_crazy, 'function_crazy')
|
||||
|
||||
signature = xplorer.signature
|
||||
|
||||
expected_signature = [{'name': 'c'},
|
||||
{'name': 'd',
|
||||
'default': range(4)},
|
||||
{'name': 'kw',
|
||||
'keywords': 1}]
|
||||
|
||||
# The name of the first argument seems to be indecipherable,
|
||||
# but make sure it has one (and no default).
|
||||
self.failUnless(signature.get_name(0))
|
||||
self.failUnless(not signature.get_default(0)[0])
|
||||
|
||||
self.assertEqual(signature.get_name(1), 'c')
|
||||
|
||||
# Get a list of values from a list of ExplorerImmutables.
|
||||
arg_2_default = map(lambda l: l.value,
|
||||
signature.get_default(2)[1].get_elements())
|
||||
|
||||
self.assertEqual(signature.get_name(2), 'd')
|
||||
self.assertEqual(arg_2_default, range(4))
|
||||
|
||||
self.assertEqual(signature.get_name(3), 'kw')
|
||||
self.failUnless(signature.is_keyword(3))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,145 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test code for basic Factory classes.
|
||||
"""
|
||||
|
||||
from __future__ import division, absolute_import
|
||||
|
||||
import pickle
|
||||
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
from twisted.internet.task import Clock
|
||||
from twisted.internet.protocol import ReconnectingClientFactory, Protocol
|
||||
|
||||
|
||||
class FakeConnector(object):
|
||||
"""
|
||||
A fake connector class, to be used to mock connections failed or lost.
|
||||
"""
|
||||
|
||||
def stopConnecting(self):
|
||||
pass
|
||||
|
||||
|
||||
def connect(self):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class ReconnectingFactoryTestCase(TestCase):
|
||||
"""
|
||||
Tests for L{ReconnectingClientFactory}.
|
||||
"""
|
||||
|
||||
def test_stopTryingWhenConnected(self):
|
||||
"""
|
||||
If a L{ReconnectingClientFactory} has C{stopTrying} called while it is
|
||||
connected, it does not subsequently attempt to reconnect if the
|
||||
connection is later lost.
|
||||
"""
|
||||
class NoConnectConnector(object):
|
||||
def stopConnecting(self):
|
||||
raise RuntimeError("Shouldn't be called, we're connected.")
|
||||
def connect(self):
|
||||
raise RuntimeError("Shouldn't be reconnecting.")
|
||||
|
||||
c = ReconnectingClientFactory()
|
||||
c.protocol = Protocol
|
||||
# Let's pretend we've connected:
|
||||
c.buildProtocol(None)
|
||||
# Now we stop trying, then disconnect:
|
||||
c.stopTrying()
|
||||
c.clientConnectionLost(NoConnectConnector(), None)
|
||||
self.assertFalse(c.continueTrying)
|
||||
|
||||
|
||||
def test_stopTryingDoesNotReconnect(self):
|
||||
"""
|
||||
Calling stopTrying on a L{ReconnectingClientFactory} doesn't attempt a
|
||||
retry on any active connector.
|
||||
"""
|
||||
class FactoryAwareFakeConnector(FakeConnector):
|
||||
attemptedRetry = False
|
||||
|
||||
def stopConnecting(self):
|
||||
"""
|
||||
Behave as though an ongoing connection attempt has now
|
||||
failed, and notify the factory of this.
|
||||
"""
|
||||
f.clientConnectionFailed(self, None)
|
||||
|
||||
def connect(self):
|
||||
"""
|
||||
Record an attempt to reconnect, since this is what we
|
||||
are trying to avoid.
|
||||
"""
|
||||
self.attemptedRetry = True
|
||||
|
||||
f = ReconnectingClientFactory()
|
||||
f.clock = Clock()
|
||||
|
||||
# simulate an active connection - stopConnecting on this connector should
|
||||
# be triggered when we call stopTrying
|
||||
f.connector = FactoryAwareFakeConnector()
|
||||
f.stopTrying()
|
||||
|
||||
# make sure we never attempted to retry
|
||||
self.assertFalse(f.connector.attemptedRetry)
|
||||
self.assertFalse(f.clock.getDelayedCalls())
|
||||
|
||||
|
||||
def test_serializeUnused(self):
|
||||
"""
|
||||
A L{ReconnectingClientFactory} which hasn't been used for anything
|
||||
can be pickled and unpickled and end up with the same state.
|
||||
"""
|
||||
original = ReconnectingClientFactory()
|
||||
reconstituted = pickle.loads(pickle.dumps(original))
|
||||
self.assertEqual(original.__dict__, reconstituted.__dict__)
|
||||
|
||||
|
||||
def test_serializeWithClock(self):
|
||||
"""
|
||||
The clock attribute of L{ReconnectingClientFactory} is not serialized,
|
||||
and the restored value sets it to the default value, the reactor.
|
||||
"""
|
||||
clock = Clock()
|
||||
original = ReconnectingClientFactory()
|
||||
original.clock = clock
|
||||
reconstituted = pickle.loads(pickle.dumps(original))
|
||||
self.assertIdentical(reconstituted.clock, None)
|
||||
|
||||
|
||||
def test_deserializationResetsParameters(self):
|
||||
"""
|
||||
A L{ReconnectingClientFactory} which is unpickled does not have an
|
||||
L{IConnector} and has its reconnecting timing parameters reset to their
|
||||
initial values.
|
||||
"""
|
||||
factory = ReconnectingClientFactory()
|
||||
factory.clientConnectionFailed(FakeConnector(), None)
|
||||
self.addCleanup(factory.stopTrying)
|
||||
|
||||
serialized = pickle.dumps(factory)
|
||||
unserialized = pickle.loads(serialized)
|
||||
self.assertEqual(unserialized.connector, None)
|
||||
self.assertEqual(unserialized._callID, None)
|
||||
self.assertEqual(unserialized.retries, 0)
|
||||
self.assertEqual(unserialized.delay, factory.initialDelay)
|
||||
self.assertEqual(unserialized.continueTrying, True)
|
||||
|
||||
|
||||
def test_parametrizedClock(self):
|
||||
"""
|
||||
The clock used by L{ReconnectingClientFactory} can be parametrized, so
|
||||
that one can cleanly test reconnections.
|
||||
"""
|
||||
clock = Clock()
|
||||
factory = ReconnectingClientFactory()
|
||||
factory.clock = clock
|
||||
|
||||
factory.clientConnectionLost(FakeConnector(), None)
|
||||
self.assertEqual(len(clock.calls), 1)
|
||||
|
|
@ -0,0 +1,993 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test cases for the L{twisted.python.failure} module.
|
||||
"""
|
||||
|
||||
from __future__ import division, absolute_import
|
||||
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import pdb
|
||||
import linecache
|
||||
|
||||
from twisted.python.compat import NativeStringIO, _PY3
|
||||
from twisted.python import reflect
|
||||
from twisted.python import failure
|
||||
|
||||
from twisted.trial.unittest import SynchronousTestCase
|
||||
|
||||
|
||||
try:
|
||||
from twisted.test import raiser
|
||||
except ImportError:
|
||||
raiser = None
|
||||
|
||||
|
||||
|
||||
def getDivisionFailure(*args, **kwargs):
|
||||
"""
|
||||
Make a C{Failure} of a divide-by-zero error.
|
||||
|
||||
@param args: Any C{*args} are passed to Failure's constructor.
|
||||
@param kwargs: Any C{**kwargs} are passed to Failure's constructor.
|
||||
"""
|
||||
try:
|
||||
1/0
|
||||
except:
|
||||
f = failure.Failure(*args, **kwargs)
|
||||
return f
|
||||
|
||||
|
||||
class FailureTestCase(SynchronousTestCase):
|
||||
"""
|
||||
Tests for L{failure.Failure}.
|
||||
"""
|
||||
|
||||
def test_failAndTrap(self):
|
||||
"""
|
||||
Trapping a L{Failure}.
|
||||
"""
|
||||
try:
|
||||
raise NotImplementedError('test')
|
||||
except:
|
||||
f = failure.Failure()
|
||||
error = f.trap(SystemExit, RuntimeError)
|
||||
self.assertEqual(error, RuntimeError)
|
||||
self.assertEqual(f.type, NotImplementedError)
|
||||
|
||||
|
||||
def test_trapRaisesCurrentFailure(self):
|
||||
"""
|
||||
If the wrapped C{Exception} is not a subclass of one of the
|
||||
expected types, L{failure.Failure.trap} raises the current
|
||||
L{failure.Failure} ie C{self}.
|
||||
"""
|
||||
exception = ValueError()
|
||||
try:
|
||||
raise exception
|
||||
except:
|
||||
f = failure.Failure()
|
||||
untrapped = self.assertRaises(failure.Failure, f.trap, OverflowError)
|
||||
self.assertIdentical(f, untrapped)
|
||||
|
||||
|
||||
if _PY3:
|
||||
test_trapRaisesCurrentFailure.skip = (
|
||||
"In Python3, Failure.trap raises the wrapped Exception "
|
||||
"instead of the original Failure instance.")
|
||||
|
||||
|
||||
def test_trapRaisesWrappedException(self):
|
||||
"""
|
||||
If the wrapped C{Exception} is not a subclass of one of the
|
||||
expected types, L{failure.Failure.trap} raises the wrapped
|
||||
C{Exception}.
|
||||
"""
|
||||
exception = ValueError()
|
||||
try:
|
||||
raise exception
|
||||
except:
|
||||
f = failure.Failure()
|
||||
|
||||
untrapped = self.assertRaises(ValueError, f.trap, OverflowError)
|
||||
self.assertIdentical(exception, untrapped)
|
||||
|
||||
|
||||
if not _PY3:
|
||||
test_trapRaisesWrappedException.skip = (
|
||||
"In Python2, Failure.trap raises the current Failure instance "
|
||||
"instead of the wrapped Exception.")
|
||||
|
||||
|
||||
def test_failureValueFromFailure(self):
|
||||
"""
|
||||
A L{failure.Failure} constructed from another
|
||||
L{failure.Failure} instance, has its C{value} property set to
|
||||
the value of that L{failure.Failure} instance.
|
||||
"""
|
||||
exception = ValueError()
|
||||
f1 = failure.Failure(exception)
|
||||
f2 = failure.Failure(f1)
|
||||
self.assertIdentical(f2.value, exception)
|
||||
|
||||
|
||||
def test_failureValueFromFoundFailure(self):
|
||||
"""
|
||||
A L{failure.Failure} constructed without a C{exc_value}
|
||||
argument, will search for an "original" C{Failure}, and if
|
||||
found, its value will be used as the value for the new
|
||||
C{Failure}.
|
||||
"""
|
||||
exception = ValueError()
|
||||
f1 = failure.Failure(exception)
|
||||
try:
|
||||
f1.trap(OverflowError)
|
||||
except:
|
||||
f2 = failure.Failure()
|
||||
|
||||
self.assertIdentical(f2.value, exception)
|
||||
|
||||
|
||||
def assertStartsWith(self, s, prefix):
|
||||
"""
|
||||
Assert that C{s} starts with a particular C{prefix}.
|
||||
|
||||
@param s: The input string.
|
||||
@type s: C{str}
|
||||
@param prefix: The string that C{s} should start with.
|
||||
@type prefix: C{str}
|
||||
"""
|
||||
self.assertTrue(s.startswith(prefix),
|
||||
'%r is not the start of %r' % (prefix, s))
|
||||
|
||||
|
||||
def assertEndsWith(self, s, suffix):
|
||||
"""
|
||||
Assert that C{s} end with a particular C{suffix}.
|
||||
|
||||
@param s: The input string.
|
||||
@type s: C{str}
|
||||
@param suffix: The string that C{s} should end with.
|
||||
@type suffix: C{str}
|
||||
"""
|
||||
self.assertTrue(s.endswith(suffix),
|
||||
'%r is not the end of %r' % (suffix, s))
|
||||
|
||||
|
||||
def assertTracebackFormat(self, tb, prefix, suffix):
|
||||
"""
|
||||
Assert that the C{tb} traceback contains a particular C{prefix} and
|
||||
C{suffix}.
|
||||
|
||||
@param tb: The traceback string.
|
||||
@type tb: C{str}
|
||||
@param prefix: The string that C{tb} should start with.
|
||||
@type prefix: C{str}
|
||||
@param suffix: The string that C{tb} should end with.
|
||||
@type suffix: C{str}
|
||||
"""
|
||||
self.assertStartsWith(tb, prefix)
|
||||
self.assertEndsWith(tb, suffix)
|
||||
|
||||
|
||||
def assertDetailedTraceback(self, captureVars=False, cleanFailure=False):
|
||||
"""
|
||||
Assert that L{printDetailedTraceback} produces and prints a detailed
|
||||
traceback.
|
||||
|
||||
The detailed traceback consists of a header::
|
||||
|
||||
*--- Failure #20 ---
|
||||
|
||||
The body contains the stacktrace::
|
||||
|
||||
/twisted/trial/_synctest.py:1180: _run(...)
|
||||
/twisted/python/util.py:1076: runWithWarningsSuppressed(...)
|
||||
--- <exception caught here> ---
|
||||
/twisted/test/test_failure.py:39: getDivisionFailure(...)
|
||||
|
||||
If C{captureVars} is enabled the body also includes a list of
|
||||
globals and locals::
|
||||
|
||||
[ Locals ]
|
||||
exampleLocalVar : 'xyz'
|
||||
...
|
||||
( Globals )
|
||||
...
|
||||
|
||||
Or when C{captureVars} is disabled::
|
||||
|
||||
[Capture of Locals and Globals disabled (use captureVars=True)]
|
||||
|
||||
When C{cleanFailure} is enabled references to other objects are removed
|
||||
and replaced with strings.
|
||||
|
||||
And finally the footer with the L{Failure}'s value::
|
||||
|
||||
exceptions.ZeroDivisionError: float division
|
||||
*--- End of Failure #20 ---
|
||||
|
||||
@param captureVars: Enables L{Failure.captureVars}.
|
||||
@type captureVars: C{bool}
|
||||
@param cleanFailure: Enables L{Failure.cleanFailure}.
|
||||
@type cleanFailure: C{bool}
|
||||
"""
|
||||
if captureVars:
|
||||
exampleLocalVar = 'xyz'
|
||||
|
||||
f = getDivisionFailure(captureVars=captureVars)
|
||||
out = NativeStringIO()
|
||||
if cleanFailure:
|
||||
f.cleanFailure()
|
||||
f.printDetailedTraceback(out)
|
||||
|
||||
tb = out.getvalue()
|
||||
start = "*--- Failure #%d%s---\n" % (f.count,
|
||||
(f.pickled and ' (pickled) ') or ' ')
|
||||
end = "%s: %s\n*--- End of Failure #%s ---\n" % (reflect.qual(f.type),
|
||||
reflect.safe_str(f.value), f.count)
|
||||
self.assertTracebackFormat(tb, start, end)
|
||||
|
||||
# Variables are printed on lines with 2 leading spaces.
|
||||
linesWithVars = [line for line in tb.splitlines()
|
||||
if line.startswith(' ')]
|
||||
|
||||
if captureVars:
|
||||
self.assertNotEqual([], linesWithVars)
|
||||
if cleanFailure:
|
||||
line = ' exampleLocalVar : "\'xyz\'"'
|
||||
else:
|
||||
line = " exampleLocalVar : 'xyz'"
|
||||
self.assertIn(line, linesWithVars)
|
||||
else:
|
||||
self.assertEqual([], linesWithVars)
|
||||
self.assertIn(' [Capture of Locals and Globals disabled (use '
|
||||
'captureVars=True)]\n', tb)
|
||||
|
||||
|
||||
def assertBriefTraceback(self, captureVars=False):
|
||||
"""
|
||||
Assert that L{printBriefTraceback} produces and prints a brief
|
||||
traceback.
|
||||
|
||||
The brief traceback consists of a header::
|
||||
|
||||
Traceback: <type 'exceptions.ZeroDivisionError'>: float division
|
||||
|
||||
The body with the stacktrace::
|
||||
|
||||
/twisted/trial/_synctest.py:1180:_run
|
||||
/twisted/python/util.py:1076:runWithWarningsSuppressed
|
||||
|
||||
And the footer::
|
||||
|
||||
--- <exception caught here> ---
|
||||
/twisted/test/test_failure.py:39:getDivisionFailure
|
||||
|
||||
@param captureVars: Enables L{Failure.captureVars}.
|
||||
@type captureVars: C{bool}
|
||||
"""
|
||||
if captureVars:
|
||||
exampleLocalVar = 'abcde'
|
||||
|
||||
f = getDivisionFailure()
|
||||
out = NativeStringIO()
|
||||
f.printBriefTraceback(out)
|
||||
tb = out.getvalue()
|
||||
stack = ''
|
||||
for method, filename, lineno, localVars, globalVars in f.frames:
|
||||
stack += '%s:%s:%s\n' % (filename, lineno, method)
|
||||
|
||||
if _PY3:
|
||||
zde = "class 'ZeroDivisionError'"
|
||||
else:
|
||||
zde = "type 'exceptions.ZeroDivisionError'"
|
||||
|
||||
self.assertTracebackFormat(tb,
|
||||
"Traceback: <%s>: " % (zde,),
|
||||
"%s\n%s" % (failure.EXCEPTION_CAUGHT_HERE, stack))
|
||||
|
||||
if captureVars:
|
||||
self.assertEqual(None, re.search('exampleLocalVar.*abcde', tb))
|
||||
|
||||
|
||||
def assertDefaultTraceback(self, captureVars=False):
|
||||
"""
|
||||
Assert that L{printTraceback} produces and prints a default traceback.
|
||||
|
||||
The default traceback consists of a header::
|
||||
|
||||
Traceback (most recent call last):
|
||||
|
||||
The body with traceback::
|
||||
|
||||
File "/twisted/trial/_synctest.py", line 1180, in _run
|
||||
runWithWarningsSuppressed(suppress, method)
|
||||
|
||||
And the footer::
|
||||
|
||||
--- <exception caught here> ---
|
||||
File "twisted/test/test_failure.py", line 39, in getDivisionFailure
|
||||
1/0
|
||||
exceptions.ZeroDivisionError: float division
|
||||
|
||||
@param captureVars: Enables L{Failure.captureVars}.
|
||||
@type captureVars: C{bool}
|
||||
"""
|
||||
if captureVars:
|
||||
exampleLocalVar = 'xyzzy'
|
||||
|
||||
f = getDivisionFailure(captureVars=captureVars)
|
||||
out = NativeStringIO()
|
||||
f.printTraceback(out)
|
||||
tb = out.getvalue()
|
||||
stack = ''
|
||||
for method, filename, lineno, localVars, globalVars in f.frames:
|
||||
stack += ' File "%s", line %s, in %s\n' % (filename, lineno,
|
||||
method)
|
||||
stack += ' %s\n' % (linecache.getline(
|
||||
filename, lineno).strip(),)
|
||||
|
||||
self.assertTracebackFormat(tb,
|
||||
"Traceback (most recent call last):",
|
||||
"%s\n%s%s: %s\n" % (failure.EXCEPTION_CAUGHT_HERE, stack,
|
||||
reflect.qual(f.type), reflect.safe_str(f.value)))
|
||||
|
||||
if captureVars:
|
||||
self.assertEqual(None, re.search('exampleLocalVar.*xyzzy', tb))
|
||||
|
||||
|
||||
def test_printDetailedTraceback(self):
|
||||
"""
|
||||
L{printDetailedTraceback} returns a detailed traceback including the
|
||||
L{Failure}'s count.
|
||||
"""
|
||||
self.assertDetailedTraceback()
|
||||
|
||||
|
||||
def test_printBriefTraceback(self):
|
||||
"""
|
||||
L{printBriefTraceback} returns a brief traceback.
|
||||
"""
|
||||
self.assertBriefTraceback()
|
||||
|
||||
|
||||
def test_printTraceback(self):
|
||||
"""
|
||||
L{printTraceback} returns a traceback.
|
||||
"""
|
||||
self.assertDefaultTraceback()
|
||||
|
||||
|
||||
def test_printDetailedTracebackCapturedVars(self):
|
||||
"""
|
||||
L{printDetailedTraceback} captures the locals and globals for its
|
||||
stack frames and adds them to the traceback, when called on a
|
||||
L{Failure} constructed with C{captureVars=True}.
|
||||
"""
|
||||
self.assertDetailedTraceback(captureVars=True)
|
||||
|
||||
|
||||
def test_printBriefTracebackCapturedVars(self):
|
||||
"""
|
||||
L{printBriefTraceback} returns a brief traceback when called on a
|
||||
L{Failure} constructed with C{captureVars=True}.
|
||||
|
||||
Local variables on the stack can not be seen in the resulting
|
||||
traceback.
|
||||
"""
|
||||
self.assertBriefTraceback(captureVars=True)
|
||||
|
||||
|
||||
def test_printTracebackCapturedVars(self):
|
||||
"""
|
||||
L{printTraceback} returns a traceback when called on a L{Failure}
|
||||
constructed with C{captureVars=True}.
|
||||
|
||||
Local variables on the stack can not be seen in the resulting
|
||||
traceback.
|
||||
"""
|
||||
self.assertDefaultTraceback(captureVars=True)
|
||||
|
||||
|
||||
def test_printDetailedTracebackCapturedVarsCleaned(self):
|
||||
"""
|
||||
C{printDetailedTraceback} includes information about local variables on
|
||||
the stack after C{cleanFailure} has been called.
|
||||
"""
|
||||
self.assertDetailedTraceback(captureVars=True, cleanFailure=True)
|
||||
|
||||
|
||||
def test_invalidFormatFramesDetail(self):
|
||||
"""
|
||||
L{failure.format_frames} raises a L{ValueError} if the supplied
|
||||
C{detail} level is unknown.
|
||||
"""
|
||||
self.assertRaises(ValueError, failure.format_frames, None, None,
|
||||
detail='noisia')
|
||||
|
||||
|
||||
def testExplictPass(self):
|
||||
e = RuntimeError()
|
||||
f = failure.Failure(e)
|
||||
f.trap(RuntimeError)
|
||||
self.assertEqual(f.value, e)
|
||||
|
||||
|
||||
def _getInnermostFrameLine(self, f):
|
||||
try:
|
||||
f.raiseException()
|
||||
except ZeroDivisionError:
|
||||
tb = traceback.extract_tb(sys.exc_info()[2])
|
||||
return tb[-1][-1]
|
||||
else:
|
||||
raise Exception(
|
||||
"f.raiseException() didn't raise ZeroDivisionError!?")
|
||||
|
||||
|
||||
def testRaiseExceptionWithTB(self):
|
||||
f = getDivisionFailure()
|
||||
innerline = self._getInnermostFrameLine(f)
|
||||
self.assertEqual(innerline, '1/0')
|
||||
|
||||
|
||||
def testLackOfTB(self):
|
||||
f = getDivisionFailure()
|
||||
f.cleanFailure()
|
||||
innerline = self._getInnermostFrameLine(f)
|
||||
self.assertEqual(innerline, '1/0')
|
||||
|
||||
testLackOfTB.todo = "the traceback is not preserved, exarkun said he'll try to fix this! god knows how"
|
||||
if _PY3:
|
||||
del testLackOfTB # fix in ticket #6008
|
||||
|
||||
|
||||
def test_stringExceptionConstruction(self):
|
||||
"""
|
||||
Constructing a C{Failure} with a string as its exception value raises
|
||||
a C{TypeError}, as this is no longer supported as of Python 2.6.
|
||||
"""
|
||||
exc = self.assertRaises(TypeError, failure.Failure, "ono!")
|
||||
self.assertIn("Strings are not supported by Failure", str(exc))
|
||||
|
||||
|
||||
def testConstructionFails(self):
|
||||
"""
|
||||
Creating a Failure with no arguments causes it to try to discover the
|
||||
current interpreter exception state. If no such state exists, creating
|
||||
the Failure should raise a synchronous exception.
|
||||
"""
|
||||
self.assertRaises(failure.NoCurrentExceptionError, failure.Failure)
|
||||
|
||||
|
||||
def test_getTracebackObject(self):
|
||||
"""
|
||||
If the C{Failure} has not been cleaned, then C{getTracebackObject}
|
||||
returns the traceback object that captured in its constructor.
|
||||
"""
|
||||
f = getDivisionFailure()
|
||||
self.assertEqual(f.getTracebackObject(), f.tb)
|
||||
|
||||
|
||||
def test_getTracebackObjectFromCaptureVars(self):
|
||||
"""
|
||||
C{captureVars=True} has no effect on the result of
|
||||
C{getTracebackObject}.
|
||||
"""
|
||||
try:
|
||||
1/0
|
||||
except ZeroDivisionError:
|
||||
noVarsFailure = failure.Failure()
|
||||
varsFailure = failure.Failure(captureVars=True)
|
||||
self.assertEqual(noVarsFailure.getTracebackObject(), varsFailure.tb)
|
||||
|
||||
|
||||
def test_getTracebackObjectFromClean(self):
|
||||
"""
|
||||
If the Failure has been cleaned, then C{getTracebackObject} returns an
|
||||
object that looks the same to L{traceback.extract_tb}.
|
||||
"""
|
||||
f = getDivisionFailure()
|
||||
expected = traceback.extract_tb(f.getTracebackObject())
|
||||
f.cleanFailure()
|
||||
observed = traceback.extract_tb(f.getTracebackObject())
|
||||
self.assertNotEqual(None, expected)
|
||||
self.assertEqual(expected, observed)
|
||||
|
||||
|
||||
def test_getTracebackObjectFromCaptureVarsAndClean(self):
|
||||
"""
|
||||
If the Failure was created with captureVars, then C{getTracebackObject}
|
||||
returns an object that looks the same to L{traceback.extract_tb}.
|
||||
"""
|
||||
f = getDivisionFailure(captureVars=True)
|
||||
expected = traceback.extract_tb(f.getTracebackObject())
|
||||
f.cleanFailure()
|
||||
observed = traceback.extract_tb(f.getTracebackObject())
|
||||
self.assertEqual(expected, observed)
|
||||
|
||||
|
||||
def test_getTracebackObjectWithoutTraceback(self):
|
||||
"""
|
||||
L{failure.Failure}s need not be constructed with traceback objects. If
|
||||
a C{Failure} has no traceback information at all, C{getTracebackObject}
|
||||
just returns None.
|
||||
|
||||
None is a good value, because traceback.extract_tb(None) -> [].
|
||||
"""
|
||||
f = failure.Failure(Exception("some error"))
|
||||
self.assertEqual(f.getTracebackObject(), None)
|
||||
|
||||
|
||||
def test_tracebackFromExceptionInPython3(self):
|
||||
"""
|
||||
If a L{failure.Failure} is constructed with an exception but no
|
||||
traceback in Python 3, the traceback will be extracted from the
|
||||
exception's C{__traceback__} attribute.
|
||||
"""
|
||||
try:
|
||||
1/0
|
||||
except:
|
||||
klass, exception, tb = sys.exc_info()
|
||||
f = failure.Failure(exception)
|
||||
self.assertIdentical(f.tb, tb)
|
||||
|
||||
|
||||
def test_cleanFailureRemovesTracebackInPython3(self):
|
||||
"""
|
||||
L{failure.Failure.cleanFailure} sets the C{__traceback__} attribute of
|
||||
the exception to C{None} in Python 3.
|
||||
"""
|
||||
f = getDivisionFailure()
|
||||
self.assertNotEqual(f.tb, None)
|
||||
self.assertIdentical(f.value.__traceback__, f.tb)
|
||||
f.cleanFailure()
|
||||
self.assertIdentical(f.value.__traceback__, None)
|
||||
|
||||
if not _PY3:
|
||||
test_tracebackFromExceptionInPython3.skip = "Python 3 only."
|
||||
test_cleanFailureRemovesTracebackInPython3.skip = "Python 3 only."
|
||||
|
||||
|
||||
|
||||
class BrokenStr(Exception):
|
||||
"""
|
||||
An exception class the instances of which cannot be presented as strings via
|
||||
C{str}.
|
||||
"""
|
||||
def __str__(self):
|
||||
# Could raise something else, but there's no point as yet.
|
||||
raise self
|
||||
|
||||
|
||||
|
||||
class BrokenExceptionMetaclass(type):
|
||||
"""
|
||||
A metaclass for an exception type which cannot be presented as a string via
|
||||
C{str}.
|
||||
"""
|
||||
def __str__(self):
|
||||
raise ValueError("You cannot make a string out of me.")
|
||||
|
||||
|
||||
|
||||
class BrokenExceptionType(Exception, object):
|
||||
"""
|
||||
The aforementioned exception type which cnanot be presented as a string via
|
||||
C{str}.
|
||||
"""
|
||||
__metaclass__ = BrokenExceptionMetaclass
|
||||
|
||||
|
||||
|
||||
class GetTracebackTests(SynchronousTestCase):
|
||||
"""
|
||||
Tests for L{Failure.getTraceback}.
|
||||
"""
|
||||
def _brokenValueTest(self, detail):
|
||||
"""
|
||||
Construct a L{Failure} with an exception that raises an exception from
|
||||
its C{__str__} method and then call C{getTraceback} with the specified
|
||||
detail and verify that it returns a string.
|
||||
"""
|
||||
x = BrokenStr()
|
||||
f = failure.Failure(x)
|
||||
traceback = f.getTraceback(detail=detail)
|
||||
self.assertIsInstance(traceback, str)
|
||||
|
||||
|
||||
def test_brokenValueBriefDetail(self):
|
||||
"""
|
||||
A L{Failure} might wrap an exception with a C{__str__} method which
|
||||
raises an exception. In this case, calling C{getTraceback} on the
|
||||
failure with the C{"brief"} detail does not raise an exception.
|
||||
"""
|
||||
self._brokenValueTest("brief")
|
||||
|
||||
|
||||
def test_brokenValueDefaultDetail(self):
|
||||
"""
|
||||
Like test_brokenValueBriefDetail, but for the C{"default"} detail case.
|
||||
"""
|
||||
self._brokenValueTest("default")
|
||||
|
||||
|
||||
def test_brokenValueVerboseDetail(self):
|
||||
"""
|
||||
Like test_brokenValueBriefDetail, but for the C{"default"} detail case.
|
||||
"""
|
||||
self._brokenValueTest("verbose")
|
||||
|
||||
|
||||
def _brokenTypeTest(self, detail):
|
||||
"""
|
||||
Construct a L{Failure} with an exception type that raises an exception
|
||||
from its C{__str__} method and then call C{getTraceback} with the
|
||||
specified detail and verify that it returns a string.
|
||||
"""
|
||||
f = failure.Failure(BrokenExceptionType())
|
||||
traceback = f.getTraceback(detail=detail)
|
||||
self.assertIsInstance(traceback, str)
|
||||
|
||||
|
||||
def test_brokenTypeBriefDetail(self):
|
||||
"""
|
||||
A L{Failure} might wrap an exception the type object of which has a
|
||||
C{__str__} method which raises an exception. In this case, calling
|
||||
C{getTraceback} on the failure with the C{"brief"} detail does not raise
|
||||
an exception.
|
||||
"""
|
||||
self._brokenTypeTest("brief")
|
||||
|
||||
|
||||
def test_brokenTypeDefaultDetail(self):
|
||||
"""
|
||||
Like test_brokenTypeBriefDetail, but for the C{"default"} detail case.
|
||||
"""
|
||||
self._brokenTypeTest("default")
|
||||
|
||||
|
||||
def test_brokenTypeVerboseDetail(self):
|
||||
"""
|
||||
Like test_brokenTypeBriefDetail, but for the C{"verbose"} detail case.
|
||||
"""
|
||||
self._brokenTypeTest("verbose")
|
||||
|
||||
|
||||
|
||||
class FindFailureTests(SynchronousTestCase):
|
||||
"""
|
||||
Tests for functionality related to L{Failure._findFailure}.
|
||||
"""
|
||||
|
||||
def test_findNoFailureInExceptionHandler(self):
|
||||
"""
|
||||
Within an exception handler, _findFailure should return
|
||||
C{None} in case no Failure is associated with the current
|
||||
exception.
|
||||
"""
|
||||
try:
|
||||
1/0
|
||||
except:
|
||||
self.assertEqual(failure.Failure._findFailure(), None)
|
||||
else:
|
||||
self.fail("No exception raised from 1/0!?")
|
||||
|
||||
|
||||
def test_findNoFailure(self):
|
||||
"""
|
||||
Outside of an exception handler, _findFailure should return None.
|
||||
"""
|
||||
self.assertEqual(sys.exc_info()[-1], None) #environment sanity check
|
||||
self.assertEqual(failure.Failure._findFailure(), None)
|
||||
|
||||
|
||||
def test_findFailure(self):
|
||||
"""
|
||||
Within an exception handler, it should be possible to find the
|
||||
original Failure that caused the current exception (if it was
|
||||
caused by raiseException).
|
||||
"""
|
||||
f = getDivisionFailure()
|
||||
f.cleanFailure()
|
||||
try:
|
||||
f.raiseException()
|
||||
except:
|
||||
self.assertEqual(failure.Failure._findFailure(), f)
|
||||
else:
|
||||
self.fail("No exception raised from raiseException!?")
|
||||
|
||||
|
||||
def test_failureConstructionFindsOriginalFailure(self):
|
||||
"""
|
||||
When a Failure is constructed in the context of an exception
|
||||
handler that is handling an exception raised by
|
||||
raiseException, the new Failure should be chained to that
|
||||
original Failure.
|
||||
"""
|
||||
f = getDivisionFailure()
|
||||
f.cleanFailure()
|
||||
try:
|
||||
f.raiseException()
|
||||
except:
|
||||
newF = failure.Failure()
|
||||
self.assertEqual(f.getTraceback(), newF.getTraceback())
|
||||
else:
|
||||
self.fail("No exception raised from raiseException!?")
|
||||
|
||||
|
||||
def test_failureConstructionWithMungedStackSucceeds(self):
|
||||
"""
|
||||
Pyrex and Cython are known to insert fake stack frames so as to give
|
||||
more Python-like tracebacks. These stack frames with empty code objects
|
||||
should not break extraction of the exception.
|
||||
"""
|
||||
try:
|
||||
raiser.raiseException()
|
||||
except raiser.RaiserException:
|
||||
f = failure.Failure()
|
||||
self.assertTrue(f.check(raiser.RaiserException))
|
||||
else:
|
||||
self.fail("No exception raised from extension?!")
|
||||
|
||||
|
||||
if raiser is None:
|
||||
skipMsg = "raiser extension not available"
|
||||
test_failureConstructionWithMungedStackSucceeds.skip = skipMsg
|
||||
|
||||
|
||||
|
||||
class TestFormattableTraceback(SynchronousTestCase):
|
||||
"""
|
||||
Whitebox tests that show that L{failure._Traceback} constructs objects that
|
||||
can be used by L{traceback.extract_tb}.
|
||||
|
||||
If the objects can be used by L{traceback.extract_tb}, then they can be
|
||||
formatted using L{traceback.format_tb} and friends.
|
||||
"""
|
||||
|
||||
def test_singleFrame(self):
|
||||
"""
|
||||
A C{_Traceback} object constructed with a single frame should be able
|
||||
to be passed to L{traceback.extract_tb}, and we should get a singleton
|
||||
list containing a (filename, lineno, methodname, line) tuple.
|
||||
"""
|
||||
tb = failure._Traceback([['method', 'filename.py', 123, {}, {}]])
|
||||
# Note that we don't need to test that extract_tb correctly extracts
|
||||
# the line's contents. In this case, since filename.py doesn't exist,
|
||||
# it will just use None.
|
||||
self.assertEqual(traceback.extract_tb(tb),
|
||||
[('filename.py', 123, 'method', None)])
|
||||
|
||||
|
||||
def test_manyFrames(self):
|
||||
"""
|
||||
A C{_Traceback} object constructed with multiple frames should be able
|
||||
to be passed to L{traceback.extract_tb}, and we should get a list
|
||||
containing a tuple for each frame.
|
||||
"""
|
||||
tb = failure._Traceback([
|
||||
['method1', 'filename.py', 123, {}, {}],
|
||||
['method2', 'filename.py', 235, {}, {}]])
|
||||
self.assertEqual(traceback.extract_tb(tb),
|
||||
[('filename.py', 123, 'method1', None),
|
||||
('filename.py', 235, 'method2', None)])
|
||||
|
||||
|
||||
|
||||
class TestFrameAttributes(SynchronousTestCase):
|
||||
"""
|
||||
_Frame objects should possess some basic attributes that qualify them as
|
||||
fake python Frame objects.
|
||||
"""
|
||||
|
||||
def test_fakeFrameAttributes(self):
|
||||
"""
|
||||
L{_Frame} instances have the C{f_globals} and C{f_locals} attributes
|
||||
bound to C{dict} instance. They also have the C{f_code} attribute
|
||||
bound to something like a code object.
|
||||
"""
|
||||
frame = failure._Frame("dummyname", "dummyfilename")
|
||||
self.assertIsInstance(frame.f_globals, dict)
|
||||
self.assertIsInstance(frame.f_locals, dict)
|
||||
self.assertIsInstance(frame.f_code, failure._Code)
|
||||
|
||||
|
||||
|
||||
class TestDebugMode(SynchronousTestCase):
|
||||
"""
|
||||
Failure's debug mode should allow jumping into the debugger.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Override pdb.post_mortem so we can make sure it's called.
|
||||
"""
|
||||
# Make sure any changes we make are reversed:
|
||||
post_mortem = pdb.post_mortem
|
||||
if _PY3:
|
||||
origInit = failure.Failure.__init__
|
||||
else:
|
||||
origInit = failure.Failure.__dict__['__init__']
|
||||
def restore():
|
||||
pdb.post_mortem = post_mortem
|
||||
if _PY3:
|
||||
failure.Failure.__init__ = origInit
|
||||
else:
|
||||
failure.Failure.__dict__['__init__'] = origInit
|
||||
self.addCleanup(restore)
|
||||
|
||||
self.result = []
|
||||
pdb.post_mortem = self.result.append
|
||||
failure.startDebugMode()
|
||||
|
||||
|
||||
def test_regularFailure(self):
|
||||
"""
|
||||
If startDebugMode() is called, calling Failure() will first call
|
||||
pdb.post_mortem with the traceback.
|
||||
"""
|
||||
try:
|
||||
1/0
|
||||
except:
|
||||
typ, exc, tb = sys.exc_info()
|
||||
f = failure.Failure()
|
||||
self.assertEqual(self.result, [tb])
|
||||
self.assertEqual(f.captureVars, False)
|
||||
|
||||
|
||||
def test_captureVars(self):
|
||||
"""
|
||||
If startDebugMode() is called, passing captureVars to Failure() will
|
||||
not blow up.
|
||||
"""
|
||||
try:
|
||||
1/0
|
||||
except:
|
||||
typ, exc, tb = sys.exc_info()
|
||||
f = failure.Failure(captureVars=True)
|
||||
self.assertEqual(self.result, [tb])
|
||||
self.assertEqual(f.captureVars, True)
|
||||
|
||||
|
||||
|
||||
class ExtendedGeneratorTests(SynchronousTestCase):
|
||||
"""
|
||||
Tests C{failure.Failure} support for generator features added in Python 2.5
|
||||
"""
|
||||
|
||||
def _throwIntoGenerator(self, f, g):
|
||||
try:
|
||||
f.throwExceptionIntoGenerator(g)
|
||||
except StopIteration:
|
||||
pass
|
||||
else:
|
||||
self.fail("throwExceptionIntoGenerator should have raised "
|
||||
"StopIteration")
|
||||
|
||||
def test_throwExceptionIntoGenerator(self):
|
||||
"""
|
||||
It should be possible to throw the exception that a Failure
|
||||
represents into a generator.
|
||||
"""
|
||||
stuff = []
|
||||
def generator():
|
||||
try:
|
||||
yield
|
||||
except:
|
||||
stuff.append(sys.exc_info())
|
||||
else:
|
||||
self.fail("Yield should have yielded exception.")
|
||||
g = generator()
|
||||
f = getDivisionFailure()
|
||||
next(g)
|
||||
self._throwIntoGenerator(f, g)
|
||||
|
||||
self.assertEqual(stuff[0][0], ZeroDivisionError)
|
||||
self.assertTrue(isinstance(stuff[0][1], ZeroDivisionError))
|
||||
|
||||
self.assertEqual(traceback.extract_tb(stuff[0][2])[-1][-1], "1/0")
|
||||
|
||||
|
||||
def test_findFailureInGenerator(self):
|
||||
"""
|
||||
Within an exception handler, it should be possible to find the
|
||||
original Failure that caused the current exception (if it was
|
||||
caused by throwExceptionIntoGenerator).
|
||||
"""
|
||||
f = getDivisionFailure()
|
||||
f.cleanFailure()
|
||||
|
||||
foundFailures = []
|
||||
def generator():
|
||||
try:
|
||||
yield
|
||||
except:
|
||||
foundFailures.append(failure.Failure._findFailure())
|
||||
else:
|
||||
self.fail("No exception sent to generator")
|
||||
|
||||
g = generator()
|
||||
next(g)
|
||||
self._throwIntoGenerator(f, g)
|
||||
|
||||
self.assertEqual(foundFailures, [f])
|
||||
|
||||
|
||||
def test_failureConstructionFindsOriginalFailure(self):
|
||||
"""
|
||||
When a Failure is constructed in the context of an exception
|
||||
handler that is handling an exception raised by
|
||||
throwExceptionIntoGenerator, the new Failure should be chained to that
|
||||
original Failure.
|
||||
"""
|
||||
f = getDivisionFailure()
|
||||
f.cleanFailure()
|
||||
|
||||
newFailures = []
|
||||
|
||||
def generator():
|
||||
try:
|
||||
yield
|
||||
except:
|
||||
newFailures.append(failure.Failure())
|
||||
else:
|
||||
self.fail("No exception sent to generator")
|
||||
g = generator()
|
||||
next(g)
|
||||
self._throwIntoGenerator(f, g)
|
||||
|
||||
self.assertEqual(len(newFailures), 1)
|
||||
self.assertEqual(newFailures[0].getTraceback(), f.getTraceback())
|
||||
|
||||
if _PY3:
|
||||
test_findFailureInGenerator.todo = (
|
||||
"Python 3 support to be fixed in #5949")
|
||||
test_failureConstructionFindsOriginalFailure.todo = (
|
||||
"Python 3 support to be fixed in #5949")
|
||||
# Remove these two lines in #6008 (unittest todo support):
|
||||
del test_findFailureInGenerator
|
||||
del test_failureConstructionFindsOriginalFailure
|
||||
|
||||
|
||||
def test_ambiguousFailureInGenerator(self):
|
||||
"""
|
||||
When a generator reraises a different exception,
|
||||
L{Failure._findFailure} inside the generator should find the reraised
|
||||
exception rather than original one.
|
||||
"""
|
||||
def generator():
|
||||
try:
|
||||
try:
|
||||
yield
|
||||
except:
|
||||
[][1]
|
||||
except:
|
||||
self.assertIsInstance(failure.Failure().value, IndexError)
|
||||
g = generator()
|
||||
next(g)
|
||||
f = getDivisionFailure()
|
||||
self._throwIntoGenerator(f, g)
|
||||
|
||||
|
||||
def test_ambiguousFailureFromGenerator(self):
|
||||
"""
|
||||
When a generator reraises a different exception,
|
||||
L{Failure._findFailure} above the generator should find the reraised
|
||||
exception rather than original one.
|
||||
"""
|
||||
def generator():
|
||||
try:
|
||||
yield
|
||||
except:
|
||||
[][1]
|
||||
g = generator()
|
||||
next(g)
|
||||
f = getDivisionFailure()
|
||||
try:
|
||||
self._throwIntoGenerator(f, g)
|
||||
except:
|
||||
self.assertIsInstance(failure.Failure().value, IndexError)
|
||||
|
|
@ -0,0 +1,266 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.internet.fdesc}.
|
||||
"""
|
||||
|
||||
import os, sys
|
||||
import errno
|
||||
|
||||
try:
|
||||
import fcntl
|
||||
except ImportError:
|
||||
skip = "not supported on this platform"
|
||||
else:
|
||||
from twisted.internet import fdesc
|
||||
|
||||
from twisted.python.util import untilConcludes
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
||||
|
||||
class NonBlockingTestCase(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Tests for L{fdesc.setNonBlocking} and L{fdesc.setBlocking}.
|
||||
"""
|
||||
|
||||
def test_setNonBlocking(self):
|
||||
"""
|
||||
L{fdesc.setNonBlocking} sets a file description to non-blocking.
|
||||
"""
|
||||
r, w = os.pipe()
|
||||
self.addCleanup(os.close, r)
|
||||
self.addCleanup(os.close, w)
|
||||
self.assertFalse(fcntl.fcntl(r, fcntl.F_GETFL) & os.O_NONBLOCK)
|
||||
fdesc.setNonBlocking(r)
|
||||
self.assertTrue(fcntl.fcntl(r, fcntl.F_GETFL) & os.O_NONBLOCK)
|
||||
|
||||
|
||||
def test_setBlocking(self):
|
||||
"""
|
||||
L{fdesc.setBlocking} sets a file description to blocking.
|
||||
"""
|
||||
r, w = os.pipe()
|
||||
self.addCleanup(os.close, r)
|
||||
self.addCleanup(os.close, w)
|
||||
fdesc.setNonBlocking(r)
|
||||
fdesc.setBlocking(r)
|
||||
self.assertFalse(fcntl.fcntl(r, fcntl.F_GETFL) & os.O_NONBLOCK)
|
||||
|
||||
|
||||
|
||||
class ReadWriteTestCase(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Tests for L{fdesc.readFromFD}, L{fdesc.writeToFD}.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Create a non-blocking pipe that can be used in tests.
|
||||
"""
|
||||
self.r, self.w = os.pipe()
|
||||
fdesc.setNonBlocking(self.r)
|
||||
fdesc.setNonBlocking(self.w)
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
"""
|
||||
Close pipes.
|
||||
"""
|
||||
try:
|
||||
os.close(self.w)
|
||||
except OSError:
|
||||
pass
|
||||
try:
|
||||
os.close(self.r)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def write(self, d):
|
||||
"""
|
||||
Write data to the pipe.
|
||||
"""
|
||||
return fdesc.writeToFD(self.w, d)
|
||||
|
||||
|
||||
def read(self):
|
||||
"""
|
||||
Read data from the pipe.
|
||||
"""
|
||||
l = []
|
||||
res = fdesc.readFromFD(self.r, l.append)
|
||||
if res is None:
|
||||
if l:
|
||||
return l[0]
|
||||
else:
|
||||
return b""
|
||||
else:
|
||||
return res
|
||||
|
||||
|
||||
def test_writeAndRead(self):
|
||||
"""
|
||||
Test that the number of bytes L{fdesc.writeToFD} reports as written
|
||||
with its return value are seen by L{fdesc.readFromFD}.
|
||||
"""
|
||||
n = self.write(b"hello")
|
||||
self.failUnless(n > 0)
|
||||
s = self.read()
|
||||
self.assertEqual(len(s), n)
|
||||
self.assertEqual(b"hello"[:n], s)
|
||||
|
||||
|
||||
def test_writeAndReadLarge(self):
|
||||
"""
|
||||
Similar to L{test_writeAndRead}, but use a much larger string to verify
|
||||
the behavior for that case.
|
||||
"""
|
||||
orig = b"0123456879" * 10000
|
||||
written = self.write(orig)
|
||||
self.failUnless(written > 0)
|
||||
result = []
|
||||
resultlength = 0
|
||||
i = 0
|
||||
while resultlength < written or i < 50:
|
||||
result.append(self.read())
|
||||
resultlength += len(result[-1])
|
||||
# Increment a counter to be sure we'll exit at some point
|
||||
i += 1
|
||||
result = b"".join(result)
|
||||
self.assertEqual(len(result), written)
|
||||
self.assertEqual(orig[:written], result)
|
||||
|
||||
|
||||
def test_readFromEmpty(self):
|
||||
"""
|
||||
Verify that reading from a file descriptor with no data does not raise
|
||||
an exception and does not result in the callback function being called.
|
||||
"""
|
||||
l = []
|
||||
result = fdesc.readFromFD(self.r, l.append)
|
||||
self.assertEqual(l, [])
|
||||
self.assertEqual(result, None)
|
||||
|
||||
|
||||
def test_readFromCleanClose(self):
|
||||
"""
|
||||
Test that using L{fdesc.readFromFD} on a cleanly closed file descriptor
|
||||
returns a connection done indicator.
|
||||
"""
|
||||
os.close(self.w)
|
||||
self.assertEqual(self.read(), fdesc.CONNECTION_DONE)
|
||||
|
||||
|
||||
def test_writeToClosed(self):
|
||||
"""
|
||||
Verify that writing with L{fdesc.writeToFD} when the read end is closed
|
||||
results in a connection lost indicator.
|
||||
"""
|
||||
os.close(self.r)
|
||||
self.assertEqual(self.write(b"s"), fdesc.CONNECTION_LOST)
|
||||
|
||||
|
||||
def test_readFromInvalid(self):
|
||||
"""
|
||||
Verify that reading with L{fdesc.readFromFD} when the read end is
|
||||
closed results in a connection lost indicator.
|
||||
"""
|
||||
os.close(self.r)
|
||||
self.assertEqual(self.read(), fdesc.CONNECTION_LOST)
|
||||
|
||||
|
||||
def test_writeToInvalid(self):
|
||||
"""
|
||||
Verify that writing with L{fdesc.writeToFD} when the write end is
|
||||
closed results in a connection lost indicator.
|
||||
"""
|
||||
os.close(self.w)
|
||||
self.assertEqual(self.write(b"s"), fdesc.CONNECTION_LOST)
|
||||
|
||||
|
||||
def test_writeErrors(self):
|
||||
"""
|
||||
Test error path for L{fdesc.writeTod}.
|
||||
"""
|
||||
oldOsWrite = os.write
|
||||
def eagainWrite(fd, data):
|
||||
err = OSError()
|
||||
err.errno = errno.EAGAIN
|
||||
raise err
|
||||
os.write = eagainWrite
|
||||
try:
|
||||
self.assertEqual(self.write(b"s"), 0)
|
||||
finally:
|
||||
os.write = oldOsWrite
|
||||
|
||||
def eintrWrite(fd, data):
|
||||
err = OSError()
|
||||
err.errno = errno.EINTR
|
||||
raise err
|
||||
os.write = eintrWrite
|
||||
try:
|
||||
self.assertEqual(self.write(b"s"), 0)
|
||||
finally:
|
||||
os.write = oldOsWrite
|
||||
|
||||
|
||||
|
||||
class CloseOnExecTests(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Tests for L{fdesc._setCloseOnExec} and L{fdesc._unsetCloseOnExec}.
|
||||
"""
|
||||
program = '''
|
||||
import os, errno
|
||||
try:
|
||||
os.write(%d, b'lul')
|
||||
except OSError as e:
|
||||
if e.errno == errno.EBADF:
|
||||
os._exit(0)
|
||||
os._exit(5)
|
||||
except:
|
||||
os._exit(10)
|
||||
else:
|
||||
os._exit(20)
|
||||
'''
|
||||
|
||||
def _execWithFileDescriptor(self, fObj):
|
||||
pid = os.fork()
|
||||
if pid == 0:
|
||||
try:
|
||||
os.execv(sys.executable, [sys.executable, '-c', self.program % (fObj.fileno(),)])
|
||||
except:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
os._exit(30)
|
||||
else:
|
||||
# On Linux wait(2) doesn't seem ever able to fail with EINTR but
|
||||
# POSIX seems to allow it and on OS X it happens quite a lot.
|
||||
return untilConcludes(os.waitpid, pid, 0)[1]
|
||||
|
||||
|
||||
def test_setCloseOnExec(self):
|
||||
"""
|
||||
A file descriptor passed to L{fdesc._setCloseOnExec} is not inherited
|
||||
by a new process image created with one of the exec family of
|
||||
functions.
|
||||
"""
|
||||
with open(self.mktemp(), 'wb') as fObj:
|
||||
fdesc._setCloseOnExec(fObj.fileno())
|
||||
status = self._execWithFileDescriptor(fObj)
|
||||
self.assertTrue(os.WIFEXITED(status))
|
||||
self.assertEqual(os.WEXITSTATUS(status), 0)
|
||||
|
||||
|
||||
def test_unsetCloseOnExec(self):
|
||||
"""
|
||||
A file descriptor passed to L{fdesc._unsetCloseOnExec} is inherited by
|
||||
a new process image created with one of the exec family of functions.
|
||||
"""
|
||||
with open(self.mktemp(), 'wb') as fObj:
|
||||
fdesc._setCloseOnExec(fObj.fileno())
|
||||
fdesc._unsetCloseOnExec(fObj.fileno())
|
||||
status = self._execWithFileDescriptor(fObj)
|
||||
self.assertTrue(os.WIFEXITED(status))
|
||||
self.assertEqual(os.WEXITSTATUS(status), 20)
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.protocols.finger}.
|
||||
"""
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.protocols import finger
|
||||
from twisted.test.proto_helpers import StringTransport
|
||||
|
||||
|
||||
class FingerTestCase(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{finger.Finger}.
|
||||
"""
|
||||
def setUp(self):
|
||||
"""
|
||||
Create and connect a L{finger.Finger} instance.
|
||||
"""
|
||||
self.transport = StringTransport()
|
||||
self.protocol = finger.Finger()
|
||||
self.protocol.makeConnection(self.transport)
|
||||
|
||||
|
||||
def test_simple(self):
|
||||
"""
|
||||
When L{finger.Finger} receives a CR LF terminated line, it responds
|
||||
with the default user status message - that no such user exists.
|
||||
"""
|
||||
self.protocol.dataReceived("moshez\r\n")
|
||||
self.assertEqual(
|
||||
self.transport.value(),
|
||||
"Login: moshez\nNo such user\n")
|
||||
|
||||
|
||||
def test_simpleW(self):
|
||||
"""
|
||||
The behavior for a query which begins with C{"/w"} is the same as the
|
||||
behavior for one which does not. The user is reported as not existing.
|
||||
"""
|
||||
self.protocol.dataReceived("/w moshez\r\n")
|
||||
self.assertEqual(
|
||||
self.transport.value(),
|
||||
"Login: moshez\nNo such user\n")
|
||||
|
||||
|
||||
def test_forwarding(self):
|
||||
"""
|
||||
When L{finger.Finger} receives a request for a remote user, it responds
|
||||
with a message rejecting the request.
|
||||
"""
|
||||
self.protocol.dataReceived("moshez@example.com\r\n")
|
||||
self.assertEqual(
|
||||
self.transport.value(),
|
||||
"Finger forwarding service denied\n")
|
||||
|
||||
|
||||
def test_list(self):
|
||||
"""
|
||||
When L{finger.Finger} receives a blank line, it responds with a message
|
||||
rejecting the request for all online users.
|
||||
"""
|
||||
self.protocol.dataReceived("\r\n")
|
||||
self.assertEqual(
|
||||
self.transport.value(),
|
||||
"Finger online list denied\n")
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""
|
||||
Test cases for formmethod module.
|
||||
"""
|
||||
|
||||
from twisted.trial import unittest
|
||||
|
||||
from twisted.python import formmethod
|
||||
|
||||
|
||||
class ArgumentTestCase(unittest.TestCase):
|
||||
|
||||
def argTest(self, argKlass, testPairs, badValues, *args, **kwargs):
|
||||
arg = argKlass("name", *args, **kwargs)
|
||||
for val, result in testPairs:
|
||||
self.assertEqual(arg.coerce(val), result)
|
||||
for val in badValues:
|
||||
self.assertRaises(formmethod.InputError, arg.coerce, val)
|
||||
|
||||
|
||||
def test_argument(self):
|
||||
"""
|
||||
Test that corce correctly raises NotImplementedError.
|
||||
"""
|
||||
arg = formmethod.Argument("name")
|
||||
self.assertRaises(NotImplementedError, arg.coerce, "")
|
||||
|
||||
|
||||
def testString(self):
|
||||
self.argTest(formmethod.String, [("a", "a"), (1, "1"), ("", "")], ())
|
||||
self.argTest(formmethod.String, [("ab", "ab"), ("abc", "abc")], ("2", ""), min=2)
|
||||
self.argTest(formmethod.String, [("ab", "ab"), ("a", "a")], ("223213", "345x"), max=3)
|
||||
self.argTest(formmethod.String, [("ab", "ab"), ("add", "add")], ("223213", "x"), min=2, max=3)
|
||||
|
||||
def testInt(self):
|
||||
self.argTest(formmethod.Integer, [("3", 3), ("-2", -2), ("", None)], ("q", "2.3"))
|
||||
self.argTest(formmethod.Integer, [("3", 3), ("-2", -2)], ("q", "2.3", ""), allowNone=0)
|
||||
|
||||
def testFloat(self):
|
||||
self.argTest(formmethod.Float, [("3", 3.0), ("-2.3", -2.3), ("", None)], ("q", "2.3z"))
|
||||
self.argTest(formmethod.Float, [("3", 3.0), ("-2.3", -2.3)], ("q", "2.3z", ""),
|
||||
allowNone=0)
|
||||
|
||||
def testChoice(self):
|
||||
choices = [("a", "apple", "an apple"),
|
||||
("b", "banana", "ook")]
|
||||
self.argTest(formmethod.Choice, [("a", "apple"), ("b", "banana")],
|
||||
("c", 1), choices=choices)
|
||||
|
||||
def testFlags(self):
|
||||
flags = [("a", "apple", "an apple"),
|
||||
("b", "banana", "ook")]
|
||||
self.argTest(formmethod.Flags,
|
||||
[(["a"], ["apple"]), (["b", "a"], ["banana", "apple"])],
|
||||
(["a", "c"], ["fdfs"]),
|
||||
flags=flags)
|
||||
|
||||
def testBoolean(self):
|
||||
tests = [("yes", 1), ("", 0), ("False", 0), ("no", 0)]
|
||||
self.argTest(formmethod.Boolean, tests, ())
|
||||
|
||||
|
||||
def test_file(self):
|
||||
"""
|
||||
Test the correctness of the coerce function.
|
||||
"""
|
||||
arg = formmethod.File("name", allowNone=0)
|
||||
self.assertEqual(arg.coerce("something"), "something")
|
||||
self.assertRaises(formmethod.InputError, arg.coerce, None)
|
||||
arg2 = formmethod.File("name")
|
||||
self.assertEqual(arg2.coerce(None), None)
|
||||
|
||||
|
||||
def testDate(self):
|
||||
goodTests = {
|
||||
("2002", "12", "21"): (2002, 12, 21),
|
||||
("1996", "2", "29"): (1996, 2, 29),
|
||||
("", "", ""): None,
|
||||
}.items()
|
||||
badTests = [("2002", "2", "29"), ("xx", "2", "3"),
|
||||
("2002", "13", "1"), ("1999", "12","32"),
|
||||
("2002", "1"), ("2002", "2", "3", "4")]
|
||||
self.argTest(formmethod.Date, goodTests, badTests)
|
||||
|
||||
def testRangedInteger(self):
|
||||
goodTests = {"0": 0, "12": 12, "3": 3}.items()
|
||||
badTests = ["-1", "x", "13", "-2000", "3.4"]
|
||||
self.argTest(formmethod.IntegerRange, goodTests, badTests, 0, 12)
|
||||
|
||||
def testVerifiedPassword(self):
|
||||
goodTests = {("foo", "foo"): "foo", ("ab", "ab"): "ab"}.items()
|
||||
badTests = [("ab", "a"), ("12345", "12345"), ("", ""), ("a", "a"), ("a",), ("a", "a", "a")]
|
||||
self.argTest(formmethod.VerifiedPassword, goodTests, badTests, min=2, max=4)
|
||||
|
||||
|
||||
3482
Linux_i686/lib/python2.7/site-packages/twisted/test/test_ftp.py
Normal file
3482
Linux_i686/lib/python2.7/site-packages/twisted/test/test_ftp.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.tap.ftp}.
|
||||
"""
|
||||
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
from twisted.cred import credentials, error
|
||||
from twisted.tap.ftp import Options
|
||||
from twisted.python import versions
|
||||
from twisted.python.filepath import FilePath
|
||||
|
||||
|
||||
|
||||
class FTPOptionsTestCase(TestCase):
|
||||
"""
|
||||
Tests for the command line option parser used for C{twistd ftp}.
|
||||
"""
|
||||
|
||||
usernamePassword = ('iamuser', 'thisispassword')
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Create a file with two users.
|
||||
"""
|
||||
self.filename = self.mktemp()
|
||||
f = FilePath(self.filename)
|
||||
f.setContent(':'.join(self.usernamePassword))
|
||||
self.options = Options()
|
||||
|
||||
|
||||
def test_passwordfileDeprecation(self):
|
||||
"""
|
||||
The C{--password-file} option will emit a warning stating that
|
||||
said option is deprecated.
|
||||
"""
|
||||
self.callDeprecated(
|
||||
versions.Version("Twisted", 11, 1, 0),
|
||||
self.options.opt_password_file, self.filename)
|
||||
|
||||
|
||||
def test_authAdded(self):
|
||||
"""
|
||||
The C{--auth} command-line option will add a checker to the list of
|
||||
checkers
|
||||
"""
|
||||
numCheckers = len(self.options['credCheckers'])
|
||||
self.options.parseOptions(['--auth', 'file:' + self.filename])
|
||||
self.assertEqual(len(self.options['credCheckers']), numCheckers + 1)
|
||||
|
||||
|
||||
def test_authFailure(self):
|
||||
"""
|
||||
The checker created by the C{--auth} command-line option returns a
|
||||
L{Deferred} that fails with L{UnauthorizedLogin} when
|
||||
presented with credentials that are unknown to that checker.
|
||||
"""
|
||||
self.options.parseOptions(['--auth', 'file:' + self.filename])
|
||||
checker = self.options['credCheckers'][-1]
|
||||
invalid = credentials.UsernamePassword(self.usernamePassword[0], 'fake')
|
||||
return (checker.requestAvatarId(invalid)
|
||||
.addCallbacks(
|
||||
lambda ignore: self.fail("Wrong password should raise error"),
|
||||
lambda err: err.trap(error.UnauthorizedLogin)))
|
||||
|
||||
|
||||
def test_authSuccess(self):
|
||||
"""
|
||||
The checker created by the C{--auth} command-line option returns a
|
||||
L{Deferred} that returns the avatar id when presented with credentials
|
||||
that are known to that checker.
|
||||
"""
|
||||
self.options.parseOptions(['--auth', 'file:' + self.filename])
|
||||
checker = self.options['credCheckers'][-1]
|
||||
correct = credentials.UsernamePassword(*self.usernamePassword)
|
||||
return checker.requestAvatarId(correct).addCallback(
|
||||
lambda username: self.assertEqual(username, correct.username)
|
||||
)
|
||||
150
Linux_i686/lib/python2.7/site-packages/twisted/test/test_hook.py
Normal file
150
Linux_i686/lib/python2.7/site-packages/twisted/test/test_hook.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""
|
||||
Test cases for twisted.hook module.
|
||||
"""
|
||||
|
||||
from twisted.python import hook
|
||||
from twisted.trial import unittest
|
||||
|
||||
class BaseClass:
|
||||
"""
|
||||
dummy class to help in testing.
|
||||
"""
|
||||
def __init__(self):
|
||||
"""
|
||||
dummy initializer
|
||||
"""
|
||||
self.calledBasePre = 0
|
||||
self.calledBasePost = 0
|
||||
self.calledBase = 0
|
||||
|
||||
def func(self, a, b):
|
||||
"""
|
||||
dummy method
|
||||
"""
|
||||
assert a == 1
|
||||
assert b == 2
|
||||
self.calledBase = self.calledBase + 1
|
||||
|
||||
|
||||
class SubClass(BaseClass):
|
||||
"""
|
||||
another dummy class
|
||||
"""
|
||||
def __init__(self):
|
||||
"""
|
||||
another dummy initializer
|
||||
"""
|
||||
BaseClass.__init__(self)
|
||||
self.calledSubPre = 0
|
||||
self.calledSubPost = 0
|
||||
self.calledSub = 0
|
||||
|
||||
def func(self, a, b):
|
||||
"""
|
||||
another dummy function
|
||||
"""
|
||||
assert a == 1
|
||||
assert b == 2
|
||||
BaseClass.func(self, a, b)
|
||||
self.calledSub = self.calledSub + 1
|
||||
|
||||
_clean_BaseClass = BaseClass.__dict__.copy()
|
||||
_clean_SubClass = SubClass.__dict__.copy()
|
||||
|
||||
def basePre(base, a, b):
|
||||
"""
|
||||
a pre-hook for the base class
|
||||
"""
|
||||
base.calledBasePre = base.calledBasePre + 1
|
||||
|
||||
def basePost(base, a, b):
|
||||
"""
|
||||
a post-hook for the base class
|
||||
"""
|
||||
base.calledBasePost = base.calledBasePost + 1
|
||||
|
||||
def subPre(sub, a, b):
|
||||
"""
|
||||
a pre-hook for the subclass
|
||||
"""
|
||||
sub.calledSubPre = sub.calledSubPre + 1
|
||||
|
||||
def subPost(sub, a, b):
|
||||
"""
|
||||
a post-hook for the subclass
|
||||
"""
|
||||
sub.calledSubPost = sub.calledSubPost + 1
|
||||
|
||||
class HookTestCase(unittest.TestCase):
|
||||
"""
|
||||
test case to make sure hooks are called
|
||||
"""
|
||||
def setUp(self):
|
||||
"""Make sure we have clean versions of our classes."""
|
||||
BaseClass.__dict__.clear()
|
||||
BaseClass.__dict__.update(_clean_BaseClass)
|
||||
SubClass.__dict__.clear()
|
||||
SubClass.__dict__.update(_clean_SubClass)
|
||||
|
||||
def testBaseHook(self):
|
||||
"""make sure that the base class's hook is called reliably
|
||||
"""
|
||||
base = BaseClass()
|
||||
self.assertEqual(base.calledBase, 0)
|
||||
self.assertEqual(base.calledBasePre, 0)
|
||||
base.func(1,2)
|
||||
self.assertEqual(base.calledBase, 1)
|
||||
self.assertEqual(base.calledBasePre, 0)
|
||||
hook.addPre(BaseClass, "func", basePre)
|
||||
base.func(1, b=2)
|
||||
self.assertEqual(base.calledBase, 2)
|
||||
self.assertEqual(base.calledBasePre, 1)
|
||||
hook.addPost(BaseClass, "func", basePost)
|
||||
base.func(1, b=2)
|
||||
self.assertEqual(base.calledBasePost, 1)
|
||||
self.assertEqual(base.calledBase, 3)
|
||||
self.assertEqual(base.calledBasePre, 2)
|
||||
hook.removePre(BaseClass, "func", basePre)
|
||||
hook.removePost(BaseClass, "func", basePost)
|
||||
base.func(1, b=2)
|
||||
self.assertEqual(base.calledBasePost, 1)
|
||||
self.assertEqual(base.calledBase, 4)
|
||||
self.assertEqual(base.calledBasePre, 2)
|
||||
|
||||
def testSubHook(self):
|
||||
"""test interactions between base-class hooks and subclass hooks
|
||||
"""
|
||||
sub = SubClass()
|
||||
self.assertEqual(sub.calledSub, 0)
|
||||
self.assertEqual(sub.calledBase, 0)
|
||||
sub.func(1, b=2)
|
||||
self.assertEqual(sub.calledSub, 1)
|
||||
self.assertEqual(sub.calledBase, 1)
|
||||
hook.addPre(SubClass, 'func', subPre)
|
||||
self.assertEqual(sub.calledSub, 1)
|
||||
self.assertEqual(sub.calledBase, 1)
|
||||
self.assertEqual(sub.calledSubPre, 0)
|
||||
self.assertEqual(sub.calledBasePre, 0)
|
||||
sub.func(1, b=2)
|
||||
self.assertEqual(sub.calledSub, 2)
|
||||
self.assertEqual(sub.calledBase, 2)
|
||||
self.assertEqual(sub.calledSubPre, 1)
|
||||
self.assertEqual(sub.calledBasePre, 0)
|
||||
# let the pain begin
|
||||
hook.addPre(BaseClass, 'func', basePre)
|
||||
BaseClass.func(sub, 1, b=2)
|
||||
# sub.func(1, b=2)
|
||||
self.assertEqual(sub.calledBase, 3)
|
||||
self.assertEqual(sub.calledBasePre, 1, str(sub.calledBasePre))
|
||||
sub.func(1, b=2)
|
||||
self.assertEqual(sub.calledBasePre, 2)
|
||||
self.assertEqual(sub.calledBase, 4)
|
||||
self.assertEqual(sub.calledSubPre, 2)
|
||||
self.assertEqual(sub.calledSub, 3)
|
||||
|
||||
testCases = [HookTestCase]
|
||||
109
Linux_i686/lib/python2.7/site-packages/twisted/test/test_htb.py
Normal file
109
Linux_i686/lib/python2.7/site-packages/twisted/test/test_htb.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
# -*- Python -*-
|
||||
|
||||
__version__ = '$Revision: 1.3 $'[11:-2]
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.protocols import htb
|
||||
|
||||
class DummyClock:
|
||||
time = 0
|
||||
def set(self, when):
|
||||
self.time = when
|
||||
|
||||
def __call__(self):
|
||||
return self.time
|
||||
|
||||
class SomeBucket(htb.Bucket):
|
||||
maxburst = 100
|
||||
rate = 2
|
||||
|
||||
class TestBucketBase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self._realTimeFunc = htb.time
|
||||
self.clock = DummyClock()
|
||||
htb.time = self.clock
|
||||
|
||||
def tearDown(self):
|
||||
htb.time = self._realTimeFunc
|
||||
|
||||
class TestBucket(TestBucketBase):
|
||||
def testBucketSize(self):
|
||||
"""Testing the size of the bucket."""
|
||||
b = SomeBucket()
|
||||
fit = b.add(1000)
|
||||
self.assertEqual(100, fit)
|
||||
|
||||
def testBucketDrain(self):
|
||||
"""Testing the bucket's drain rate."""
|
||||
b = SomeBucket()
|
||||
fit = b.add(1000)
|
||||
self.clock.set(10)
|
||||
fit = b.add(1000)
|
||||
self.assertEqual(20, fit)
|
||||
|
||||
def test_bucketEmpty(self):
|
||||
"""
|
||||
L{htb.Bucket.drip} returns C{True} if the bucket is empty after that drip.
|
||||
"""
|
||||
b = SomeBucket()
|
||||
b.add(20)
|
||||
self.clock.set(9)
|
||||
empty = b.drip()
|
||||
self.assertFalse(empty)
|
||||
self.clock.set(10)
|
||||
empty = b.drip()
|
||||
self.assertTrue(empty)
|
||||
|
||||
class TestBucketNesting(TestBucketBase):
|
||||
def setUp(self):
|
||||
TestBucketBase.setUp(self)
|
||||
self.parent = SomeBucket()
|
||||
self.child1 = SomeBucket(self.parent)
|
||||
self.child2 = SomeBucket(self.parent)
|
||||
|
||||
def testBucketParentSize(self):
|
||||
# Use up most of the parent bucket.
|
||||
self.child1.add(90)
|
||||
fit = self.child2.add(90)
|
||||
self.assertEqual(10, fit)
|
||||
|
||||
def testBucketParentRate(self):
|
||||
# Make the parent bucket drain slower.
|
||||
self.parent.rate = 1
|
||||
# Fill both child1 and parent.
|
||||
self.child1.add(100)
|
||||
self.clock.set(10)
|
||||
fit = self.child1.add(100)
|
||||
# How much room was there? The child bucket would have had 20,
|
||||
# but the parent bucket only ten (so no, it wouldn't make too much
|
||||
# sense to have a child bucket draining faster than its parent in a real
|
||||
# application.)
|
||||
self.assertEqual(10, fit)
|
||||
|
||||
|
||||
# TODO: Test the Transport stuff?
|
||||
|
||||
from test_pcp import DummyConsumer
|
||||
|
||||
class ConsumerShaperTest(TestBucketBase):
|
||||
def setUp(self):
|
||||
TestBucketBase.setUp(self)
|
||||
self.underlying = DummyConsumer()
|
||||
self.bucket = SomeBucket()
|
||||
self.shaped = htb.ShapedConsumer(self.underlying, self.bucket)
|
||||
|
||||
def testRate(self):
|
||||
# Start off with a full bucket, so the burst-size dosen't factor in
|
||||
# to the calculations.
|
||||
delta_t = 10
|
||||
self.bucket.add(100)
|
||||
self.shaped.write("x" * 100)
|
||||
self.clock.set(delta_t)
|
||||
self.shaped.resumeProducing()
|
||||
self.assertEqual(len(self.underlying.getvalue()),
|
||||
delta_t * self.bucket.rate)
|
||||
|
||||
def testBucketRefs(self):
|
||||
self.assertEqual(self.bucket._refcount, 1)
|
||||
self.shaped.stopProducing()
|
||||
self.assertEqual(self.bucket._refcount, 0)
|
||||
|
|
@ -0,0 +1,194 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""
|
||||
Test cases for twisted.protocols.ident module.
|
||||
"""
|
||||
|
||||
import struct
|
||||
|
||||
from twisted.protocols import ident
|
||||
from twisted.python import failure
|
||||
from twisted.internet import error
|
||||
from twisted.internet import defer
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.test.proto_helpers import StringTransport
|
||||
|
||||
|
||||
|
||||
class ClassParserTestCase(unittest.TestCase):
|
||||
"""
|
||||
Test parsing of ident responses.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Create a ident client used in tests.
|
||||
"""
|
||||
self.client = ident.IdentClient()
|
||||
|
||||
|
||||
def test_indentError(self):
|
||||
"""
|
||||
'UNKNOWN-ERROR' error should map to the L{ident.IdentError} exception.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
self.client.queries.append((d, 123, 456))
|
||||
self.client.lineReceived('123, 456 : ERROR : UNKNOWN-ERROR')
|
||||
return self.assertFailure(d, ident.IdentError)
|
||||
|
||||
|
||||
def test_noUSerError(self):
|
||||
"""
|
||||
'NO-USER' error should map to the L{ident.NoUser} exception.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
self.client.queries.append((d, 234, 456))
|
||||
self.client.lineReceived('234, 456 : ERROR : NO-USER')
|
||||
return self.assertFailure(d, ident.NoUser)
|
||||
|
||||
|
||||
def test_invalidPortError(self):
|
||||
"""
|
||||
'INVALID-PORT' error should map to the L{ident.InvalidPort} exception.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
self.client.queries.append((d, 345, 567))
|
||||
self.client.lineReceived('345, 567 : ERROR : INVALID-PORT')
|
||||
return self.assertFailure(d, ident.InvalidPort)
|
||||
|
||||
|
||||
def test_hiddenUserError(self):
|
||||
"""
|
||||
'HIDDEN-USER' error should map to the L{ident.HiddenUser} exception.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
self.client.queries.append((d, 567, 789))
|
||||
self.client.lineReceived('567, 789 : ERROR : HIDDEN-USER')
|
||||
return self.assertFailure(d, ident.HiddenUser)
|
||||
|
||||
|
||||
def test_lostConnection(self):
|
||||
"""
|
||||
A pending query which failed because of a ConnectionLost should
|
||||
receive an L{ident.IdentError}.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
self.client.queries.append((d, 765, 432))
|
||||
self.client.connectionLost(failure.Failure(error.ConnectionLost()))
|
||||
return self.assertFailure(d, ident.IdentError)
|
||||
|
||||
|
||||
|
||||
class TestIdentServer(ident.IdentServer):
|
||||
def lookup(self, serverAddress, clientAddress):
|
||||
return self.resultValue
|
||||
|
||||
|
||||
class TestErrorIdentServer(ident.IdentServer):
|
||||
def lookup(self, serverAddress, clientAddress):
|
||||
raise self.exceptionType()
|
||||
|
||||
|
||||
class NewException(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class ServerParserTestCase(unittest.TestCase):
|
||||
def testErrors(self):
|
||||
p = TestErrorIdentServer()
|
||||
p.makeConnection(StringTransport())
|
||||
L = []
|
||||
p.sendLine = L.append
|
||||
|
||||
p.exceptionType = ident.IdentError
|
||||
p.lineReceived('123, 345')
|
||||
self.assertEqual(L[0], '123, 345 : ERROR : UNKNOWN-ERROR')
|
||||
|
||||
p.exceptionType = ident.NoUser
|
||||
p.lineReceived('432, 210')
|
||||
self.assertEqual(L[1], '432, 210 : ERROR : NO-USER')
|
||||
|
||||
p.exceptionType = ident.InvalidPort
|
||||
p.lineReceived('987, 654')
|
||||
self.assertEqual(L[2], '987, 654 : ERROR : INVALID-PORT')
|
||||
|
||||
p.exceptionType = ident.HiddenUser
|
||||
p.lineReceived('756, 827')
|
||||
self.assertEqual(L[3], '756, 827 : ERROR : HIDDEN-USER')
|
||||
|
||||
p.exceptionType = NewException
|
||||
p.lineReceived('987, 789')
|
||||
self.assertEqual(L[4], '987, 789 : ERROR : UNKNOWN-ERROR')
|
||||
errs = self.flushLoggedErrors(NewException)
|
||||
self.assertEqual(len(errs), 1)
|
||||
|
||||
for port in -1, 0, 65536, 65537:
|
||||
del L[:]
|
||||
p.lineReceived('%d, 5' % (port,))
|
||||
p.lineReceived('5, %d' % (port,))
|
||||
self.assertEqual(
|
||||
L, ['%d, 5 : ERROR : INVALID-PORT' % (port,),
|
||||
'5, %d : ERROR : INVALID-PORT' % (port,)])
|
||||
|
||||
def testSuccess(self):
|
||||
p = TestIdentServer()
|
||||
p.makeConnection(StringTransport())
|
||||
L = []
|
||||
p.sendLine = L.append
|
||||
|
||||
p.resultValue = ('SYS', 'USER')
|
||||
p.lineReceived('123, 456')
|
||||
self.assertEqual(L[0], '123, 456 : USERID : SYS : USER')
|
||||
|
||||
|
||||
if struct.pack('=L', 1)[0] == '\x01':
|
||||
_addr1 = '0100007F'
|
||||
_addr2 = '04030201'
|
||||
else:
|
||||
_addr1 = '7F000001'
|
||||
_addr2 = '01020304'
|
||||
|
||||
|
||||
class ProcMixinTestCase(unittest.TestCase):
|
||||
line = ('4: %s:0019 %s:02FA 0A 00000000:00000000 '
|
||||
'00:00000000 00000000 0 0 10927 1 f72a5b80 '
|
||||
'3000 0 0 2 -1') % (_addr1, _addr2)
|
||||
|
||||
def testDottedQuadFromHexString(self):
|
||||
p = ident.ProcServerMixin()
|
||||
self.assertEqual(p.dottedQuadFromHexString(_addr1), '127.0.0.1')
|
||||
|
||||
def testUnpackAddress(self):
|
||||
p = ident.ProcServerMixin()
|
||||
self.assertEqual(p.unpackAddress(_addr1 + ':0277'),
|
||||
('127.0.0.1', 631))
|
||||
|
||||
def testLineParser(self):
|
||||
p = ident.ProcServerMixin()
|
||||
self.assertEqual(
|
||||
p.parseLine(self.line),
|
||||
(('127.0.0.1', 25), ('1.2.3.4', 762), 0))
|
||||
|
||||
def testExistingAddress(self):
|
||||
username = []
|
||||
p = ident.ProcServerMixin()
|
||||
p.entries = lambda: iter([self.line])
|
||||
p.getUsername = lambda uid: (username.append(uid), 'root')[1]
|
||||
self.assertEqual(
|
||||
p.lookup(('127.0.0.1', 25), ('1.2.3.4', 762)),
|
||||
(p.SYSTEM_NAME, 'root'))
|
||||
self.assertEqual(username, [0])
|
||||
|
||||
def testNonExistingAddress(self):
|
||||
p = ident.ProcServerMixin()
|
||||
p.entries = lambda: iter([self.line])
|
||||
self.assertRaises(ident.NoUser, p.lookup, ('127.0.0.1', 26),
|
||||
('1.2.3.4', 762))
|
||||
self.assertRaises(ident.NoUser, p.lookup, ('127.0.0.1', 25),
|
||||
('1.2.3.5', 762))
|
||||
self.assertRaises(ident.NoUser, p.lookup, ('127.0.0.1', 25),
|
||||
('1.2.3.4', 763))
|
||||
|
||||
1442
Linux_i686/lib/python2.7/site-packages/twisted/test/test_internet.py
Normal file
1442
Linux_i686/lib/python2.7/site-packages/twisted/test/test_internet.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.test.iosim}.
|
||||
"""
|
||||
|
||||
from twisted.test.iosim import FakeTransport
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
class FakeTransportTests(TestCase):
|
||||
"""
|
||||
Tests for L{FakeTransport}
|
||||
"""
|
||||
|
||||
def test_connectionSerial(self):
|
||||
"""
|
||||
Each L{FakeTransport} receives a serial number that uniquely identifies
|
||||
it.
|
||||
"""
|
||||
a = FakeTransport(object(), True)
|
||||
b = FakeTransport(object(), False)
|
||||
self.assertIsInstance(a.serial, int)
|
||||
self.assertIsInstance(b.serial, int)
|
||||
self.assertNotEqual(a.serial, b.serial)
|
||||
|
|
@ -0,0 +1,351 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test running processes with the APIs in L{twisted.internet.utils}.
|
||||
"""
|
||||
|
||||
from __future__ import division, absolute_import
|
||||
|
||||
import warnings, os, stat, sys, signal
|
||||
|
||||
from twisted.python.compat import _PY3
|
||||
from twisted.python.runtime import platform
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import error, reactor, utils, interfaces
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.python.test.test_util import SuppressedWarningsTests
|
||||
|
||||
|
||||
class ProcessUtilsTests(unittest.TestCase):
|
||||
"""
|
||||
Test running a process using L{getProcessOutput}, L{getProcessValue}, and
|
||||
L{getProcessOutputAndValue}.
|
||||
"""
|
||||
|
||||
if interfaces.IReactorProcess(reactor, None) is None:
|
||||
skip = "reactor doesn't implement IReactorProcess"
|
||||
|
||||
output = None
|
||||
value = None
|
||||
exe = sys.executable
|
||||
|
||||
def makeSourceFile(self, sourceLines):
|
||||
"""
|
||||
Write the given list of lines to a text file and return the absolute
|
||||
path to it.
|
||||
"""
|
||||
script = self.mktemp()
|
||||
scriptFile = file(script, 'wt')
|
||||
scriptFile.write(os.linesep.join(sourceLines) + os.linesep)
|
||||
scriptFile.close()
|
||||
return os.path.abspath(script)
|
||||
|
||||
|
||||
def test_output(self):
|
||||
"""
|
||||
L{getProcessOutput} returns a L{Deferred} which fires with the complete
|
||||
output of the process it runs after that process exits.
|
||||
"""
|
||||
scriptFile = self.makeSourceFile([
|
||||
"import sys",
|
||||
"for s in 'hello world\\n':",
|
||||
" sys.stdout.write(s)",
|
||||
" sys.stdout.flush()"])
|
||||
d = utils.getProcessOutput(self.exe, ['-u', scriptFile])
|
||||
return d.addCallback(self.assertEqual, "hello world\n")
|
||||
|
||||
|
||||
def test_outputWithErrorIgnored(self):
|
||||
"""
|
||||
The L{Deferred} returned by L{getProcessOutput} is fired with an
|
||||
L{IOError} L{Failure} if the child process writes to stderr.
|
||||
"""
|
||||
# make sure stderr raises an error normally
|
||||
scriptFile = self.makeSourceFile([
|
||||
'import sys',
|
||||
'sys.stderr.write("hello world\\n")'
|
||||
])
|
||||
|
||||
d = utils.getProcessOutput(self.exe, ['-u', scriptFile])
|
||||
d = self.assertFailure(d, IOError)
|
||||
def cbFailed(err):
|
||||
return self.assertFailure(err.processEnded, error.ProcessDone)
|
||||
d.addCallback(cbFailed)
|
||||
return d
|
||||
|
||||
|
||||
def test_outputWithErrorCollected(self):
|
||||
"""
|
||||
If a C{True} value is supplied for the C{errortoo} parameter to
|
||||
L{getProcessOutput}, the returned L{Deferred} fires with the child's
|
||||
stderr output as well as its stdout output.
|
||||
"""
|
||||
scriptFile = self.makeSourceFile([
|
||||
'import sys',
|
||||
# Write the same value to both because ordering isn't guaranteed so
|
||||
# this simplifies the test.
|
||||
'sys.stdout.write("foo")',
|
||||
'sys.stdout.flush()',
|
||||
'sys.stderr.write("foo")',
|
||||
'sys.stderr.flush()'])
|
||||
|
||||
d = utils.getProcessOutput(self.exe, ['-u', scriptFile], errortoo=True)
|
||||
return d.addCallback(self.assertEqual, "foofoo")
|
||||
|
||||
|
||||
def test_value(self):
|
||||
"""
|
||||
The L{Deferred} returned by L{getProcessValue} is fired with the exit
|
||||
status of the child process.
|
||||
"""
|
||||
scriptFile = self.makeSourceFile(["raise SystemExit(1)"])
|
||||
|
||||
d = utils.getProcessValue(self.exe, ['-u', scriptFile])
|
||||
return d.addCallback(self.assertEqual, 1)
|
||||
|
||||
|
||||
def test_outputAndValue(self):
|
||||
"""
|
||||
The L{Deferred} returned by L{getProcessOutputAndValue} fires with a
|
||||
three-tuple, the elements of which give the data written to the child's
|
||||
stdout, the data written to the child's stderr, and the exit status of
|
||||
the child.
|
||||
"""
|
||||
exe = sys.executable
|
||||
scriptFile = self.makeSourceFile([
|
||||
"import sys",
|
||||
"sys.stdout.write('hello world!\\n')",
|
||||
"sys.stderr.write('goodbye world!\\n')",
|
||||
"sys.exit(1)"
|
||||
])
|
||||
|
||||
def gotOutputAndValue(out_err_code):
|
||||
out, err, code = out_err_code
|
||||
self.assertEqual(out, "hello world!\n")
|
||||
self.assertEqual(err, "goodbye world!" + os.linesep)
|
||||
self.assertEqual(code, 1)
|
||||
d = utils.getProcessOutputAndValue(self.exe, ["-u", scriptFile])
|
||||
return d.addCallback(gotOutputAndValue)
|
||||
|
||||
|
||||
def test_outputSignal(self):
|
||||
"""
|
||||
If the child process exits because of a signal, the L{Deferred}
|
||||
returned by L{getProcessOutputAndValue} fires a L{Failure} of a tuple
|
||||
containing the the child's stdout, stderr, and the signal which caused
|
||||
it to exit.
|
||||
"""
|
||||
# Use SIGKILL here because it's guaranteed to be delivered. Using
|
||||
# SIGHUP might not work in, e.g., a buildbot slave run under the
|
||||
# 'nohup' command.
|
||||
scriptFile = self.makeSourceFile([
|
||||
"import sys, os, signal",
|
||||
"sys.stdout.write('stdout bytes\\n')",
|
||||
"sys.stderr.write('stderr bytes\\n')",
|
||||
"sys.stdout.flush()",
|
||||
"sys.stderr.flush()",
|
||||
"os.kill(os.getpid(), signal.SIGKILL)"])
|
||||
|
||||
def gotOutputAndValue(out_err_sig):
|
||||
out, err, sig = out_err_sig
|
||||
self.assertEqual(out, "stdout bytes\n")
|
||||
self.assertEqual(err, "stderr bytes\n")
|
||||
self.assertEqual(sig, signal.SIGKILL)
|
||||
|
||||
d = utils.getProcessOutputAndValue(self.exe, ['-u', scriptFile])
|
||||
d = self.assertFailure(d, tuple)
|
||||
return d.addCallback(gotOutputAndValue)
|
||||
|
||||
if platform.isWindows():
|
||||
test_outputSignal.skip = "Windows doesn't have real signals."
|
||||
|
||||
|
||||
def _pathTest(self, utilFunc, check):
|
||||
dir = os.path.abspath(self.mktemp())
|
||||
os.makedirs(dir)
|
||||
scriptFile = self.makeSourceFile([
|
||||
"import os, sys",
|
||||
"sys.stdout.write(os.getcwd())"])
|
||||
d = utilFunc(self.exe, ['-u', scriptFile], path=dir)
|
||||
d.addCallback(check, dir)
|
||||
return d
|
||||
|
||||
|
||||
def test_getProcessOutputPath(self):
|
||||
"""
|
||||
L{getProcessOutput} runs the given command with the working directory
|
||||
given by the C{path} parameter.
|
||||
"""
|
||||
return self._pathTest(utils.getProcessOutput, self.assertEqual)
|
||||
|
||||
|
||||
def test_getProcessValuePath(self):
|
||||
"""
|
||||
L{getProcessValue} runs the given command with the working directory
|
||||
given by the C{path} parameter.
|
||||
"""
|
||||
def check(result, ignored):
|
||||
self.assertEqual(result, 0)
|
||||
return self._pathTest(utils.getProcessValue, check)
|
||||
|
||||
|
||||
def test_getProcessOutputAndValuePath(self):
|
||||
"""
|
||||
L{getProcessOutputAndValue} runs the given command with the working
|
||||
directory given by the C{path} parameter.
|
||||
"""
|
||||
def check(out_err_status, dir):
|
||||
out, err, status = out_err_status
|
||||
self.assertEqual(out, dir)
|
||||
self.assertEqual(status, 0)
|
||||
return self._pathTest(utils.getProcessOutputAndValue, check)
|
||||
|
||||
|
||||
def _defaultPathTest(self, utilFunc, check):
|
||||
# Make another directory to mess around with.
|
||||
dir = os.path.abspath(self.mktemp())
|
||||
os.makedirs(dir)
|
||||
|
||||
scriptFile = self.makeSourceFile([
|
||||
"import os, sys, stat",
|
||||
# Fix the permissions so we can report the working directory.
|
||||
# On OS X (and maybe elsewhere), os.getcwd() fails with EACCES
|
||||
# if +x is missing from the working directory.
|
||||
"os.chmod(%r, stat.S_IXUSR)" % (dir,),
|
||||
"sys.stdout.write(os.getcwd())"])
|
||||
|
||||
# Switch to it, but make sure we switch back
|
||||
self.addCleanup(os.chdir, os.getcwd())
|
||||
os.chdir(dir)
|
||||
|
||||
# Get rid of all its permissions, but make sure they get cleaned up
|
||||
# later, because otherwise it might be hard to delete the trial
|
||||
# temporary directory.
|
||||
self.addCleanup(
|
||||
os.chmod, dir, stat.S_IMODE(os.stat('.').st_mode))
|
||||
os.chmod(dir, 0)
|
||||
|
||||
d = utilFunc(self.exe, ['-u', scriptFile])
|
||||
d.addCallback(check, dir)
|
||||
return d
|
||||
|
||||
|
||||
def test_getProcessOutputDefaultPath(self):
|
||||
"""
|
||||
If no value is supplied for the C{path} parameter, L{getProcessOutput}
|
||||
runs the given command in the same working directory as the parent
|
||||
process and succeeds even if the current working directory is not
|
||||
accessible.
|
||||
"""
|
||||
return self._defaultPathTest(utils.getProcessOutput, self.assertEqual)
|
||||
|
||||
|
||||
def test_getProcessValueDefaultPath(self):
|
||||
"""
|
||||
If no value is supplied for the C{path} parameter, L{getProcessValue}
|
||||
runs the given command in the same working directory as the parent
|
||||
process and succeeds even if the current working directory is not
|
||||
accessible.
|
||||
"""
|
||||
def check(result, ignored):
|
||||
self.assertEqual(result, 0)
|
||||
return self._defaultPathTest(utils.getProcessValue, check)
|
||||
|
||||
|
||||
def test_getProcessOutputAndValueDefaultPath(self):
|
||||
"""
|
||||
If no value is supplied for the C{path} parameter,
|
||||
L{getProcessOutputAndValue} runs the given command in the same working
|
||||
directory as the parent process and succeeds even if the current
|
||||
working directory is not accessible.
|
||||
"""
|
||||
def check(out_err_status, dir):
|
||||
out, err, status = out_err_status
|
||||
self.assertEqual(out, dir)
|
||||
self.assertEqual(status, 0)
|
||||
return self._defaultPathTest(
|
||||
utils.getProcessOutputAndValue, check)
|
||||
|
||||
|
||||
|
||||
class SuppressWarningsTests(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Tests for L{utils.suppressWarnings}.
|
||||
"""
|
||||
def test_suppressWarnings(self):
|
||||
"""
|
||||
L{utils.suppressWarnings} decorates a function so that the given
|
||||
warnings are suppressed.
|
||||
"""
|
||||
result = []
|
||||
def showwarning(self, *a, **kw):
|
||||
result.append((a, kw))
|
||||
self.patch(warnings, "showwarning", showwarning)
|
||||
|
||||
def f(msg):
|
||||
warnings.warn(msg)
|
||||
g = utils.suppressWarnings(f, (('ignore',), dict(message="This is message")))
|
||||
|
||||
# Start off with a sanity check - calling the original function
|
||||
# should emit the warning.
|
||||
f("Sanity check message")
|
||||
self.assertEqual(len(result), 1)
|
||||
|
||||
# Now that that's out of the way, call the wrapped function, and
|
||||
# make sure no new warnings show up.
|
||||
g("This is message")
|
||||
self.assertEqual(len(result), 1)
|
||||
|
||||
# Finally, emit another warning which should not be ignored, and
|
||||
# make sure it is not.
|
||||
g("Unignored message")
|
||||
self.assertEqual(len(result), 2)
|
||||
|
||||
|
||||
|
||||
class DeferredSuppressedWarningsTests(SuppressedWarningsTests):
|
||||
"""
|
||||
Tests for L{utils.runWithWarningsSuppressed}, the version that supports
|
||||
Deferreds.
|
||||
"""
|
||||
# Override the non-Deferred-supporting function from the base class with
|
||||
# the function we are testing in this class:
|
||||
runWithWarningsSuppressed = staticmethod(utils.runWithWarningsSuppressed)
|
||||
|
||||
def test_deferredCallback(self):
|
||||
"""
|
||||
If the function called by L{utils.runWithWarningsSuppressed} returns a
|
||||
C{Deferred}, the warning filters aren't removed until the Deferred
|
||||
fires.
|
||||
"""
|
||||
filters = [(("ignore", ".*foo.*"), {}),
|
||||
(("ignore", ".*bar.*"), {})]
|
||||
result = Deferred()
|
||||
self.runWithWarningsSuppressed(filters, lambda: result)
|
||||
warnings.warn("ignore foo")
|
||||
result.callback(3)
|
||||
warnings.warn("ignore foo 2")
|
||||
self.assertEqual(
|
||||
["ignore foo 2"], [w['message'] for w in self.flushWarnings()])
|
||||
|
||||
|
||||
def test_deferredErrback(self):
|
||||
"""
|
||||
If the function called by L{utils.runWithWarningsSuppressed} returns a
|
||||
C{Deferred}, the warning filters aren't removed until the Deferred
|
||||
fires with an errback.
|
||||
"""
|
||||
filters = [(("ignore", ".*foo.*"), {}),
|
||||
(("ignore", ".*bar.*"), {})]
|
||||
result = Deferred()
|
||||
d = self.runWithWarningsSuppressed(filters, lambda: result)
|
||||
warnings.warn("ignore foo")
|
||||
result.errback(ZeroDivisionError())
|
||||
d.addErrback(lambda f: f.trap(ZeroDivisionError))
|
||||
warnings.warn("ignore foo 2")
|
||||
self.assertEqual(
|
||||
["ignore foo 2"], [w['message'] for w in self.flushWarnings()])
|
||||
|
||||
if _PY3:
|
||||
del ProcessUtilsTests
|
||||
|
|
@ -0,0 +1,640 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test cases for L{jelly} object serialization.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import decimal
|
||||
|
||||
from twisted.spread import jelly, pb
|
||||
from twisted.trial import unittest
|
||||
from twisted.test.proto_helpers import StringTransport
|
||||
|
||||
|
||||
class TestNode(object, jelly.Jellyable):
|
||||
"""
|
||||
An object to test jellyfying of new style class instances.
|
||||
"""
|
||||
classAttr = 4
|
||||
|
||||
def __init__(self, parent=None):
|
||||
if parent:
|
||||
self.id = parent.id + 1
|
||||
parent.children.append(self)
|
||||
else:
|
||||
self.id = 1
|
||||
self.parent = parent
|
||||
self.children = []
|
||||
|
||||
|
||||
|
||||
class A:
|
||||
"""
|
||||
Dummy class.
|
||||
"""
|
||||
|
||||
def amethod(self):
|
||||
"""
|
||||
Method tp be used in serialization tests.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
def afunc(self):
|
||||
"""
|
||||
A dummy function to test function serialization.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class B:
|
||||
"""
|
||||
Dummy class.
|
||||
"""
|
||||
|
||||
def bmethod(self):
|
||||
"""
|
||||
Method to be used in serialization tests.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class C:
|
||||
"""
|
||||
Dummy class.
|
||||
"""
|
||||
|
||||
def cmethod(self):
|
||||
"""
|
||||
Method to be used in serialization tests.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class D(object):
|
||||
"""
|
||||
Dummy new-style class.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class E(object):
|
||||
"""
|
||||
Dummy new-style class with slots.
|
||||
"""
|
||||
|
||||
__slots__ = ("x", "y")
|
||||
|
||||
def __init__(self, x=None, y=None):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
|
||||
def __getstate__(self):
|
||||
return {"x" : self.x, "y" : self.y}
|
||||
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.x = state["x"]
|
||||
self.y = state["y"]
|
||||
|
||||
|
||||
|
||||
class SimpleJellyTest:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def isTheSameAs(self, other):
|
||||
return self.__dict__ == other.__dict__
|
||||
|
||||
|
||||
|
||||
class JellyTestCase(unittest.TestCase):
|
||||
"""
|
||||
Testcases for L{jelly} module serialization.
|
||||
|
||||
@cvar decimalData: serialized version of decimal data, to be used in tests.
|
||||
@type decimalData: C{list}
|
||||
"""
|
||||
|
||||
def _testSecurity(self, inputList, atom):
|
||||
"""
|
||||
Helper test method to test security options for a type.
|
||||
|
||||
@param inputList: a sample input for the type.
|
||||
@param inputList: C{list}
|
||||
|
||||
@param atom: atom identifier for the type.
|
||||
@type atom: C{str}
|
||||
"""
|
||||
c = jelly.jelly(inputList)
|
||||
taster = jelly.SecurityOptions()
|
||||
taster.allowBasicTypes()
|
||||
# By default, it should succeed
|
||||
jelly.unjelly(c, taster)
|
||||
taster.allowedTypes.pop(atom)
|
||||
# But it should raise an exception when disallowed
|
||||
self.assertRaises(jelly.InsecureJelly, jelly.unjelly, c, taster)
|
||||
|
||||
|
||||
def test_methodSelfIdentity(self):
|
||||
a = A()
|
||||
b = B()
|
||||
a.bmethod = b.bmethod
|
||||
b.a = a
|
||||
im_ = jelly.unjelly(jelly.jelly(b)).a.bmethod
|
||||
self.assertEqual(im_.im_class, im_.im_self.__class__)
|
||||
|
||||
|
||||
def test_methodsNotSelfIdentity(self):
|
||||
"""
|
||||
If a class change after an instance has been created, L{jelly.unjelly}
|
||||
shoud raise a C{TypeError} when trying to unjelly the instance.
|
||||
"""
|
||||
a = A()
|
||||
b = B()
|
||||
c = C()
|
||||
a.bmethod = c.cmethod
|
||||
b.a = a
|
||||
savecmethod = C.cmethod
|
||||
del C.cmethod
|
||||
try:
|
||||
self.assertRaises(TypeError, jelly.unjelly, jelly.jelly(b))
|
||||
finally:
|
||||
C.cmethod = savecmethod
|
||||
|
||||
|
||||
def test_newStyle(self):
|
||||
n = D()
|
||||
n.x = 1
|
||||
n2 = D()
|
||||
n.n2 = n2
|
||||
n.n3 = n2
|
||||
c = jelly.jelly(n)
|
||||
m = jelly.unjelly(c)
|
||||
self.assertIsInstance(m, D)
|
||||
self.assertIdentical(m.n2, m.n3)
|
||||
|
||||
|
||||
def test_newStyleWithSlots(self):
|
||||
"""
|
||||
A class defined with I{slots} can be jellied and unjellied with the
|
||||
values for its attributes preserved.
|
||||
"""
|
||||
n = E()
|
||||
n.x = 1
|
||||
c = jelly.jelly(n)
|
||||
m = jelly.unjelly(c)
|
||||
self.assertIsInstance(m, E)
|
||||
self.assertEqual(n.x, 1)
|
||||
|
||||
|
||||
def test_typeOldStyle(self):
|
||||
"""
|
||||
Test that an old style class type can be jellied and unjellied
|
||||
to the original type.
|
||||
"""
|
||||
t = [C]
|
||||
r = jelly.unjelly(jelly.jelly(t))
|
||||
self.assertEqual(t, r)
|
||||
|
||||
|
||||
def test_typeNewStyle(self):
|
||||
"""
|
||||
Test that a new style class type can be jellied and unjellied
|
||||
to the original type.
|
||||
"""
|
||||
t = [D]
|
||||
r = jelly.unjelly(jelly.jelly(t))
|
||||
self.assertEqual(t, r)
|
||||
|
||||
|
||||
def test_typeBuiltin(self):
|
||||
"""
|
||||
Test that a builtin type can be jellied and unjellied to the original
|
||||
type.
|
||||
"""
|
||||
t = [str]
|
||||
r = jelly.unjelly(jelly.jelly(t))
|
||||
self.assertEqual(t, r)
|
||||
|
||||
|
||||
def test_dateTime(self):
|
||||
dtn = datetime.datetime.now()
|
||||
dtd = datetime.datetime.now() - dtn
|
||||
input = [dtn, dtd]
|
||||
c = jelly.jelly(input)
|
||||
output = jelly.unjelly(c)
|
||||
self.assertEqual(input, output)
|
||||
self.assertNotIdentical(input, output)
|
||||
|
||||
|
||||
def test_decimal(self):
|
||||
"""
|
||||
Jellying L{decimal.Decimal} instances and then unjellying the result
|
||||
should produce objects which represent the values of the original
|
||||
inputs.
|
||||
"""
|
||||
inputList = [decimal.Decimal('9.95'),
|
||||
decimal.Decimal(0),
|
||||
decimal.Decimal(123456),
|
||||
decimal.Decimal('-78.901')]
|
||||
c = jelly.jelly(inputList)
|
||||
output = jelly.unjelly(c)
|
||||
self.assertEqual(inputList, output)
|
||||
self.assertNotIdentical(inputList, output)
|
||||
|
||||
|
||||
decimalData = ['list', ['decimal', 995, -2], ['decimal', 0, 0],
|
||||
['decimal', 123456, 0], ['decimal', -78901, -3]]
|
||||
|
||||
|
||||
def test_decimalUnjelly(self):
|
||||
"""
|
||||
Unjellying the s-expressions produced by jelly for L{decimal.Decimal}
|
||||
instances should result in L{decimal.Decimal} instances with the values
|
||||
represented by the s-expressions.
|
||||
|
||||
This test also verifies that C{self.decimalData} contains valid jellied
|
||||
data. This is important since L{test_decimalMissing} re-uses
|
||||
C{self.decimalData} and is expected to be unable to produce
|
||||
L{decimal.Decimal} instances even though the s-expression correctly
|
||||
represents a list of them.
|
||||
"""
|
||||
expected = [decimal.Decimal('9.95'),
|
||||
decimal.Decimal(0),
|
||||
decimal.Decimal(123456),
|
||||
decimal.Decimal('-78.901')]
|
||||
output = jelly.unjelly(self.decimalData)
|
||||
self.assertEqual(output, expected)
|
||||
|
||||
|
||||
def test_decimalSecurity(self):
|
||||
"""
|
||||
By default, C{decimal} objects should be allowed by
|
||||
L{jelly.SecurityOptions}. If not allowed, L{jelly.unjelly} should raise
|
||||
L{jelly.InsecureJelly} when trying to unjelly it.
|
||||
"""
|
||||
inputList = [decimal.Decimal('9.95')]
|
||||
self._testSecurity(inputList, "decimal")
|
||||
|
||||
|
||||
def test_set(self):
|
||||
"""
|
||||
Jellying C{set} instances and then unjellying the result
|
||||
should produce objects which represent the values of the original
|
||||
inputs.
|
||||
"""
|
||||
inputList = [set([1, 2, 3])]
|
||||
output = jelly.unjelly(jelly.jelly(inputList))
|
||||
self.assertEqual(inputList, output)
|
||||
self.assertNotIdentical(inputList, output)
|
||||
|
||||
|
||||
def test_frozenset(self):
|
||||
"""
|
||||
Jellying C{frozenset} instances and then unjellying the result
|
||||
should produce objects which represent the values of the original
|
||||
inputs.
|
||||
"""
|
||||
inputList = [frozenset([1, 2, 3])]
|
||||
output = jelly.unjelly(jelly.jelly(inputList))
|
||||
self.assertEqual(inputList, output)
|
||||
self.assertNotIdentical(inputList, output)
|
||||
|
||||
|
||||
def test_setSecurity(self):
|
||||
"""
|
||||
By default, C{set} objects should be allowed by
|
||||
L{jelly.SecurityOptions}. If not allowed, L{jelly.unjelly} should raise
|
||||
L{jelly.InsecureJelly} when trying to unjelly it.
|
||||
"""
|
||||
inputList = [set([1, 2, 3])]
|
||||
self._testSecurity(inputList, "set")
|
||||
|
||||
|
||||
def test_frozensetSecurity(self):
|
||||
"""
|
||||
By default, C{frozenset} objects should be allowed by
|
||||
L{jelly.SecurityOptions}. If not allowed, L{jelly.unjelly} should raise
|
||||
L{jelly.InsecureJelly} when trying to unjelly it.
|
||||
"""
|
||||
inputList = [frozenset([1, 2, 3])]
|
||||
self._testSecurity(inputList, "frozenset")
|
||||
|
||||
|
||||
def test_oldSets(self):
|
||||
"""
|
||||
Test jellying C{sets.Set}: it should serialize to the same thing as
|
||||
C{set} jelly, and be unjellied as C{set} if available.
|
||||
"""
|
||||
inputList = [jelly._sets.Set([1, 2, 3])]
|
||||
inputJelly = jelly.jelly(inputList)
|
||||
self.assertEqual(inputJelly, jelly.jelly([set([1, 2, 3])]))
|
||||
output = jelly.unjelly(inputJelly)
|
||||
# Even if the class is different, it should coerce to the same list
|
||||
self.assertEqual(list(inputList[0]), list(output[0]))
|
||||
if set is jelly._sets.Set:
|
||||
self.assertIsInstance(output[0], jelly._sets.Set)
|
||||
else:
|
||||
self.assertIsInstance(output[0], set)
|
||||
|
||||
|
||||
def test_oldImmutableSets(self):
|
||||
"""
|
||||
Test jellying C{sets.ImmutableSet}: it should serialize to the same
|
||||
thing as C{frozenset} jelly, and be unjellied as C{frozenset} if
|
||||
available.
|
||||
"""
|
||||
inputList = [jelly._sets.ImmutableSet([1, 2, 3])]
|
||||
inputJelly = jelly.jelly(inputList)
|
||||
self.assertEqual(inputJelly, jelly.jelly([frozenset([1, 2, 3])]))
|
||||
output = jelly.unjelly(inputJelly)
|
||||
# Even if the class is different, it should coerce to the same list
|
||||
self.assertEqual(list(inputList[0]), list(output[0]))
|
||||
if frozenset is jelly._sets.ImmutableSet:
|
||||
self.assertIsInstance(output[0], jelly._sets.ImmutableSet)
|
||||
else:
|
||||
self.assertIsInstance(output[0], frozenset)
|
||||
|
||||
|
||||
def test_simple(self):
|
||||
"""
|
||||
Simplest test case.
|
||||
"""
|
||||
self.failUnless(SimpleJellyTest('a', 'b').isTheSameAs(
|
||||
SimpleJellyTest('a', 'b')))
|
||||
a = SimpleJellyTest(1, 2)
|
||||
cereal = jelly.jelly(a)
|
||||
b = jelly.unjelly(cereal)
|
||||
self.failUnless(a.isTheSameAs(b))
|
||||
|
||||
|
||||
def test_identity(self):
|
||||
"""
|
||||
Test to make sure that objects retain identity properly.
|
||||
"""
|
||||
x = []
|
||||
y = (x)
|
||||
x.append(y)
|
||||
x.append(y)
|
||||
self.assertIdentical(x[0], x[1])
|
||||
self.assertIdentical(x[0][0], x)
|
||||
s = jelly.jelly(x)
|
||||
z = jelly.unjelly(s)
|
||||
self.assertIdentical(z[0], z[1])
|
||||
self.assertIdentical(z[0][0], z)
|
||||
|
||||
|
||||
def test_unicode(self):
|
||||
x = unicode('blah')
|
||||
y = jelly.unjelly(jelly.jelly(x))
|
||||
self.assertEqual(x, y)
|
||||
self.assertEqual(type(x), type(y))
|
||||
|
||||
|
||||
def test_stressReferences(self):
|
||||
reref = []
|
||||
toplevelTuple = ({'list': reref}, reref)
|
||||
reref.append(toplevelTuple)
|
||||
s = jelly.jelly(toplevelTuple)
|
||||
z = jelly.unjelly(s)
|
||||
self.assertIdentical(z[0]['list'], z[1])
|
||||
self.assertIdentical(z[0]['list'][0], z)
|
||||
|
||||
|
||||
def test_moreReferences(self):
|
||||
a = []
|
||||
t = (a,)
|
||||
a.append((t,))
|
||||
s = jelly.jelly(t)
|
||||
z = jelly.unjelly(s)
|
||||
self.assertIdentical(z[0][0][0], z)
|
||||
|
||||
|
||||
def test_typeSecurity(self):
|
||||
"""
|
||||
Test for type-level security of serialization.
|
||||
"""
|
||||
taster = jelly.SecurityOptions()
|
||||
dct = jelly.jelly({})
|
||||
self.assertRaises(jelly.InsecureJelly, jelly.unjelly, dct, taster)
|
||||
|
||||
|
||||
def test_newStyleClasses(self):
|
||||
j = jelly.jelly(D)
|
||||
uj = jelly.unjelly(D)
|
||||
self.assertIdentical(D, uj)
|
||||
|
||||
|
||||
def test_lotsaTypes(self):
|
||||
"""
|
||||
Test for all types currently supported in jelly
|
||||
"""
|
||||
a = A()
|
||||
jelly.unjelly(jelly.jelly(a))
|
||||
jelly.unjelly(jelly.jelly(a.amethod))
|
||||
items = [afunc, [1, 2, 3], not bool(1), bool(1), 'test', 20.3,
|
||||
(1, 2, 3), None, A, unittest, {'a': 1}, A.amethod]
|
||||
for i in items:
|
||||
self.assertEqual(i, jelly.unjelly(jelly.jelly(i)))
|
||||
|
||||
|
||||
def test_setState(self):
|
||||
global TupleState
|
||||
class TupleState:
|
||||
def __init__(self, other):
|
||||
self.other = other
|
||||
def __getstate__(self):
|
||||
return (self.other,)
|
||||
def __setstate__(self, state):
|
||||
self.other = state[0]
|
||||
def __hash__(self):
|
||||
return hash(self.other)
|
||||
a = A()
|
||||
t1 = TupleState(a)
|
||||
t2 = TupleState(a)
|
||||
t3 = TupleState((t1, t2))
|
||||
d = {t1: t1, t2: t2, t3: t3, "t3": t3}
|
||||
t3prime = jelly.unjelly(jelly.jelly(d))["t3"]
|
||||
self.assertIdentical(t3prime.other[0].other, t3prime.other[1].other)
|
||||
|
||||
|
||||
def test_classSecurity(self):
|
||||
"""
|
||||
Test for class-level security of serialization.
|
||||
"""
|
||||
taster = jelly.SecurityOptions()
|
||||
taster.allowInstancesOf(A, B)
|
||||
a = A()
|
||||
b = B()
|
||||
c = C()
|
||||
# add a little complexity to the data
|
||||
a.b = b
|
||||
a.c = c
|
||||
# and a backreference
|
||||
a.x = b
|
||||
b.c = c
|
||||
# first, a friendly insecure serialization
|
||||
friendly = jelly.jelly(a, taster)
|
||||
x = jelly.unjelly(friendly, taster)
|
||||
self.assertIsInstance(x.c, jelly.Unpersistable)
|
||||
# now, a malicious one
|
||||
mean = jelly.jelly(a)
|
||||
self.assertRaises(jelly.InsecureJelly, jelly.unjelly, mean, taster)
|
||||
self.assertIdentical(x.x, x.b, "Identity mismatch")
|
||||
# test class serialization
|
||||
friendly = jelly.jelly(A, taster)
|
||||
x = jelly.unjelly(friendly, taster)
|
||||
self.assertIdentical(x, A, "A came back: %s" % x)
|
||||
|
||||
|
||||
def test_unjellyable(self):
|
||||
"""
|
||||
Test that if Unjellyable is used to deserialize a jellied object,
|
||||
state comes out right.
|
||||
"""
|
||||
class JellyableTestClass(jelly.Jellyable):
|
||||
pass
|
||||
jelly.setUnjellyableForClass(JellyableTestClass, jelly.Unjellyable)
|
||||
input = JellyableTestClass()
|
||||
input.attribute = 'value'
|
||||
output = jelly.unjelly(jelly.jelly(input))
|
||||
self.assertEqual(output.attribute, 'value')
|
||||
self.assertIsInstance(output, jelly.Unjellyable)
|
||||
|
||||
|
||||
def test_persistentStorage(self):
|
||||
perst = [{}, 1]
|
||||
def persistentStore(obj, jel, perst = perst):
|
||||
perst[1] = perst[1] + 1
|
||||
perst[0][perst[1]] = obj
|
||||
return str(perst[1])
|
||||
|
||||
def persistentLoad(pidstr, unj, perst = perst):
|
||||
pid = int(pidstr)
|
||||
return perst[0][pid]
|
||||
|
||||
a = SimpleJellyTest(1, 2)
|
||||
b = SimpleJellyTest(3, 4)
|
||||
c = SimpleJellyTest(5, 6)
|
||||
|
||||
a.b = b
|
||||
a.c = c
|
||||
c.b = b
|
||||
|
||||
jel = jelly.jelly(a, persistentStore = persistentStore)
|
||||
x = jelly.unjelly(jel, persistentLoad = persistentLoad)
|
||||
|
||||
self.assertIdentical(x.b, x.c.b)
|
||||
self.failUnless(perst[0], "persistentStore was not called.")
|
||||
self.assertIdentical(x.b, a.b, "Persistent storage identity failure.")
|
||||
|
||||
|
||||
def test_newStyleClassesAttributes(self):
|
||||
n = TestNode()
|
||||
n1 = TestNode(n)
|
||||
n11 = TestNode(n1)
|
||||
n2 = TestNode(n)
|
||||
# Jelly it
|
||||
jel = jelly.jelly(n)
|
||||
m = jelly.unjelly(jel)
|
||||
# Check that it has been restored ok
|
||||
self._check_newstyle(n, m)
|
||||
|
||||
|
||||
def _check_newstyle(self, a, b):
|
||||
self.assertEqual(a.id, b.id)
|
||||
self.assertEqual(a.classAttr, 4)
|
||||
self.assertEqual(b.classAttr, 4)
|
||||
self.assertEqual(len(a.children), len(b.children))
|
||||
for x, y in zip(a.children, b.children):
|
||||
self._check_newstyle(x, y)
|
||||
|
||||
|
||||
def test_referenceable(self):
|
||||
"""
|
||||
A L{pb.Referenceable} instance jellies to a structure which unjellies to
|
||||
a L{pb.RemoteReference}. The C{RemoteReference} has a I{luid} that
|
||||
matches up with the local object key in the L{pb.Broker} which sent the
|
||||
L{Referenceable}.
|
||||
"""
|
||||
ref = pb.Referenceable()
|
||||
jellyBroker = pb.Broker()
|
||||
jellyBroker.makeConnection(StringTransport())
|
||||
j = jelly.jelly(ref, invoker=jellyBroker)
|
||||
|
||||
unjellyBroker = pb.Broker()
|
||||
unjellyBroker.makeConnection(StringTransport())
|
||||
|
||||
uj = jelly.unjelly(j, invoker=unjellyBroker)
|
||||
self.assertIn(uj.luid, jellyBroker.localObjects)
|
||||
|
||||
|
||||
|
||||
class ClassA(pb.Copyable, pb.RemoteCopy):
|
||||
def __init__(self):
|
||||
self.ref = ClassB(self)
|
||||
|
||||
|
||||
|
||||
class ClassB(pb.Copyable, pb.RemoteCopy):
|
||||
def __init__(self, ref):
|
||||
self.ref = ref
|
||||
|
||||
|
||||
|
||||
class CircularReferenceTestCase(unittest.TestCase):
|
||||
"""
|
||||
Tests for circular references handling in the jelly/unjelly process.
|
||||
"""
|
||||
|
||||
def test_simpleCircle(self):
|
||||
jelly.setUnjellyableForClass(ClassA, ClassA)
|
||||
jelly.setUnjellyableForClass(ClassB, ClassB)
|
||||
a = jelly.unjelly(jelly.jelly(ClassA()))
|
||||
self.assertIdentical(a.ref.ref, a,
|
||||
"Identity not preserved in circular reference")
|
||||
|
||||
|
||||
def test_circleWithInvoker(self):
|
||||
class DummyInvokerClass:
|
||||
pass
|
||||
dummyInvoker = DummyInvokerClass()
|
||||
dummyInvoker.serializingPerspective = None
|
||||
a0 = ClassA()
|
||||
jelly.setUnjellyableForClass(ClassA, ClassA)
|
||||
jelly.setUnjellyableForClass(ClassB, ClassB)
|
||||
j = jelly.jelly(a0, invoker=dummyInvoker)
|
||||
a1 = jelly.unjelly(j)
|
||||
self.failUnlessIdentical(a1.ref.ref, a1,
|
||||
"Identity not preserved in circular reference")
|
||||
|
||||
|
||||
def test_set(self):
|
||||
"""
|
||||
Check that a C{set} can contain a circular reference and be serialized
|
||||
and unserialized without losing the reference.
|
||||
"""
|
||||
s = set()
|
||||
a = SimpleJellyTest(s, None)
|
||||
s.add(a)
|
||||
res = jelly.unjelly(jelly.jelly(a))
|
||||
self.assertIsInstance(res.x, set)
|
||||
self.assertEqual(list(res.x), [res])
|
||||
|
||||
|
||||
def test_frozenset(self):
|
||||
"""
|
||||
Check that a C{frozenset} can contain a circular reference and be
|
||||
serializeserialized without losing the reference.
|
||||
"""
|
||||
a = SimpleJellyTest(None, None)
|
||||
s = frozenset([a])
|
||||
a.x = s
|
||||
res = jelly.unjelly(jelly.jelly(a))
|
||||
self.assertIsInstance(res.x, frozenset)
|
||||
self.assertEqual(list(res.x), [res])
|
||||
|
|
@ -0,0 +1,445 @@
|
|||
# Copyright (c) 2005 Divmod, Inc.
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.python.lockfile}.
|
||||
"""
|
||||
|
||||
import os, errno
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.python import lockfile
|
||||
from twisted.python.runtime import platform
|
||||
|
||||
skipKill = None
|
||||
if platform.isWindows():
|
||||
try:
|
||||
from win32api import OpenProcess
|
||||
import pywintypes
|
||||
except ImportError:
|
||||
skipKill = ("On windows, lockfile.kill is not implemented in the "
|
||||
"absence of win32api and/or pywintypes.")
|
||||
|
||||
class UtilTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for the helper functions used to implement L{FilesystemLock}.
|
||||
"""
|
||||
def test_symlinkEEXIST(self):
|
||||
"""
|
||||
L{lockfile.symlink} raises L{OSError} with C{errno} set to L{EEXIST}
|
||||
when an attempt is made to create a symlink which already exists.
|
||||
"""
|
||||
name = self.mktemp()
|
||||
lockfile.symlink('foo', name)
|
||||
exc = self.assertRaises(OSError, lockfile.symlink, 'foo', name)
|
||||
self.assertEqual(exc.errno, errno.EEXIST)
|
||||
|
||||
|
||||
def test_symlinkEIOWindows(self):
|
||||
"""
|
||||
L{lockfile.symlink} raises L{OSError} with C{errno} set to L{EIO} when
|
||||
the underlying L{rename} call fails with L{EIO}.
|
||||
|
||||
Renaming a file on Windows may fail if the target of the rename is in
|
||||
the process of being deleted (directory deletion appears not to be
|
||||
atomic).
|
||||
"""
|
||||
name = self.mktemp()
|
||||
def fakeRename(src, dst):
|
||||
raise IOError(errno.EIO, None)
|
||||
self.patch(lockfile, 'rename', fakeRename)
|
||||
exc = self.assertRaises(IOError, lockfile.symlink, name, "foo")
|
||||
self.assertEqual(exc.errno, errno.EIO)
|
||||
if not platform.isWindows():
|
||||
test_symlinkEIOWindows.skip = (
|
||||
"special rename EIO handling only necessary and correct on "
|
||||
"Windows.")
|
||||
|
||||
|
||||
def test_readlinkENOENT(self):
|
||||
"""
|
||||
L{lockfile.readlink} raises L{OSError} with C{errno} set to L{ENOENT}
|
||||
when an attempt is made to read a symlink which does not exist.
|
||||
"""
|
||||
name = self.mktemp()
|
||||
exc = self.assertRaises(OSError, lockfile.readlink, name)
|
||||
self.assertEqual(exc.errno, errno.ENOENT)
|
||||
|
||||
|
||||
def test_readlinkEACCESWindows(self):
|
||||
"""
|
||||
L{lockfile.readlink} raises L{OSError} with C{errno} set to L{EACCES}
|
||||
on Windows when the underlying file open attempt fails with C{EACCES}.
|
||||
|
||||
Opening a file on Windows may fail if the path is inside a directory
|
||||
which is in the process of being deleted (directory deletion appears
|
||||
not to be atomic).
|
||||
"""
|
||||
name = self.mktemp()
|
||||
def fakeOpen(path, mode):
|
||||
raise IOError(errno.EACCES, None)
|
||||
self.patch(lockfile, '_open', fakeOpen)
|
||||
exc = self.assertRaises(IOError, lockfile.readlink, name)
|
||||
self.assertEqual(exc.errno, errno.EACCES)
|
||||
if not platform.isWindows():
|
||||
test_readlinkEACCESWindows.skip = (
|
||||
"special readlink EACCES handling only necessary and correct on "
|
||||
"Windows.")
|
||||
|
||||
|
||||
def test_kill(self):
|
||||
"""
|
||||
L{lockfile.kill} returns without error if passed the PID of a
|
||||
process which exists and signal C{0}.
|
||||
"""
|
||||
lockfile.kill(os.getpid(), 0)
|
||||
test_kill.skip = skipKill
|
||||
|
||||
|
||||
def test_killESRCH(self):
|
||||
"""
|
||||
L{lockfile.kill} raises L{OSError} with errno of L{ESRCH} if
|
||||
passed a PID which does not correspond to any process.
|
||||
"""
|
||||
# Hopefully there is no process with PID 2 ** 31 - 1
|
||||
exc = self.assertRaises(OSError, lockfile.kill, 2 ** 31 - 1, 0)
|
||||
self.assertEqual(exc.errno, errno.ESRCH)
|
||||
test_killESRCH.skip = skipKill
|
||||
|
||||
|
||||
def test_noKillCall(self):
|
||||
"""
|
||||
Verify that when L{lockfile.kill} does end up as None (e.g. on Windows
|
||||
without pywin32), it doesn't end up being called and raising a
|
||||
L{TypeError}.
|
||||
"""
|
||||
self.patch(lockfile, "kill", None)
|
||||
fl = lockfile.FilesystemLock(self.mktemp())
|
||||
fl.lock()
|
||||
self.assertFalse(fl.lock())
|
||||
|
||||
|
||||
|
||||
class LockingTestCase(unittest.TestCase):
|
||||
def _symlinkErrorTest(self, errno):
|
||||
def fakeSymlink(source, dest):
|
||||
raise OSError(errno, None)
|
||||
self.patch(lockfile, 'symlink', fakeSymlink)
|
||||
|
||||
lockf = self.mktemp()
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
exc = self.assertRaises(OSError, lock.lock)
|
||||
self.assertEqual(exc.errno, errno)
|
||||
|
||||
|
||||
def test_symlinkError(self):
|
||||
"""
|
||||
An exception raised by C{symlink} other than C{EEXIST} is passed up to
|
||||
the caller of L{FilesystemLock.lock}.
|
||||
"""
|
||||
self._symlinkErrorTest(errno.ENOSYS)
|
||||
|
||||
|
||||
def test_symlinkErrorPOSIX(self):
|
||||
"""
|
||||
An L{OSError} raised by C{symlink} on a POSIX platform with an errno of
|
||||
C{EACCES} or C{EIO} is passed to the caller of L{FilesystemLock.lock}.
|
||||
|
||||
On POSIX, unlike on Windows, these are unexpected errors which cannot
|
||||
be handled by L{FilesystemLock}.
|
||||
"""
|
||||
self._symlinkErrorTest(errno.EACCES)
|
||||
self._symlinkErrorTest(errno.EIO)
|
||||
if platform.isWindows():
|
||||
test_symlinkErrorPOSIX.skip = (
|
||||
"POSIX-specific error propagation not expected on Windows.")
|
||||
|
||||
|
||||
def test_cleanlyAcquire(self):
|
||||
"""
|
||||
If the lock has never been held, it can be acquired and the C{clean}
|
||||
and C{locked} attributes are set to C{True}.
|
||||
"""
|
||||
lockf = self.mktemp()
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
self.assertTrue(lock.lock())
|
||||
self.assertTrue(lock.clean)
|
||||
self.assertTrue(lock.locked)
|
||||
|
||||
|
||||
def test_cleanlyRelease(self):
|
||||
"""
|
||||
If a lock is released cleanly, it can be re-acquired and the C{clean}
|
||||
and C{locked} attributes are set to C{True}.
|
||||
"""
|
||||
lockf = self.mktemp()
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
self.assertTrue(lock.lock())
|
||||
lock.unlock()
|
||||
self.assertFalse(lock.locked)
|
||||
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
self.assertTrue(lock.lock())
|
||||
self.assertTrue(lock.clean)
|
||||
self.assertTrue(lock.locked)
|
||||
|
||||
|
||||
def test_cannotLockLocked(self):
|
||||
"""
|
||||
If a lock is currently locked, it cannot be locked again.
|
||||
"""
|
||||
lockf = self.mktemp()
|
||||
firstLock = lockfile.FilesystemLock(lockf)
|
||||
self.assertTrue(firstLock.lock())
|
||||
|
||||
secondLock = lockfile.FilesystemLock(lockf)
|
||||
self.assertFalse(secondLock.lock())
|
||||
self.assertFalse(secondLock.locked)
|
||||
|
||||
|
||||
def test_uncleanlyAcquire(self):
|
||||
"""
|
||||
If a lock was held by a process which no longer exists, it can be
|
||||
acquired, the C{clean} attribute is set to C{False}, and the
|
||||
C{locked} attribute is set to C{True}.
|
||||
"""
|
||||
owner = 12345
|
||||
|
||||
def fakeKill(pid, signal):
|
||||
if signal != 0:
|
||||
raise OSError(errno.EPERM, None)
|
||||
if pid == owner:
|
||||
raise OSError(errno.ESRCH, None)
|
||||
|
||||
lockf = self.mktemp()
|
||||
self.patch(lockfile, 'kill', fakeKill)
|
||||
lockfile.symlink(str(owner), lockf)
|
||||
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
self.assertTrue(lock.lock())
|
||||
self.assertFalse(lock.clean)
|
||||
self.assertTrue(lock.locked)
|
||||
|
||||
self.assertEqual(lockfile.readlink(lockf), str(os.getpid()))
|
||||
|
||||
|
||||
def test_lockReleasedBeforeCheck(self):
|
||||
"""
|
||||
If the lock is initially held but then released before it can be
|
||||
examined to determine if the process which held it still exists, it is
|
||||
acquired and the C{clean} and C{locked} attributes are set to C{True}.
|
||||
"""
|
||||
def fakeReadlink(name):
|
||||
# Pretend to be another process releasing the lock.
|
||||
lockfile.rmlink(lockf)
|
||||
# Fall back to the real implementation of readlink.
|
||||
readlinkPatch.restore()
|
||||
return lockfile.readlink(name)
|
||||
readlinkPatch = self.patch(lockfile, 'readlink', fakeReadlink)
|
||||
|
||||
def fakeKill(pid, signal):
|
||||
if signal != 0:
|
||||
raise OSError(errno.EPERM, None)
|
||||
if pid == 43125:
|
||||
raise OSError(errno.ESRCH, None)
|
||||
self.patch(lockfile, 'kill', fakeKill)
|
||||
|
||||
lockf = self.mktemp()
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
lockfile.symlink(str(43125), lockf)
|
||||
self.assertTrue(lock.lock())
|
||||
self.assertTrue(lock.clean)
|
||||
self.assertTrue(lock.locked)
|
||||
|
||||
|
||||
def test_lockReleasedDuringAcquireSymlink(self):
|
||||
"""
|
||||
If the lock is released while an attempt is made to acquire
|
||||
it, the lock attempt fails and C{FilesystemLock.lock} returns
|
||||
C{False}. This can happen on Windows when L{lockfile.symlink}
|
||||
fails with L{IOError} of C{EIO} because another process is in
|
||||
the middle of a call to L{os.rmdir} (implemented in terms of
|
||||
RemoveDirectory) which is not atomic.
|
||||
"""
|
||||
def fakeSymlink(src, dst):
|
||||
# While another process id doing os.rmdir which the Windows
|
||||
# implementation of rmlink does, a rename call will fail with EIO.
|
||||
raise OSError(errno.EIO, None)
|
||||
|
||||
self.patch(lockfile, 'symlink', fakeSymlink)
|
||||
|
||||
lockf = self.mktemp()
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
self.assertFalse(lock.lock())
|
||||
self.assertFalse(lock.locked)
|
||||
if not platform.isWindows():
|
||||
test_lockReleasedDuringAcquireSymlink.skip = (
|
||||
"special rename EIO handling only necessary and correct on "
|
||||
"Windows.")
|
||||
|
||||
|
||||
def test_lockReleasedDuringAcquireReadlink(self):
|
||||
"""
|
||||
If the lock is initially held but is released while an attempt
|
||||
is made to acquire it, the lock attempt fails and
|
||||
L{FilesystemLock.lock} returns C{False}.
|
||||
"""
|
||||
def fakeReadlink(name):
|
||||
# While another process is doing os.rmdir which the
|
||||
# Windows implementation of rmlink does, a readlink call
|
||||
# will fail with EACCES.
|
||||
raise IOError(errno.EACCES, None)
|
||||
readlinkPatch = self.patch(lockfile, 'readlink', fakeReadlink)
|
||||
|
||||
lockf = self.mktemp()
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
lockfile.symlink(str(43125), lockf)
|
||||
self.assertFalse(lock.lock())
|
||||
self.assertFalse(lock.locked)
|
||||
if not platform.isWindows():
|
||||
test_lockReleasedDuringAcquireReadlink.skip = (
|
||||
"special readlink EACCES handling only necessary and correct on "
|
||||
"Windows.")
|
||||
|
||||
|
||||
def _readlinkErrorTest(self, exceptionType, errno):
|
||||
def fakeReadlink(name):
|
||||
raise exceptionType(errno, None)
|
||||
self.patch(lockfile, 'readlink', fakeReadlink)
|
||||
|
||||
lockf = self.mktemp()
|
||||
|
||||
# Make it appear locked so it has to use readlink
|
||||
lockfile.symlink(str(43125), lockf)
|
||||
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
exc = self.assertRaises(exceptionType, lock.lock)
|
||||
self.assertEqual(exc.errno, errno)
|
||||
self.assertFalse(lock.locked)
|
||||
|
||||
|
||||
def test_readlinkError(self):
|
||||
"""
|
||||
An exception raised by C{readlink} other than C{ENOENT} is passed up to
|
||||
the caller of L{FilesystemLock.lock}.
|
||||
"""
|
||||
self._readlinkErrorTest(OSError, errno.ENOSYS)
|
||||
self._readlinkErrorTest(IOError, errno.ENOSYS)
|
||||
|
||||
|
||||
def test_readlinkErrorPOSIX(self):
|
||||
"""
|
||||
Any L{IOError} raised by C{readlink} on a POSIX platform passed to the
|
||||
caller of L{FilesystemLock.lock}.
|
||||
|
||||
On POSIX, unlike on Windows, these are unexpected errors which cannot
|
||||
be handled by L{FilesystemLock}.
|
||||
"""
|
||||
self._readlinkErrorTest(IOError, errno.ENOSYS)
|
||||
self._readlinkErrorTest(IOError, errno.EACCES)
|
||||
if platform.isWindows():
|
||||
test_readlinkErrorPOSIX.skip = (
|
||||
"POSIX-specific error propagation not expected on Windows.")
|
||||
|
||||
|
||||
def test_lockCleanedUpConcurrently(self):
|
||||
"""
|
||||
If a second process cleans up the lock after a first one checks the
|
||||
lock and finds that no process is holding it, the first process does
|
||||
not fail when it tries to clean up the lock.
|
||||
"""
|
||||
def fakeRmlink(name):
|
||||
rmlinkPatch.restore()
|
||||
# Pretend to be another process cleaning up the lock.
|
||||
lockfile.rmlink(lockf)
|
||||
# Fall back to the real implementation of rmlink.
|
||||
return lockfile.rmlink(name)
|
||||
rmlinkPatch = self.patch(lockfile, 'rmlink', fakeRmlink)
|
||||
|
||||
def fakeKill(pid, signal):
|
||||
if signal != 0:
|
||||
raise OSError(errno.EPERM, None)
|
||||
if pid == 43125:
|
||||
raise OSError(errno.ESRCH, None)
|
||||
self.patch(lockfile, 'kill', fakeKill)
|
||||
|
||||
lockf = self.mktemp()
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
lockfile.symlink(str(43125), lockf)
|
||||
self.assertTrue(lock.lock())
|
||||
self.assertTrue(lock.clean)
|
||||
self.assertTrue(lock.locked)
|
||||
|
||||
|
||||
def test_rmlinkError(self):
|
||||
"""
|
||||
An exception raised by L{rmlink} other than C{ENOENT} is passed up
|
||||
to the caller of L{FilesystemLock.lock}.
|
||||
"""
|
||||
def fakeRmlink(name):
|
||||
raise OSError(errno.ENOSYS, None)
|
||||
self.patch(lockfile, 'rmlink', fakeRmlink)
|
||||
|
||||
def fakeKill(pid, signal):
|
||||
if signal != 0:
|
||||
raise OSError(errno.EPERM, None)
|
||||
if pid == 43125:
|
||||
raise OSError(errno.ESRCH, None)
|
||||
self.patch(lockfile, 'kill', fakeKill)
|
||||
|
||||
lockf = self.mktemp()
|
||||
|
||||
# Make it appear locked so it has to use readlink
|
||||
lockfile.symlink(str(43125), lockf)
|
||||
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
exc = self.assertRaises(OSError, lock.lock)
|
||||
self.assertEqual(exc.errno, errno.ENOSYS)
|
||||
self.assertFalse(lock.locked)
|
||||
|
||||
|
||||
def test_killError(self):
|
||||
"""
|
||||
If L{kill} raises an exception other than L{OSError} with errno set to
|
||||
C{ESRCH}, the exception is passed up to the caller of
|
||||
L{FilesystemLock.lock}.
|
||||
"""
|
||||
def fakeKill(pid, signal):
|
||||
raise OSError(errno.EPERM, None)
|
||||
self.patch(lockfile, 'kill', fakeKill)
|
||||
|
||||
lockf = self.mktemp()
|
||||
|
||||
# Make it appear locked so it has to use readlink
|
||||
lockfile.symlink(str(43125), lockf)
|
||||
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
exc = self.assertRaises(OSError, lock.lock)
|
||||
self.assertEqual(exc.errno, errno.EPERM)
|
||||
self.assertFalse(lock.locked)
|
||||
|
||||
|
||||
def test_unlockOther(self):
|
||||
"""
|
||||
L{FilesystemLock.unlock} raises L{ValueError} if called for a lock
|
||||
which is held by a different process.
|
||||
"""
|
||||
lockf = self.mktemp()
|
||||
lockfile.symlink(str(os.getpid() + 1), lockf)
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
self.assertRaises(ValueError, lock.unlock)
|
||||
|
||||
|
||||
def test_isLocked(self):
|
||||
"""
|
||||
L{isLocked} returns C{True} if the named lock is currently locked,
|
||||
C{False} otherwise.
|
||||
"""
|
||||
lockf = self.mktemp()
|
||||
self.assertFalse(lockfile.isLocked(lockf))
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
self.assertTrue(lock.lock())
|
||||
self.assertTrue(lockfile.isLocked(lockf))
|
||||
lock.unlock()
|
||||
self.assertFalse(lockfile.isLocked(lockf))
|
||||
882
Linux_i686/lib/python2.7/site-packages/twisted/test/test_log.py
Normal file
882
Linux_i686/lib/python2.7/site-packages/twisted/test/test_log.py
Normal file
|
|
@ -0,0 +1,882 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.python.log}.
|
||||
"""
|
||||
|
||||
from __future__ import division, absolute_import, print_function
|
||||
|
||||
from twisted.python.compat import _PY3, NativeStringIO as StringIO
|
||||
|
||||
import os, sys, time, logging, warnings, calendar
|
||||
|
||||
|
||||
from twisted.trial import unittest
|
||||
|
||||
from twisted.python import log, failure
|
||||
|
||||
|
||||
class FakeWarning(Warning):
|
||||
"""
|
||||
A unique L{Warning} subclass used by tests for interactions of
|
||||
L{twisted.python.log} with the L{warnings} module.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class LogTest(unittest.SynchronousTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.catcher = []
|
||||
self.observer = self.catcher.append
|
||||
log.addObserver(self.observer)
|
||||
self.addCleanup(log.removeObserver, self.observer)
|
||||
|
||||
|
||||
def testObservation(self):
|
||||
catcher = self.catcher
|
||||
log.msg("test", testShouldCatch=True)
|
||||
i = catcher.pop()
|
||||
self.assertEqual(i["message"][0], "test")
|
||||
self.assertEqual(i["testShouldCatch"], True)
|
||||
self.assertIn("time", i)
|
||||
self.assertEqual(len(catcher), 0)
|
||||
|
||||
|
||||
def testContext(self):
|
||||
catcher = self.catcher
|
||||
log.callWithContext({"subsystem": "not the default",
|
||||
"subsubsystem": "a",
|
||||
"other": "c"},
|
||||
log.callWithContext,
|
||||
{"subsubsystem": "b"}, log.msg, "foo", other="d")
|
||||
i = catcher.pop()
|
||||
self.assertEqual(i['subsubsystem'], 'b')
|
||||
self.assertEqual(i['subsystem'], 'not the default')
|
||||
self.assertEqual(i['other'], 'd')
|
||||
self.assertEqual(i['message'][0], 'foo')
|
||||
|
||||
def testErrors(self):
|
||||
for e, ig in [("hello world","hello world"),
|
||||
(KeyError(), KeyError),
|
||||
(failure.Failure(RuntimeError()), RuntimeError)]:
|
||||
log.err(e)
|
||||
i = self.catcher.pop()
|
||||
self.assertEqual(i['isError'], 1)
|
||||
self.flushLoggedErrors(ig)
|
||||
|
||||
def testErrorsWithWhy(self):
|
||||
for e, ig in [("hello world","hello world"),
|
||||
(KeyError(), KeyError),
|
||||
(failure.Failure(RuntimeError()), RuntimeError)]:
|
||||
log.err(e, 'foobar')
|
||||
i = self.catcher.pop()
|
||||
self.assertEqual(i['isError'], 1)
|
||||
self.assertEqual(i['why'], 'foobar')
|
||||
self.flushLoggedErrors(ig)
|
||||
|
||||
|
||||
def test_erroneousErrors(self):
|
||||
"""
|
||||
Exceptions raised by log observers are logged but the observer which
|
||||
raised the exception remains registered with the publisher. These
|
||||
exceptions do not prevent the event from being sent to other observers
|
||||
registered with the publisher.
|
||||
"""
|
||||
L1 = []
|
||||
L2 = []
|
||||
def broken(events):
|
||||
1 // 0
|
||||
|
||||
for observer in [L1.append, broken, L2.append]:
|
||||
log.addObserver(observer)
|
||||
self.addCleanup(log.removeObserver, observer)
|
||||
|
||||
for i in range(3):
|
||||
# Reset the lists for simpler comparison.
|
||||
L1[:] = []
|
||||
L2[:] = []
|
||||
|
||||
# Send out the event which will break one of the observers.
|
||||
log.msg("Howdy, y'all.")
|
||||
|
||||
# The broken observer should have caused this to be logged.
|
||||
excs = self.flushLoggedErrors(ZeroDivisionError)
|
||||
del self.catcher[:]
|
||||
self.assertEqual(len(excs), 1)
|
||||
|
||||
# Both other observers should have seen the message.
|
||||
self.assertEqual(len(L1), 2)
|
||||
self.assertEqual(len(L2), 2)
|
||||
|
||||
# The order is slightly wrong here. The first event should be
|
||||
# delivered to all observers; then, errors should be delivered.
|
||||
self.assertEqual(L1[1]['message'], ("Howdy, y'all.",))
|
||||
self.assertEqual(L2[0]['message'], ("Howdy, y'all.",))
|
||||
|
||||
|
||||
def test_doubleErrorDoesNotRemoveObserver(self):
|
||||
"""
|
||||
If logging causes an error, make sure that if logging the fact that
|
||||
logging failed also causes an error, the log observer is not removed.
|
||||
"""
|
||||
events = []
|
||||
errors = []
|
||||
publisher = log.LogPublisher()
|
||||
|
||||
class FailingObserver(object):
|
||||
calls = 0
|
||||
def log(self, msg, **kwargs):
|
||||
# First call raises RuntimeError:
|
||||
self.calls += 1
|
||||
if self.calls < 2:
|
||||
raise RuntimeError("Failure #%s" % (self.calls,))
|
||||
else:
|
||||
events.append(msg)
|
||||
|
||||
observer = FailingObserver()
|
||||
publisher.addObserver(observer.log)
|
||||
self.assertEqual(publisher.observers, [observer.log])
|
||||
|
||||
try:
|
||||
# When observer throws, the publisher attempts to log the fact by
|
||||
# calling self._err()... which also fails with recursion error:
|
||||
oldError = publisher._err
|
||||
|
||||
def failingErr(failure, why, **kwargs):
|
||||
errors.append(failure.value)
|
||||
raise RuntimeError("Fake recursion error")
|
||||
|
||||
publisher._err = failingErr
|
||||
publisher.msg("error in first observer")
|
||||
finally:
|
||||
publisher._err = oldError
|
||||
# Observer should still exist; we do this in finally since before
|
||||
# bug was fixed the test would fail due to uncaught exception, so
|
||||
# we want failing assert too in that case:
|
||||
self.assertEqual(publisher.observers, [observer.log])
|
||||
|
||||
# The next message should succeed:
|
||||
publisher.msg("but this should succeed")
|
||||
|
||||
self.assertEqual(observer.calls, 2)
|
||||
self.assertEqual(len(events), 1)
|
||||
self.assertEqual(events[0]['message'], ("but this should succeed",))
|
||||
self.assertEqual(len(errors), 1)
|
||||
self.assertIsInstance(errors[0], RuntimeError)
|
||||
|
||||
|
||||
def test_showwarning(self):
|
||||
"""
|
||||
L{twisted.python.log.showwarning} emits the warning as a message
|
||||
to the Twisted logging system.
|
||||
"""
|
||||
publisher = log.LogPublisher()
|
||||
publisher.addObserver(self.observer)
|
||||
|
||||
publisher.showwarning(
|
||||
FakeWarning("unique warning message"), FakeWarning,
|
||||
"warning-filename.py", 27)
|
||||
event = self.catcher.pop()
|
||||
self.assertEqual(
|
||||
event['format'] % event,
|
||||
'warning-filename.py:27: twisted.test.test_log.FakeWarning: '
|
||||
'unique warning message')
|
||||
self.assertEqual(self.catcher, [])
|
||||
|
||||
# Python 2.6 requires that any function used to override the
|
||||
# warnings.showwarning API accept a "line" parameter or a
|
||||
# deprecation warning is emitted.
|
||||
publisher.showwarning(
|
||||
FakeWarning("unique warning message"), FakeWarning,
|
||||
"warning-filename.py", 27, line=object())
|
||||
event = self.catcher.pop()
|
||||
self.assertEqual(
|
||||
event['format'] % event,
|
||||
'warning-filename.py:27: twisted.test.test_log.FakeWarning: '
|
||||
'unique warning message')
|
||||
self.assertEqual(self.catcher, [])
|
||||
|
||||
|
||||
def test_warningToFile(self):
|
||||
"""
|
||||
L{twisted.python.log.showwarning} passes warnings with an explicit file
|
||||
target on to the underlying Python warning system.
|
||||
"""
|
||||
# log.showwarning depends on _oldshowwarning being set, which only
|
||||
# happens in startLogging(), which doesn't happen if you're not
|
||||
# running under trial. So this test only passes by accident of runner
|
||||
# environment.
|
||||
if log._oldshowwarning is None:
|
||||
raise unittest.SkipTest("Currently this test only runs under trial.")
|
||||
message = "another unique message"
|
||||
category = FakeWarning
|
||||
filename = "warning-filename.py"
|
||||
lineno = 31
|
||||
|
||||
output = StringIO()
|
||||
log.showwarning(message, category, filename, lineno, file=output)
|
||||
|
||||
self.assertEqual(
|
||||
output.getvalue(),
|
||||
warnings.formatwarning(message, category, filename, lineno))
|
||||
|
||||
# In Python 2.6, warnings.showwarning accepts a "line" argument which
|
||||
# gives the source line the warning message is to include.
|
||||
if sys.version_info >= (2, 6):
|
||||
line = "hello world"
|
||||
output = StringIO()
|
||||
log.showwarning(message, category, filename, lineno, file=output,
|
||||
line=line)
|
||||
|
||||
self.assertEqual(
|
||||
output.getvalue(),
|
||||
warnings.formatwarning(message, category, filename, lineno,
|
||||
line))
|
||||
|
||||
|
||||
def test_publisherReportsBrokenObserversPrivately(self):
|
||||
"""
|
||||
Log publisher does not use the global L{log.err} when reporting broken
|
||||
observers.
|
||||
"""
|
||||
errors = []
|
||||
def logError(eventDict):
|
||||
if eventDict.get("isError"):
|
||||
errors.append(eventDict["failure"].value)
|
||||
|
||||
def fail(eventDict):
|
||||
raise RuntimeError("test_publisherLocalyReportsBrokenObservers")
|
||||
|
||||
publisher = log.LogPublisher()
|
||||
publisher.addObserver(logError)
|
||||
publisher.addObserver(fail)
|
||||
|
||||
publisher.msg("Hello!")
|
||||
self.assertEqual(publisher.observers, [logError, fail])
|
||||
self.assertEqual(len(errors), 1)
|
||||
self.assertIsInstance(errors[0], RuntimeError)
|
||||
|
||||
|
||||
|
||||
class FakeFile(list):
|
||||
def write(self, bytes):
|
||||
self.append(bytes)
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
class EvilStr:
|
||||
def __str__(self):
|
||||
1//0
|
||||
|
||||
class EvilRepr:
|
||||
def __str__(self):
|
||||
return "Happy Evil Repr"
|
||||
def __repr__(self):
|
||||
1//0
|
||||
|
||||
class EvilReprStr(EvilStr, EvilRepr):
|
||||
pass
|
||||
|
||||
class LogPublisherTestCaseMixin:
|
||||
def setUp(self):
|
||||
"""
|
||||
Add a log observer which records log events in C{self.out}. Also,
|
||||
make sure the default string encoding is ASCII so that
|
||||
L{testSingleUnicode} can test the behavior of logging unencodable
|
||||
unicode messages.
|
||||
"""
|
||||
self.out = FakeFile()
|
||||
self.lp = log.LogPublisher()
|
||||
self.flo = log.FileLogObserver(self.out)
|
||||
self.lp.addObserver(self.flo.emit)
|
||||
|
||||
try:
|
||||
str(u'\N{VULGAR FRACTION ONE HALF}')
|
||||
except UnicodeEncodeError:
|
||||
# This is the behavior we want - don't change anything.
|
||||
self._origEncoding = None
|
||||
else:
|
||||
if _PY3:
|
||||
self._origEncoding = None
|
||||
return
|
||||
reload(sys)
|
||||
self._origEncoding = sys.getdefaultencoding()
|
||||
sys.setdefaultencoding('ascii')
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
"""
|
||||
Verify that everything written to the fake file C{self.out} was a
|
||||
C{str}. Also, restore the default string encoding to its previous
|
||||
setting, if it was modified by L{setUp}.
|
||||
"""
|
||||
for chunk in self.out:
|
||||
self.failUnless(isinstance(chunk, str), "%r was not a string" % (chunk,))
|
||||
|
||||
if self._origEncoding is not None:
|
||||
sys.setdefaultencoding(self._origEncoding)
|
||||
del sys.setdefaultencoding
|
||||
|
||||
|
||||
|
||||
class LogPublisherTestCase(LogPublisherTestCaseMixin, unittest.SynchronousTestCase):
|
||||
def testSingleString(self):
|
||||
self.lp.msg("Hello, world.")
|
||||
self.assertEqual(len(self.out), 1)
|
||||
|
||||
|
||||
def testMultipleString(self):
|
||||
# Test some stupid behavior that will be deprecated real soon.
|
||||
# If you are reading this and trying to learn how the logging
|
||||
# system works, *do not use this feature*.
|
||||
self.lp.msg("Hello, ", "world.")
|
||||
self.assertEqual(len(self.out), 1)
|
||||
|
||||
|
||||
def test_singleUnicode(self):
|
||||
"""
|
||||
L{log.LogPublisher.msg} does not accept non-ASCII Unicode on Python 2,
|
||||
logging an error instead.
|
||||
|
||||
On Python 3, where Unicode is default message type, the message is
|
||||
logged normally.
|
||||
"""
|
||||
message = u"Hello, \N{VULGAR FRACTION ONE HALF} world."
|
||||
self.lp.msg(message)
|
||||
self.assertEqual(len(self.out), 1)
|
||||
if _PY3:
|
||||
self.assertIn(message, self.out[0])
|
||||
else:
|
||||
self.assertIn('with str error', self.out[0])
|
||||
self.assertIn('UnicodeEncodeError', self.out[0])
|
||||
|
||||
|
||||
|
||||
class FileObserverTestCase(LogPublisherTestCaseMixin, unittest.SynchronousTestCase):
|
||||
"""
|
||||
Tests for L{log.FileObserver}.
|
||||
"""
|
||||
def _getTimezoneOffsetTest(self, tzname, daylightOffset, standardOffset):
|
||||
"""
|
||||
Verify that L{getTimezoneOffset} produces the expected offset for a
|
||||
certain timezone both when daylight saving time is in effect and when
|
||||
it is not.
|
||||
|
||||
@param tzname: The name of a timezone to exercise.
|
||||
@type tzname: L{bytes}
|
||||
|
||||
@param daylightOffset: The number of seconds west of UTC the timezone
|
||||
should be when daylight saving time is in effect.
|
||||
@type daylightOffset: L{int}
|
||||
|
||||
@param standardOffset: The number of seconds west of UTC the timezone
|
||||
should be when daylight saving time is not in effect.
|
||||
@type standardOffset: L{int}
|
||||
"""
|
||||
if getattr(time, 'tzset', None) is None:
|
||||
raise unittest.SkipTest(
|
||||
"Platform cannot change timezone, cannot verify correct offsets "
|
||||
"in well-known timezones.")
|
||||
|
||||
originalTimezone = os.environ.get('TZ', None)
|
||||
try:
|
||||
os.environ['TZ'] = tzname
|
||||
time.tzset()
|
||||
|
||||
# The behavior of mktime depends on the current timezone setting.
|
||||
# So only do this after changing the timezone.
|
||||
|
||||
# Compute a POSIX timestamp for a certain date and time that is
|
||||
# known to occur at a time when daylight saving time is in effect.
|
||||
localDaylightTuple = (2006, 6, 30, 0, 0, 0, 4, 181, 1)
|
||||
daylight = time.mktime(localDaylightTuple)
|
||||
|
||||
# Compute a POSIX timestamp for a certain date and time that is
|
||||
# known to occur at a time when daylight saving time is not in
|
||||
# effect.
|
||||
localStandardTuple = (2007, 1, 31, 0, 0, 0, 2, 31, 0)
|
||||
standard = time.mktime(localStandardTuple)
|
||||
|
||||
self.assertEqual(
|
||||
(self.flo.getTimezoneOffset(daylight),
|
||||
self.flo.getTimezoneOffset(standard)),
|
||||
(daylightOffset, standardOffset))
|
||||
finally:
|
||||
if originalTimezone is None:
|
||||
del os.environ['TZ']
|
||||
else:
|
||||
os.environ['TZ'] = originalTimezone
|
||||
time.tzset()
|
||||
|
||||
|
||||
def test_getTimezoneOffsetWestOfUTC(self):
|
||||
"""
|
||||
Attempt to verify that L{FileLogObserver.getTimezoneOffset} returns
|
||||
correct values for the current C{TZ} environment setting for at least
|
||||
some cases. This test method exercises a timezone that is west of UTC
|
||||
(and should produce positive results).
|
||||
"""
|
||||
self._getTimezoneOffsetTest("America/New_York", 14400, 18000)
|
||||
|
||||
|
||||
def test_getTimezoneOffsetEastOfUTC(self):
|
||||
"""
|
||||
Attempt to verify that L{FileLogObserver.getTimezoneOffset} returns
|
||||
correct values for the current C{TZ} environment setting for at least
|
||||
some cases. This test method exercises a timezone that is east of UTC
|
||||
(and should produce negative results).
|
||||
"""
|
||||
self._getTimezoneOffsetTest("Europe/Berlin", -7200, -3600)
|
||||
|
||||
|
||||
def test_getTimezoneOffsetWithoutDaylightSavingTime(self):
|
||||
"""
|
||||
Attempt to verify that L{FileLogObserver.getTimezoneOffset} returns
|
||||
correct values for the current C{TZ} environment setting for at least
|
||||
some cases. This test method exercises a timezone that does not use
|
||||
daylight saving time at all (so both summer and winter time test values
|
||||
should have the same offset).
|
||||
"""
|
||||
# Test a timezone that doesn't have DST. mktime() implementations
|
||||
# available for testing seem happy to produce results for this even
|
||||
# though it's not entirely valid.
|
||||
self._getTimezoneOffsetTest("Africa/Johannesburg", -7200, -7200)
|
||||
|
||||
|
||||
def test_timeFormatting(self):
|
||||
"""
|
||||
Test the method of L{FileLogObserver} which turns a timestamp into a
|
||||
human-readable string.
|
||||
"""
|
||||
when = calendar.timegm((2001, 2, 3, 4, 5, 6, 7, 8, 0))
|
||||
|
||||
# Pretend to be in US/Eastern for a moment
|
||||
self.flo.getTimezoneOffset = lambda when: 18000
|
||||
self.assertEqual(self.flo.formatTime(when), '2001-02-02 23:05:06-0500')
|
||||
|
||||
# Okay now we're in Eastern Europe somewhere
|
||||
self.flo.getTimezoneOffset = lambda when: -3600
|
||||
self.assertEqual(self.flo.formatTime(when), '2001-02-03 05:05:06+0100')
|
||||
|
||||
# And off in the Pacific or someplace like that
|
||||
self.flo.getTimezoneOffset = lambda when: -39600
|
||||
self.assertEqual(self.flo.formatTime(when), '2001-02-03 15:05:06+1100')
|
||||
|
||||
# One of those weird places with a half-hour offset timezone
|
||||
self.flo.getTimezoneOffset = lambda when: 5400
|
||||
self.assertEqual(self.flo.formatTime(when), '2001-02-03 02:35:06-0130')
|
||||
|
||||
# Half-hour offset in the other direction
|
||||
self.flo.getTimezoneOffset = lambda when: -5400
|
||||
self.assertEqual(self.flo.formatTime(when), '2001-02-03 05:35:06+0130')
|
||||
|
||||
# Test an offset which is between 0 and 60 minutes to make sure the
|
||||
# sign comes out properly in that case.
|
||||
self.flo.getTimezoneOffset = lambda when: 1800
|
||||
self.assertEqual(self.flo.formatTime(when), '2001-02-03 03:35:06-0030')
|
||||
|
||||
# Test an offset between 0 and 60 minutes in the other direction.
|
||||
self.flo.getTimezoneOffset = lambda when: -1800
|
||||
self.assertEqual(self.flo.formatTime(when), '2001-02-03 04:35:06+0030')
|
||||
|
||||
# If a strftime-format string is present on the logger, it should
|
||||
# use that instead. Note we don't assert anything about day, hour
|
||||
# or minute because we cannot easily control what time.strftime()
|
||||
# thinks the local timezone is.
|
||||
self.flo.timeFormat = '%Y %m'
|
||||
self.assertEqual(self.flo.formatTime(when), '2001 02')
|
||||
|
||||
|
||||
def test_microsecondTimestampFormatting(self):
|
||||
"""
|
||||
L{FileLogObserver.formatTime} supports a value of C{timeFormat} which
|
||||
includes C{"%f"}, a L{datetime}-only format specifier for microseconds.
|
||||
"""
|
||||
self.flo.timeFormat = '%f'
|
||||
self.assertEqual("600000", self.flo.formatTime(12345.6))
|
||||
|
||||
|
||||
def test_loggingAnObjectWithBroken__str__(self):
|
||||
#HELLO, MCFLY
|
||||
self.lp.msg(EvilStr())
|
||||
self.assertEqual(len(self.out), 1)
|
||||
# Logging system shouldn't need to crap itself for this trivial case
|
||||
self.assertNotIn('UNFORMATTABLE', self.out[0])
|
||||
|
||||
|
||||
def test_formattingAnObjectWithBroken__str__(self):
|
||||
self.lp.msg(format='%(blat)s', blat=EvilStr())
|
||||
self.assertEqual(len(self.out), 1)
|
||||
self.assertIn('Invalid format string or unformattable object', self.out[0])
|
||||
|
||||
|
||||
def test_brokenSystem__str__(self):
|
||||
self.lp.msg('huh', system=EvilStr())
|
||||
self.assertEqual(len(self.out), 1)
|
||||
self.assertIn('Invalid format string or unformattable object', self.out[0])
|
||||
|
||||
|
||||
def test_formattingAnObjectWithBroken__repr__Indirect(self):
|
||||
self.lp.msg(format='%(blat)s', blat=[EvilRepr()])
|
||||
self.assertEqual(len(self.out), 1)
|
||||
self.assertIn('UNFORMATTABLE OBJECT', self.out[0])
|
||||
|
||||
|
||||
def test_systemWithBroker__repr__Indirect(self):
|
||||
self.lp.msg('huh', system=[EvilRepr()])
|
||||
self.assertEqual(len(self.out), 1)
|
||||
self.assertIn('UNFORMATTABLE OBJECT', self.out[0])
|
||||
|
||||
|
||||
def test_simpleBrokenFormat(self):
|
||||
self.lp.msg(format='hooj %s %s', blat=1)
|
||||
self.assertEqual(len(self.out), 1)
|
||||
self.assertIn('Invalid format string or unformattable object', self.out[0])
|
||||
|
||||
|
||||
def test_ridiculousFormat(self):
|
||||
self.lp.msg(format=42, blat=1)
|
||||
self.assertEqual(len(self.out), 1)
|
||||
self.assertIn('Invalid format string or unformattable object', self.out[0])
|
||||
|
||||
|
||||
def test_evilFormat__repr__And__str__(self):
|
||||
self.lp.msg(format=EvilReprStr(), blat=1)
|
||||
self.assertEqual(len(self.out), 1)
|
||||
self.assertIn('PATHOLOGICAL', self.out[0])
|
||||
|
||||
|
||||
def test_strangeEventDict(self):
|
||||
"""
|
||||
This kind of eventDict used to fail silently, so test it does.
|
||||
"""
|
||||
self.lp.msg(message='', isError=False)
|
||||
self.assertEqual(len(self.out), 0)
|
||||
|
||||
|
||||
def _startLoggingCleanup(self):
|
||||
"""
|
||||
Cleanup after a startLogging() call that mutates the hell out of some
|
||||
global state.
|
||||
"""
|
||||
origShowwarnings = log._oldshowwarning
|
||||
self.addCleanup(setattr, log, "_oldshowwarning", origShowwarnings)
|
||||
self.addCleanup(setattr, sys, 'stdout', sys.stdout)
|
||||
self.addCleanup(setattr, sys, 'stderr', sys.stderr)
|
||||
|
||||
def test_startLogging(self):
|
||||
"""
|
||||
startLogging() installs FileLogObserver and overrides sys.stdout and
|
||||
sys.stderr.
|
||||
"""
|
||||
origStdout, origStderr = sys.stdout, sys.stderr
|
||||
self._startLoggingCleanup()
|
||||
# When done with test, reset stdout and stderr to current values:
|
||||
fakeFile = StringIO()
|
||||
observer = log.startLogging(fakeFile)
|
||||
self.addCleanup(observer.stop)
|
||||
log.msg("Hello!")
|
||||
self.assertIn("Hello!", fakeFile.getvalue())
|
||||
self.assertIsInstance(sys.stdout, log.StdioOnnaStick)
|
||||
self.assertEqual(sys.stdout.isError, False)
|
||||
encoding = getattr(origStdout, "encoding", None)
|
||||
if not encoding:
|
||||
encoding = sys.getdefaultencoding()
|
||||
self.assertEqual(sys.stdout.encoding, encoding)
|
||||
self.assertIsInstance(sys.stderr, log.StdioOnnaStick)
|
||||
self.assertEqual(sys.stderr.isError, True)
|
||||
encoding = getattr(origStderr, "encoding", None)
|
||||
if not encoding:
|
||||
encoding = sys.getdefaultencoding()
|
||||
self.assertEqual(sys.stderr.encoding, encoding)
|
||||
|
||||
|
||||
def test_startLoggingTwice(self):
|
||||
"""
|
||||
There are some obscure error conditions that can occur when logging is
|
||||
started twice. See http://twistedmatrix.com/trac/ticket/3289 for more
|
||||
information.
|
||||
"""
|
||||
self._startLoggingCleanup()
|
||||
# The bug is particular to the way that the t.p.log 'global' function
|
||||
# handle stdout. If we use our own stream, the error doesn't occur. If
|
||||
# we use our own LogPublisher, the error doesn't occur.
|
||||
sys.stdout = StringIO()
|
||||
|
||||
def showError(eventDict):
|
||||
if eventDict['isError']:
|
||||
sys.__stdout__.write(eventDict['failure'].getTraceback())
|
||||
|
||||
log.addObserver(showError)
|
||||
self.addCleanup(log.removeObserver, showError)
|
||||
observer = log.startLogging(sys.stdout)
|
||||
self.addCleanup(observer.stop)
|
||||
# At this point, we expect that sys.stdout is a StdioOnnaStick object.
|
||||
self.assertIsInstance(sys.stdout, log.StdioOnnaStick)
|
||||
fakeStdout = sys.stdout
|
||||
observer = log.startLogging(sys.stdout)
|
||||
self.assertIdentical(sys.stdout, fakeStdout)
|
||||
|
||||
|
||||
def test_startLoggingOverridesWarning(self):
|
||||
"""
|
||||
startLogging() overrides global C{warnings.showwarning} such that
|
||||
warnings go to Twisted log observers.
|
||||
"""
|
||||
self._startLoggingCleanup()
|
||||
# Ugggh, pretend we're starting from newly imported module:
|
||||
log._oldshowwarning = None
|
||||
fakeFile = StringIO()
|
||||
observer = log.startLogging(fakeFile)
|
||||
self.addCleanup(observer.stop)
|
||||
warnings.warn("hello!")
|
||||
output = fakeFile.getvalue()
|
||||
self.assertIn("UserWarning: hello!", output)
|
||||
|
||||
|
||||
|
||||
class PythonLoggingObserverTestCase(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Test the bridge with python logging module.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.out = StringIO()
|
||||
|
||||
rootLogger = logging.getLogger("")
|
||||
self.originalLevel = rootLogger.getEffectiveLevel()
|
||||
rootLogger.setLevel(logging.DEBUG)
|
||||
self.hdlr = logging.StreamHandler(self.out)
|
||||
fmt = logging.Formatter(logging.BASIC_FORMAT)
|
||||
self.hdlr.setFormatter(fmt)
|
||||
rootLogger.addHandler(self.hdlr)
|
||||
|
||||
self.lp = log.LogPublisher()
|
||||
self.obs = log.PythonLoggingObserver()
|
||||
self.lp.addObserver(self.obs.emit)
|
||||
|
||||
def tearDown(self):
|
||||
rootLogger = logging.getLogger("")
|
||||
rootLogger.removeHandler(self.hdlr)
|
||||
rootLogger.setLevel(self.originalLevel)
|
||||
logging.shutdown()
|
||||
|
||||
def test_singleString(self):
|
||||
"""
|
||||
Test simple output, and default log level.
|
||||
"""
|
||||
self.lp.msg("Hello, world.")
|
||||
self.assertIn("Hello, world.", self.out.getvalue())
|
||||
self.assertIn("INFO", self.out.getvalue())
|
||||
|
||||
def test_errorString(self):
|
||||
"""
|
||||
Test error output.
|
||||
"""
|
||||
self.lp.msg(failure=failure.Failure(ValueError("That is bad.")), isError=True)
|
||||
self.assertIn("ERROR", self.out.getvalue())
|
||||
|
||||
def test_formatString(self):
|
||||
"""
|
||||
Test logging with a format.
|
||||
"""
|
||||
self.lp.msg(format="%(bar)s oo %(foo)s", bar="Hello", foo="world")
|
||||
self.assertIn("Hello oo world", self.out.getvalue())
|
||||
|
||||
def test_customLevel(self):
|
||||
"""
|
||||
Test the logLevel keyword for customizing level used.
|
||||
"""
|
||||
self.lp.msg("Spam egg.", logLevel=logging.DEBUG)
|
||||
self.assertIn("Spam egg.", self.out.getvalue())
|
||||
self.assertIn("DEBUG", self.out.getvalue())
|
||||
self.out.seek(0, 0)
|
||||
self.out.truncate()
|
||||
self.lp.msg("Foo bar.", logLevel=logging.WARNING)
|
||||
self.assertIn("Foo bar.", self.out.getvalue())
|
||||
self.assertIn("WARNING", self.out.getvalue())
|
||||
|
||||
def test_strangeEventDict(self):
|
||||
"""
|
||||
Verify that an event dictionary which is not an error and has an empty
|
||||
message isn't recorded.
|
||||
"""
|
||||
self.lp.msg(message='', isError=False)
|
||||
self.assertEqual(self.out.getvalue(), '')
|
||||
|
||||
|
||||
class PythonLoggingIntegrationTestCase(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Test integration of python logging bridge.
|
||||
"""
|
||||
def test_startStopObserver(self):
|
||||
"""
|
||||
Test that start and stop methods of the observer actually register
|
||||
and unregister to the log system.
|
||||
"""
|
||||
oldAddObserver = log.addObserver
|
||||
oldRemoveObserver = log.removeObserver
|
||||
l = []
|
||||
try:
|
||||
log.addObserver = l.append
|
||||
log.removeObserver = l.remove
|
||||
obs = log.PythonLoggingObserver()
|
||||
obs.start()
|
||||
self.assertEqual(l[0], obs.emit)
|
||||
obs.stop()
|
||||
self.assertEqual(len(l), 0)
|
||||
finally:
|
||||
log.addObserver = oldAddObserver
|
||||
log.removeObserver = oldRemoveObserver
|
||||
|
||||
def test_inheritance(self):
|
||||
"""
|
||||
Test that we can inherit L{log.PythonLoggingObserver} and use super:
|
||||
that's basically a validation that L{log.PythonLoggingObserver} is
|
||||
new-style class.
|
||||
"""
|
||||
class MyObserver(log.PythonLoggingObserver):
|
||||
def emit(self, eventDict):
|
||||
super(MyObserver, self).emit(eventDict)
|
||||
obs = MyObserver()
|
||||
l = []
|
||||
oldEmit = log.PythonLoggingObserver.emit
|
||||
try:
|
||||
log.PythonLoggingObserver.emit = l.append
|
||||
obs.emit('foo')
|
||||
self.assertEqual(len(l), 1)
|
||||
finally:
|
||||
log.PythonLoggingObserver.emit = oldEmit
|
||||
|
||||
|
||||
|
||||
class DefaultObserverTestCase(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Test the default observer.
|
||||
"""
|
||||
|
||||
def test_failureLogger(self):
|
||||
"""
|
||||
The reason argument passed to log.err() appears in the report
|
||||
generated by DefaultObserver.
|
||||
"""
|
||||
self.catcher = []
|
||||
self.observer = self.catcher.append
|
||||
log.addObserver(self.observer)
|
||||
self.addCleanup(log.removeObserver, self.observer)
|
||||
|
||||
obs = log.DefaultObserver()
|
||||
obs.stderr = StringIO()
|
||||
obs.start()
|
||||
self.addCleanup(obs.stop)
|
||||
|
||||
reason = "The reason."
|
||||
log.err(Exception(), reason)
|
||||
errors = self.flushLoggedErrors()
|
||||
|
||||
self.assertIn(reason, obs.stderr.getvalue())
|
||||
self.assertEqual(len(errors), 1)
|
||||
|
||||
|
||||
|
||||
class StdioOnnaStickTestCase(unittest.SynchronousTestCase):
|
||||
"""
|
||||
StdioOnnaStick should act like the normal sys.stdout object.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.resultLogs = []
|
||||
log.addObserver(self.resultLogs.append)
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
log.removeObserver(self.resultLogs.append)
|
||||
|
||||
|
||||
def getLogMessages(self):
|
||||
return ["".join(d['message']) for d in self.resultLogs]
|
||||
|
||||
|
||||
def test_write(self):
|
||||
"""
|
||||
Writing to a StdioOnnaStick instance results in Twisted log messages.
|
||||
|
||||
Log messages are generated every time a '\n' is encountered.
|
||||
"""
|
||||
stdio = log.StdioOnnaStick()
|
||||
stdio.write("Hello there\nThis is a test")
|
||||
self.assertEqual(self.getLogMessages(), ["Hello there"])
|
||||
stdio.write("!\n")
|
||||
self.assertEqual(self.getLogMessages(), ["Hello there", "This is a test!"])
|
||||
|
||||
|
||||
def test_metadata(self):
|
||||
"""
|
||||
The log messages written by StdioOnnaStick have printed=1 keyword, and
|
||||
by default are not errors.
|
||||
"""
|
||||
stdio = log.StdioOnnaStick()
|
||||
stdio.write("hello\n")
|
||||
self.assertEqual(self.resultLogs[0]['isError'], False)
|
||||
self.assertEqual(self.resultLogs[0]['printed'], True)
|
||||
|
||||
|
||||
def test_writeLines(self):
|
||||
"""
|
||||
Writing lines to a StdioOnnaStick results in Twisted log messages.
|
||||
"""
|
||||
stdio = log.StdioOnnaStick()
|
||||
stdio.writelines(["log 1", "log 2"])
|
||||
self.assertEqual(self.getLogMessages(), ["log 1", "log 2"])
|
||||
|
||||
|
||||
def test_print(self):
|
||||
"""
|
||||
When StdioOnnaStick is set as sys.stdout, prints become log messages.
|
||||
"""
|
||||
oldStdout = sys.stdout
|
||||
sys.stdout = log.StdioOnnaStick()
|
||||
self.addCleanup(setattr, sys, "stdout", oldStdout)
|
||||
print("This", end=" ")
|
||||
print("is a test")
|
||||
self.assertEqual(self.getLogMessages(), ["This is a test"])
|
||||
|
||||
|
||||
def test_error(self):
|
||||
"""
|
||||
StdioOnnaStick created with isError=True log messages as errors.
|
||||
"""
|
||||
stdio = log.StdioOnnaStick(isError=True)
|
||||
stdio.write("log 1\n")
|
||||
self.assertEqual(self.resultLogs[0]['isError'], True)
|
||||
|
||||
|
||||
def test_unicode(self):
|
||||
"""
|
||||
StdioOnnaStick converts unicode prints to byte strings on Python 2, in
|
||||
order to be compatible with the normal stdout/stderr objects.
|
||||
|
||||
On Python 3, the prints are left unmodified.
|
||||
"""
|
||||
unicodeString = u"Hello, \N{VULGAR FRACTION ONE HALF} world."
|
||||
stdio = log.StdioOnnaStick(encoding="utf-8")
|
||||
self.assertEqual(stdio.encoding, "utf-8")
|
||||
stdio.write(unicodeString + u"\n")
|
||||
stdio.writelines([u"Also, " + unicodeString])
|
||||
oldStdout = sys.stdout
|
||||
sys.stdout = stdio
|
||||
self.addCleanup(setattr, sys, "stdout", oldStdout)
|
||||
# This should go to the log, utf-8 encoded too:
|
||||
print(unicodeString)
|
||||
if _PY3:
|
||||
self.assertEqual(self.getLogMessages(),
|
||||
[unicodeString,
|
||||
u"Also, " + unicodeString,
|
||||
unicodeString])
|
||||
else:
|
||||
self.assertEqual(self.getLogMessages(),
|
||||
[unicodeString.encode("utf-8"),
|
||||
(u"Also, " + unicodeString).encode("utf-8"),
|
||||
unicodeString.encode("utf-8")])
|
||||
|
|
@ -0,0 +1,320 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
import os, time, stat, errno
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.python import logfile, runtime
|
||||
|
||||
|
||||
class LogFileTestCase(unittest.TestCase):
|
||||
"""
|
||||
Test the rotating log file.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.dir = self.mktemp()
|
||||
os.makedirs(self.dir)
|
||||
self.name = "test.log"
|
||||
self.path = os.path.join(self.dir, self.name)
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
"""
|
||||
Restore back write rights on created paths: if tests modified the
|
||||
rights, that will allow the paths to be removed easily afterwards.
|
||||
"""
|
||||
os.chmod(self.dir, 0777)
|
||||
if os.path.exists(self.path):
|
||||
os.chmod(self.path, 0777)
|
||||
|
||||
|
||||
def testWriting(self):
|
||||
log = logfile.LogFile(self.name, self.dir)
|
||||
log.write("123")
|
||||
log.write("456")
|
||||
log.flush()
|
||||
log.write("7890")
|
||||
log.close()
|
||||
|
||||
f = open(self.path, "r")
|
||||
self.assertEqual(f.read(), "1234567890")
|
||||
f.close()
|
||||
|
||||
def testRotation(self):
|
||||
# this logfile should rotate every 10 bytes
|
||||
log = logfile.LogFile(self.name, self.dir, rotateLength=10)
|
||||
|
||||
# test automatic rotation
|
||||
log.write("123")
|
||||
log.write("4567890")
|
||||
log.write("1" * 11)
|
||||
self.assert_(os.path.exists("%s.1" % self.path))
|
||||
self.assert_(not os.path.exists("%s.2" % self.path))
|
||||
log.write('')
|
||||
self.assert_(os.path.exists("%s.1" % self.path))
|
||||
self.assert_(os.path.exists("%s.2" % self.path))
|
||||
self.assert_(not os.path.exists("%s.3" % self.path))
|
||||
log.write("3")
|
||||
self.assert_(not os.path.exists("%s.3" % self.path))
|
||||
|
||||
# test manual rotation
|
||||
log.rotate()
|
||||
self.assert_(os.path.exists("%s.3" % self.path))
|
||||
self.assert_(not os.path.exists("%s.4" % self.path))
|
||||
log.close()
|
||||
|
||||
self.assertEqual(log.listLogs(), [1, 2, 3])
|
||||
|
||||
def testAppend(self):
|
||||
log = logfile.LogFile(self.name, self.dir)
|
||||
log.write("0123456789")
|
||||
log.close()
|
||||
|
||||
log = logfile.LogFile(self.name, self.dir)
|
||||
self.assertEqual(log.size, 10)
|
||||
self.assertEqual(log._file.tell(), log.size)
|
||||
log.write("abc")
|
||||
self.assertEqual(log.size, 13)
|
||||
self.assertEqual(log._file.tell(), log.size)
|
||||
f = log._file
|
||||
f.seek(0, 0)
|
||||
self.assertEqual(f.read(), "0123456789abc")
|
||||
log.close()
|
||||
|
||||
def testLogReader(self):
|
||||
log = logfile.LogFile(self.name, self.dir)
|
||||
log.write("abc\n")
|
||||
log.write("def\n")
|
||||
log.rotate()
|
||||
log.write("ghi\n")
|
||||
log.flush()
|
||||
|
||||
# check reading logs
|
||||
self.assertEqual(log.listLogs(), [1])
|
||||
reader = log.getCurrentLog()
|
||||
reader._file.seek(0)
|
||||
self.assertEqual(reader.readLines(), ["ghi\n"])
|
||||
self.assertEqual(reader.readLines(), [])
|
||||
reader.close()
|
||||
reader = log.getLog(1)
|
||||
self.assertEqual(reader.readLines(), ["abc\n", "def\n"])
|
||||
self.assertEqual(reader.readLines(), [])
|
||||
reader.close()
|
||||
|
||||
# check getting illegal log readers
|
||||
self.assertRaises(ValueError, log.getLog, 2)
|
||||
self.assertRaises(TypeError, log.getLog, "1")
|
||||
|
||||
# check that log numbers are higher for older logs
|
||||
log.rotate()
|
||||
self.assertEqual(log.listLogs(), [1, 2])
|
||||
reader = log.getLog(1)
|
||||
reader._file.seek(0)
|
||||
self.assertEqual(reader.readLines(), ["ghi\n"])
|
||||
self.assertEqual(reader.readLines(), [])
|
||||
reader.close()
|
||||
reader = log.getLog(2)
|
||||
self.assertEqual(reader.readLines(), ["abc\n", "def\n"])
|
||||
self.assertEqual(reader.readLines(), [])
|
||||
reader.close()
|
||||
|
||||
def testModePreservation(self):
|
||||
"""
|
||||
Check rotated files have same permissions as original.
|
||||
"""
|
||||
f = open(self.path, "w").close()
|
||||
os.chmod(self.path, 0707)
|
||||
mode = os.stat(self.path)[stat.ST_MODE]
|
||||
log = logfile.LogFile(self.name, self.dir)
|
||||
log.write("abc")
|
||||
log.rotate()
|
||||
self.assertEqual(mode, os.stat(self.path)[stat.ST_MODE])
|
||||
|
||||
|
||||
def test_noPermission(self):
|
||||
"""
|
||||
Check it keeps working when permission on dir changes.
|
||||
"""
|
||||
log = logfile.LogFile(self.name, self.dir)
|
||||
log.write("abc")
|
||||
|
||||
# change permissions so rotation would fail
|
||||
os.chmod(self.dir, 0555)
|
||||
|
||||
# if this succeeds, chmod doesn't restrict us, so we can't
|
||||
# do the test
|
||||
try:
|
||||
f = open(os.path.join(self.dir,"xxx"), "w")
|
||||
except (OSError, IOError):
|
||||
pass
|
||||
else:
|
||||
f.close()
|
||||
return
|
||||
|
||||
log.rotate() # this should not fail
|
||||
|
||||
log.write("def")
|
||||
log.flush()
|
||||
|
||||
f = log._file
|
||||
self.assertEqual(f.tell(), 6)
|
||||
f.seek(0, 0)
|
||||
self.assertEqual(f.read(), "abcdef")
|
||||
log.close()
|
||||
|
||||
|
||||
def test_maxNumberOfLog(self):
|
||||
"""
|
||||
Test it respect the limit on the number of files when maxRotatedFiles
|
||||
is not None.
|
||||
"""
|
||||
log = logfile.LogFile(self.name, self.dir, rotateLength=10,
|
||||
maxRotatedFiles=3)
|
||||
log.write("1" * 11)
|
||||
log.write("2" * 11)
|
||||
self.failUnless(os.path.exists("%s.1" % self.path))
|
||||
|
||||
log.write("3" * 11)
|
||||
self.failUnless(os.path.exists("%s.2" % self.path))
|
||||
|
||||
log.write("4" * 11)
|
||||
self.failUnless(os.path.exists("%s.3" % self.path))
|
||||
self.assertEqual(file("%s.3" % self.path).read(), "1" * 11)
|
||||
|
||||
log.write("5" * 11)
|
||||
self.assertEqual(file("%s.3" % self.path).read(), "2" * 11)
|
||||
self.failUnless(not os.path.exists("%s.4" % self.path))
|
||||
|
||||
def test_fromFullPath(self):
|
||||
"""
|
||||
Test the fromFullPath method.
|
||||
"""
|
||||
log1 = logfile.LogFile(self.name, self.dir, 10, defaultMode=0777)
|
||||
log2 = logfile.LogFile.fromFullPath(self.path, 10, defaultMode=0777)
|
||||
self.assertEqual(log1.name, log2.name)
|
||||
self.assertEqual(os.path.abspath(log1.path), log2.path)
|
||||
self.assertEqual(log1.rotateLength, log2.rotateLength)
|
||||
self.assertEqual(log1.defaultMode, log2.defaultMode)
|
||||
|
||||
def test_defaultPermissions(self):
|
||||
"""
|
||||
Test the default permission of the log file: if the file exist, it
|
||||
should keep the permission.
|
||||
"""
|
||||
f = file(self.path, "w")
|
||||
os.chmod(self.path, 0707)
|
||||
currentMode = stat.S_IMODE(os.stat(self.path)[stat.ST_MODE])
|
||||
f.close()
|
||||
log1 = logfile.LogFile(self.name, self.dir)
|
||||
self.assertEqual(stat.S_IMODE(os.stat(self.path)[stat.ST_MODE]),
|
||||
currentMode)
|
||||
|
||||
|
||||
def test_specifiedPermissions(self):
|
||||
"""
|
||||
Test specifying the permissions used on the log file.
|
||||
"""
|
||||
log1 = logfile.LogFile(self.name, self.dir, defaultMode=0066)
|
||||
mode = stat.S_IMODE(os.stat(self.path)[stat.ST_MODE])
|
||||
if runtime.platform.isWindows():
|
||||
# The only thing we can get here is global read-only
|
||||
self.assertEqual(mode, 0444)
|
||||
else:
|
||||
self.assertEqual(mode, 0066)
|
||||
|
||||
|
||||
def test_reopen(self):
|
||||
"""
|
||||
L{logfile.LogFile.reopen} allows to rename the currently used file and
|
||||
make L{logfile.LogFile} create a new file.
|
||||
"""
|
||||
log1 = logfile.LogFile(self.name, self.dir)
|
||||
log1.write("hello1")
|
||||
savePath = os.path.join(self.dir, "save.log")
|
||||
os.rename(self.path, savePath)
|
||||
log1.reopen()
|
||||
log1.write("hello2")
|
||||
log1.close()
|
||||
|
||||
f = open(self.path, "r")
|
||||
self.assertEqual(f.read(), "hello2")
|
||||
f.close()
|
||||
f = open(savePath, "r")
|
||||
self.assertEqual(f.read(), "hello1")
|
||||
f.close()
|
||||
|
||||
if runtime.platform.isWindows():
|
||||
test_reopen.skip = "Can't test reopen on Windows"
|
||||
|
||||
|
||||
def test_nonExistentDir(self):
|
||||
"""
|
||||
Specifying an invalid directory to L{LogFile} raises C{IOError}.
|
||||
"""
|
||||
e = self.assertRaises(
|
||||
IOError, logfile.LogFile, self.name, 'this_dir_does_not_exist')
|
||||
self.assertEqual(e.errno, errno.ENOENT)
|
||||
|
||||
|
||||
|
||||
class RiggedDailyLogFile(logfile.DailyLogFile):
|
||||
_clock = 0.0
|
||||
|
||||
def _openFile(self):
|
||||
logfile.DailyLogFile._openFile(self)
|
||||
# rig the date to match _clock, not mtime
|
||||
self.lastDate = self.toDate()
|
||||
|
||||
def toDate(self, *args):
|
||||
if args:
|
||||
return time.gmtime(*args)[:3]
|
||||
return time.gmtime(self._clock)[:3]
|
||||
|
||||
class DailyLogFileTestCase(unittest.TestCase):
|
||||
"""
|
||||
Test rotating log file.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.dir = self.mktemp()
|
||||
os.makedirs(self.dir)
|
||||
self.name = "testdaily.log"
|
||||
self.path = os.path.join(self.dir, self.name)
|
||||
|
||||
|
||||
def testWriting(self):
|
||||
log = RiggedDailyLogFile(self.name, self.dir)
|
||||
log.write("123")
|
||||
log.write("456")
|
||||
log.flush()
|
||||
log.write("7890")
|
||||
log.close()
|
||||
|
||||
f = open(self.path, "r")
|
||||
self.assertEqual(f.read(), "1234567890")
|
||||
f.close()
|
||||
|
||||
def testRotation(self):
|
||||
# this logfile should rotate every 10 bytes
|
||||
log = RiggedDailyLogFile(self.name, self.dir)
|
||||
days = [(self.path + '.' + log.suffix(day * 86400)) for day in range(3)]
|
||||
|
||||
# test automatic rotation
|
||||
log._clock = 0.0 # 1970/01/01 00:00.00
|
||||
log.write("123")
|
||||
log._clock = 43200 # 1970/01/01 12:00.00
|
||||
log.write("4567890")
|
||||
log._clock = 86400 # 1970/01/02 00:00.00
|
||||
log.write("1" * 11)
|
||||
self.assert_(os.path.exists(days[0]))
|
||||
self.assert_(not os.path.exists(days[1]))
|
||||
log._clock = 172800 # 1970/01/03 00:00.00
|
||||
log.write('')
|
||||
self.assert_(os.path.exists(days[0]))
|
||||
self.assert_(os.path.exists(days[1]))
|
||||
self.assert_(not os.path.exists(days[2]))
|
||||
log._clock = 259199 # 1970/01/03 23:59.59
|
||||
log.write("3")
|
||||
self.assert_(not os.path.exists(days[2]))
|
||||
|
||||
|
|
@ -0,0 +1,431 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test case for L{twisted.protocols.loopback}.
|
||||
"""
|
||||
|
||||
from __future__ import division, absolute_import
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.python.compat import _PY3, intToBytes
|
||||
from twisted.trial import unittest
|
||||
from twisted.trial.util import suppress as SUPPRESS
|
||||
from twisted.protocols import basic, loopback
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.interfaces import IAddress, IPushProducer, IPullProducer
|
||||
from twisted.internet import reactor, interfaces
|
||||
|
||||
|
||||
class SimpleProtocol(basic.LineReceiver):
|
||||
def __init__(self):
|
||||
self.conn = defer.Deferred()
|
||||
self.lines = []
|
||||
self.connLost = []
|
||||
|
||||
def connectionMade(self):
|
||||
self.conn.callback(None)
|
||||
|
||||
def lineReceived(self, line):
|
||||
self.lines.append(line)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.connLost.append(reason)
|
||||
|
||||
|
||||
class DoomProtocol(SimpleProtocol):
|
||||
i = 0
|
||||
def lineReceived(self, line):
|
||||
self.i += 1
|
||||
if self.i < 4:
|
||||
# by this point we should have connection closed,
|
||||
# but just in case we didn't we won't ever send 'Hello 4'
|
||||
self.sendLine(b"Hello " + intToBytes(self.i))
|
||||
SimpleProtocol.lineReceived(self, line)
|
||||
if self.lines[-1] == b"Hello 3":
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
class LoopbackTestCaseMixin:
|
||||
def testRegularFunction(self):
|
||||
s = SimpleProtocol()
|
||||
c = SimpleProtocol()
|
||||
|
||||
def sendALine(result):
|
||||
s.sendLine(b"THIS IS LINE ONE!")
|
||||
s.transport.loseConnection()
|
||||
s.conn.addCallback(sendALine)
|
||||
|
||||
def check(ignored):
|
||||
self.assertEqual(c.lines, [b"THIS IS LINE ONE!"])
|
||||
self.assertEqual(len(s.connLost), 1)
|
||||
self.assertEqual(len(c.connLost), 1)
|
||||
d = defer.maybeDeferred(self.loopbackFunc, s, c)
|
||||
d.addCallback(check)
|
||||
return d
|
||||
|
||||
def testSneakyHiddenDoom(self):
|
||||
s = DoomProtocol()
|
||||
c = DoomProtocol()
|
||||
|
||||
def sendALine(result):
|
||||
s.sendLine(b"DOOM LINE")
|
||||
s.conn.addCallback(sendALine)
|
||||
|
||||
def check(ignored):
|
||||
self.assertEqual(s.lines, [b'Hello 1', b'Hello 2', b'Hello 3'])
|
||||
self.assertEqual(
|
||||
c.lines, [b'DOOM LINE', b'Hello 1', b'Hello 2', b'Hello 3'])
|
||||
self.assertEqual(len(s.connLost), 1)
|
||||
self.assertEqual(len(c.connLost), 1)
|
||||
d = defer.maybeDeferred(self.loopbackFunc, s, c)
|
||||
d.addCallback(check)
|
||||
return d
|
||||
|
||||
|
||||
|
||||
class LoopbackAsyncTestCase(LoopbackTestCaseMixin, unittest.TestCase):
|
||||
loopbackFunc = staticmethod(loopback.loopbackAsync)
|
||||
|
||||
|
||||
def test_makeConnection(self):
|
||||
"""
|
||||
Test that the client and server protocol both have makeConnection
|
||||
invoked on them by loopbackAsync.
|
||||
"""
|
||||
class TestProtocol(Protocol):
|
||||
transport = None
|
||||
def makeConnection(self, transport):
|
||||
self.transport = transport
|
||||
|
||||
server = TestProtocol()
|
||||
client = TestProtocol()
|
||||
loopback.loopbackAsync(server, client)
|
||||
self.failIfEqual(client.transport, None)
|
||||
self.failIfEqual(server.transport, None)
|
||||
|
||||
|
||||
def _hostpeertest(self, get, testServer):
|
||||
"""
|
||||
Test one of the permutations of client/server host/peer.
|
||||
"""
|
||||
class TestProtocol(Protocol):
|
||||
def makeConnection(self, transport):
|
||||
Protocol.makeConnection(self, transport)
|
||||
self.onConnection.callback(transport)
|
||||
|
||||
if testServer:
|
||||
server = TestProtocol()
|
||||
d = server.onConnection = Deferred()
|
||||
client = Protocol()
|
||||
else:
|
||||
server = Protocol()
|
||||
client = TestProtocol()
|
||||
d = client.onConnection = Deferred()
|
||||
|
||||
loopback.loopbackAsync(server, client)
|
||||
|
||||
def connected(transport):
|
||||
host = getattr(transport, get)()
|
||||
self.failUnless(IAddress.providedBy(host))
|
||||
|
||||
return d.addCallback(connected)
|
||||
|
||||
|
||||
def test_serverHost(self):
|
||||
"""
|
||||
Test that the server gets a transport with a properly functioning
|
||||
implementation of L{ITransport.getHost}.
|
||||
"""
|
||||
return self._hostpeertest("getHost", True)
|
||||
|
||||
|
||||
def test_serverPeer(self):
|
||||
"""
|
||||
Like C{test_serverHost} but for L{ITransport.getPeer}
|
||||
"""
|
||||
return self._hostpeertest("getPeer", True)
|
||||
|
||||
|
||||
def test_clientHost(self, get="getHost"):
|
||||
"""
|
||||
Test that the client gets a transport with a properly functioning
|
||||
implementation of L{ITransport.getHost}.
|
||||
"""
|
||||
return self._hostpeertest("getHost", False)
|
||||
|
||||
|
||||
def test_clientPeer(self):
|
||||
"""
|
||||
Like C{test_clientHost} but for L{ITransport.getPeer}.
|
||||
"""
|
||||
return self._hostpeertest("getPeer", False)
|
||||
|
||||
|
||||
def _greetingtest(self, write, testServer):
|
||||
"""
|
||||
Test one of the permutations of write/writeSequence client/server.
|
||||
|
||||
@param write: The name of the method to test, C{"write"} or
|
||||
C{"writeSequence"}.
|
||||
"""
|
||||
class GreeteeProtocol(Protocol):
|
||||
bytes = b""
|
||||
def dataReceived(self, bytes):
|
||||
self.bytes += bytes
|
||||
if self.bytes == b"bytes":
|
||||
self.received.callback(None)
|
||||
|
||||
class GreeterProtocol(Protocol):
|
||||
def connectionMade(self):
|
||||
if write == "write":
|
||||
self.transport.write(b"bytes")
|
||||
else:
|
||||
self.transport.writeSequence([b"byt", b"es"])
|
||||
|
||||
if testServer:
|
||||
server = GreeterProtocol()
|
||||
client = GreeteeProtocol()
|
||||
d = client.received = Deferred()
|
||||
else:
|
||||
server = GreeteeProtocol()
|
||||
d = server.received = Deferred()
|
||||
client = GreeterProtocol()
|
||||
|
||||
loopback.loopbackAsync(server, client)
|
||||
return d
|
||||
|
||||
|
||||
def test_clientGreeting(self):
|
||||
"""
|
||||
Test that on a connection where the client speaks first, the server
|
||||
receives the bytes sent by the client.
|
||||
"""
|
||||
return self._greetingtest("write", False)
|
||||
|
||||
|
||||
def test_clientGreetingSequence(self):
|
||||
"""
|
||||
Like C{test_clientGreeting}, but use C{writeSequence} instead of
|
||||
C{write} to issue the greeting.
|
||||
"""
|
||||
return self._greetingtest("writeSequence", False)
|
||||
|
||||
|
||||
def test_serverGreeting(self, write="write"):
|
||||
"""
|
||||
Test that on a connection where the server speaks first, the client
|
||||
receives the bytes sent by the server.
|
||||
"""
|
||||
return self._greetingtest("write", True)
|
||||
|
||||
|
||||
def test_serverGreetingSequence(self):
|
||||
"""
|
||||
Like C{test_serverGreeting}, but use C{writeSequence} instead of
|
||||
C{write} to issue the greeting.
|
||||
"""
|
||||
return self._greetingtest("writeSequence", True)
|
||||
|
||||
|
||||
def _producertest(self, producerClass):
|
||||
toProduce = list(map(intToBytes, range(0, 10)))
|
||||
|
||||
class ProducingProtocol(Protocol):
|
||||
def connectionMade(self):
|
||||
self.producer = producerClass(list(toProduce))
|
||||
self.producer.start(self.transport)
|
||||
|
||||
class ReceivingProtocol(Protocol):
|
||||
bytes = b""
|
||||
def dataReceived(self, data):
|
||||
self.bytes += data
|
||||
if self.bytes == b''.join(toProduce):
|
||||
self.received.callback((client, server))
|
||||
|
||||
server = ProducingProtocol()
|
||||
client = ReceivingProtocol()
|
||||
client.received = Deferred()
|
||||
|
||||
loopback.loopbackAsync(server, client)
|
||||
return client.received
|
||||
|
||||
|
||||
def test_pushProducer(self):
|
||||
"""
|
||||
Test a push producer registered against a loopback transport.
|
||||
"""
|
||||
@implementer(IPushProducer)
|
||||
class PushProducer(object):
|
||||
resumed = False
|
||||
|
||||
def __init__(self, toProduce):
|
||||
self.toProduce = toProduce
|
||||
|
||||
def resumeProducing(self):
|
||||
self.resumed = True
|
||||
|
||||
def start(self, consumer):
|
||||
self.consumer = consumer
|
||||
consumer.registerProducer(self, True)
|
||||
self._produceAndSchedule()
|
||||
|
||||
def _produceAndSchedule(self):
|
||||
if self.toProduce:
|
||||
self.consumer.write(self.toProduce.pop(0))
|
||||
reactor.callLater(0, self._produceAndSchedule)
|
||||
else:
|
||||
self.consumer.unregisterProducer()
|
||||
d = self._producertest(PushProducer)
|
||||
|
||||
def finished(results):
|
||||
(client, server) = results
|
||||
self.assertFalse(
|
||||
server.producer.resumed,
|
||||
"Streaming producer should not have been resumed.")
|
||||
d.addCallback(finished)
|
||||
return d
|
||||
|
||||
|
||||
def test_pullProducer(self):
|
||||
"""
|
||||
Test a pull producer registered against a loopback transport.
|
||||
"""
|
||||
@implementer(IPullProducer)
|
||||
class PullProducer(object):
|
||||
def __init__(self, toProduce):
|
||||
self.toProduce = toProduce
|
||||
|
||||
def start(self, consumer):
|
||||
self.consumer = consumer
|
||||
self.consumer.registerProducer(self, False)
|
||||
|
||||
def resumeProducing(self):
|
||||
self.consumer.write(self.toProduce.pop(0))
|
||||
if not self.toProduce:
|
||||
self.consumer.unregisterProducer()
|
||||
return self._producertest(PullProducer)
|
||||
|
||||
|
||||
def test_writeNotReentrant(self):
|
||||
"""
|
||||
L{loopback.loopbackAsync} does not call a protocol's C{dataReceived}
|
||||
method while that protocol's transport's C{write} method is higher up
|
||||
on the stack.
|
||||
"""
|
||||
class Server(Protocol):
|
||||
def dataReceived(self, bytes):
|
||||
self.transport.write(b"bytes")
|
||||
|
||||
class Client(Protocol):
|
||||
ready = False
|
||||
|
||||
def connectionMade(self):
|
||||
reactor.callLater(0, self.go)
|
||||
|
||||
def go(self):
|
||||
self.transport.write(b"foo")
|
||||
self.ready = True
|
||||
|
||||
def dataReceived(self, bytes):
|
||||
self.wasReady = self.ready
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
server = Server()
|
||||
client = Client()
|
||||
d = loopback.loopbackAsync(client, server)
|
||||
def cbFinished(ignored):
|
||||
self.assertTrue(client.wasReady)
|
||||
d.addCallback(cbFinished)
|
||||
return d
|
||||
|
||||
|
||||
def test_pumpPolicy(self):
|
||||
"""
|
||||
The callable passed as the value for the C{pumpPolicy} parameter to
|
||||
L{loopbackAsync} is called with a L{_LoopbackQueue} of pending bytes
|
||||
and a protocol to which they should be delivered.
|
||||
"""
|
||||
pumpCalls = []
|
||||
def dummyPolicy(queue, target):
|
||||
bytes = []
|
||||
while queue:
|
||||
bytes.append(queue.get())
|
||||
pumpCalls.append((target, bytes))
|
||||
|
||||
client = Protocol()
|
||||
server = Protocol()
|
||||
|
||||
finished = loopback.loopbackAsync(server, client, dummyPolicy)
|
||||
self.assertEqual(pumpCalls, [])
|
||||
|
||||
client.transport.write(b"foo")
|
||||
client.transport.write(b"bar")
|
||||
server.transport.write(b"baz")
|
||||
server.transport.write(b"quux")
|
||||
server.transport.loseConnection()
|
||||
|
||||
def cbComplete(ignored):
|
||||
self.assertEqual(
|
||||
pumpCalls,
|
||||
# The order here is somewhat arbitrary. The implementation
|
||||
# happens to always deliver data to the client first.
|
||||
[(client, [b"baz", b"quux", None]),
|
||||
(server, [b"foo", b"bar"])])
|
||||
finished.addCallback(cbComplete)
|
||||
return finished
|
||||
|
||||
|
||||
def test_identityPumpPolicy(self):
|
||||
"""
|
||||
L{identityPumpPolicy} is a pump policy which calls the target's
|
||||
C{dataReceived} method one for each string in the queue passed to it.
|
||||
"""
|
||||
bytes = []
|
||||
client = Protocol()
|
||||
client.dataReceived = bytes.append
|
||||
queue = loopback._LoopbackQueue()
|
||||
queue.put(b"foo")
|
||||
queue.put(b"bar")
|
||||
queue.put(None)
|
||||
|
||||
loopback.identityPumpPolicy(queue, client)
|
||||
|
||||
self.assertEqual(bytes, [b"foo", b"bar"])
|
||||
|
||||
|
||||
def test_collapsingPumpPolicy(self):
|
||||
"""
|
||||
L{collapsingPumpPolicy} is a pump policy which calls the target's
|
||||
C{dataReceived} only once with all of the strings in the queue passed
|
||||
to it joined together.
|
||||
"""
|
||||
bytes = []
|
||||
client = Protocol()
|
||||
client.dataReceived = bytes.append
|
||||
queue = loopback._LoopbackQueue()
|
||||
queue.put(b"foo")
|
||||
queue.put(b"bar")
|
||||
queue.put(None)
|
||||
|
||||
loopback.collapsingPumpPolicy(queue, client)
|
||||
|
||||
self.assertEqual(bytes, [b"foobar"])
|
||||
|
||||
|
||||
|
||||
class LoopbackTCPTestCase(LoopbackTestCaseMixin, unittest.TestCase):
|
||||
loopbackFunc = staticmethod(loopback.loopbackTCP)
|
||||
|
||||
|
||||
class LoopbackUNIXTestCase(LoopbackTestCaseMixin, unittest.TestCase):
|
||||
loopbackFunc = staticmethod(loopback.loopbackUNIX)
|
||||
|
||||
if interfaces.IReactorUNIX(reactor, None) is None:
|
||||
skip = "Current reactor does not support UNIX sockets"
|
||||
elif _PY3:
|
||||
skip = "UNIX sockets not supported on Python 3. See #6136"
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.manhole import service
|
||||
from twisted.spread.util import LocalAsRemote
|
||||
|
||||
class Dummy:
|
||||
pass
|
||||
|
||||
class DummyTransport:
|
||||
def getHost(self):
|
||||
return 'INET', '127.0.0.1', 0
|
||||
|
||||
class DummyManholeClient(LocalAsRemote):
|
||||
zero = 0
|
||||
broker = Dummy()
|
||||
broker.transport = DummyTransport()
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
def console(self, messages):
|
||||
self.messages.extend(messages)
|
||||
|
||||
def receiveExplorer(self, xplorer):
|
||||
pass
|
||||
|
||||
def setZero(self):
|
||||
self.zero = len(self.messages)
|
||||
|
||||
def getMessages(self):
|
||||
return self.messages[self.zero:]
|
||||
|
||||
# local interface
|
||||
sync_console = console
|
||||
sync_receiveExplorer = receiveExplorer
|
||||
sync_setZero = setZero
|
||||
sync_getMessages = getMessages
|
||||
|
||||
class ManholeTest(unittest.TestCase):
|
||||
"""Various tests for the manhole service.
|
||||
|
||||
Both the the importIdentity and importMain tests are known to fail
|
||||
when the __name__ in the manhole namespace is set to certain
|
||||
values.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.service = service.Service()
|
||||
self.p = service.Perspective(self.service)
|
||||
self.client = DummyManholeClient()
|
||||
self.p.attached(self.client, None)
|
||||
|
||||
def test_importIdentity(self):
|
||||
"""Making sure imported module is the same as one previously loaded.
|
||||
"""
|
||||
self.p.perspective_do("from twisted.manhole import service")
|
||||
self.client.setZero()
|
||||
self.p.perspective_do("int(service is sys.modules['twisted.manhole.service'])")
|
||||
msg = self.client.getMessages()[0]
|
||||
self.assertEqual(msg, ('result',"1\n"))
|
||||
|
||||
def test_importMain(self):
|
||||
"""Trying to import __main__"""
|
||||
self.client.setZero()
|
||||
self.p.perspective_do("import __main__")
|
||||
if self.client.getMessages():
|
||||
msg = self.client.getMessages()[0]
|
||||
if msg[0] in ("exception","stderr"):
|
||||
self.fail(msg[1])
|
||||
|
||||
#if __name__=='__main__':
|
||||
# unittest.main()
|
||||
|
|
@ -0,0 +1,699 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test the memcache client protocol.
|
||||
"""
|
||||
|
||||
from twisted.internet.error import ConnectionDone
|
||||
|
||||
from twisted.protocols.memcache import MemCacheProtocol, NoSuchCommand
|
||||
from twisted.protocols.memcache import ClientError, ServerError
|
||||
|
||||
from twisted.trial.unittest import TestCase
|
||||
from twisted.test.proto_helpers import StringTransportWithDisconnection
|
||||
from twisted.internet.task import Clock
|
||||
from twisted.internet.defer import Deferred, gatherResults, TimeoutError
|
||||
from twisted.internet.defer import DeferredList
|
||||
|
||||
|
||||
|
||||
class CommandMixin:
|
||||
"""
|
||||
Setup and tests for basic invocation of L{MemCacheProtocol} commands.
|
||||
"""
|
||||
|
||||
def _test(self, d, send, recv, result):
|
||||
"""
|
||||
Helper test method to test the resulting C{Deferred} of a
|
||||
L{MemCacheProtocol} command.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def test_get(self):
|
||||
"""
|
||||
L{MemCacheProtocol.get} returns a L{Deferred} which is called back with
|
||||
the value and the flag associated with the given key if the server
|
||||
returns a successful result.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.get("foo"), "get foo\r\n",
|
||||
"VALUE foo 0 3\r\nbar\r\nEND\r\n", (0, "bar"))
|
||||
|
||||
|
||||
def test_emptyGet(self):
|
||||
"""
|
||||
Test getting a non-available key: it succeeds but return C{None} as
|
||||
value and C{0} as flag.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.get("foo"), "get foo\r\n",
|
||||
"END\r\n", (0, None))
|
||||
|
||||
|
||||
def test_getMultiple(self):
|
||||
"""
|
||||
L{MemCacheProtocol.getMultiple} returns a L{Deferred} which is called
|
||||
back with a dictionary of flag, value for each given key.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.getMultiple(['foo', 'cow']),
|
||||
"get foo cow\r\n",
|
||||
"VALUE foo 0 3\r\nbar\r\nVALUE cow 0 7\r\nchicken\r\nEND\r\n",
|
||||
{'cow': (0, 'chicken'), 'foo': (0, 'bar')})
|
||||
|
||||
|
||||
def test_getMultipleWithEmpty(self):
|
||||
"""
|
||||
When L{MemCacheProtocol.getMultiple} is called with non-available keys,
|
||||
the corresponding tuples are (0, None).
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.getMultiple(['foo', 'cow']),
|
||||
"get foo cow\r\n",
|
||||
"VALUE cow 1 3\r\nbar\r\nEND\r\n",
|
||||
{'cow': (1, 'bar'), 'foo': (0, None)})
|
||||
|
||||
|
||||
def test_set(self):
|
||||
"""
|
||||
L{MemCacheProtocol.set} returns a L{Deferred} which is called back with
|
||||
C{True} when the operation succeeds.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.set("foo", "bar"),
|
||||
"set foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)
|
||||
|
||||
|
||||
def test_add(self):
|
||||
"""
|
||||
L{MemCacheProtocol.add} returns a L{Deferred} which is called back with
|
||||
C{True} when the operation succeeds.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.add("foo", "bar"),
|
||||
"add foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)
|
||||
|
||||
|
||||
def test_replace(self):
|
||||
"""
|
||||
L{MemCacheProtocol.replace} returns a L{Deferred} which is called back
|
||||
with C{True} when the operation succeeds.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.replace("foo", "bar"),
|
||||
"replace foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)
|
||||
|
||||
|
||||
def test_errorAdd(self):
|
||||
"""
|
||||
Test an erroneous add: if a L{MemCacheProtocol.add} is called but the
|
||||
key already exists on the server, it returns a B{NOT STORED} answer,
|
||||
which calls back the resulting L{Deferred} with C{False}.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.add("foo", "bar"),
|
||||
"add foo 0 0 3\r\nbar\r\n", "NOT STORED\r\n", False)
|
||||
|
||||
|
||||
def test_errorReplace(self):
|
||||
"""
|
||||
Test an erroneous replace: if a L{MemCacheProtocol.replace} is called
|
||||
but the key doesn't exist on the server, it returns a B{NOT STORED}
|
||||
answer, which calls back the resulting L{Deferred} with C{False}.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.replace("foo", "bar"),
|
||||
"replace foo 0 0 3\r\nbar\r\n", "NOT STORED\r\n", False)
|
||||
|
||||
|
||||
def test_delete(self):
|
||||
"""
|
||||
L{MemCacheProtocol.delete} returns a L{Deferred} which is called back
|
||||
with C{True} when the server notifies a success.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.delete("bar"), "delete bar\r\n", "DELETED\r\n", True)
|
||||
|
||||
|
||||
def test_errorDelete(self):
|
||||
"""
|
||||
Test a error during a delete: if key doesn't exist on the server, it
|
||||
returns a B{NOT FOUND} answer which calls back the resulting
|
||||
L{Deferred} with C{False}.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.delete("bar"), "delete bar\r\n", "NOT FOUND\r\n", False)
|
||||
|
||||
|
||||
def test_increment(self):
|
||||
"""
|
||||
Test incrementing a variable: L{MemCacheProtocol.increment} returns a
|
||||
L{Deferred} which is called back with the incremented value of the
|
||||
given key.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.increment("foo"), "incr foo 1\r\n", "4\r\n", 4)
|
||||
|
||||
|
||||
def test_decrement(self):
|
||||
"""
|
||||
Test decrementing a variable: L{MemCacheProtocol.decrement} returns a
|
||||
L{Deferred} which is called back with the decremented value of the
|
||||
given key.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.decrement("foo"), "decr foo 1\r\n", "5\r\n", 5)
|
||||
|
||||
|
||||
def test_incrementVal(self):
|
||||
"""
|
||||
L{MemCacheProtocol.increment} takes an optional argument C{value} which
|
||||
replaces the default value of 1 when specified.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.increment("foo", 8), "incr foo 8\r\n", "4\r\n", 4)
|
||||
|
||||
|
||||
def test_decrementVal(self):
|
||||
"""
|
||||
L{MemCacheProtocol.decrement} takes an optional argument C{value} which
|
||||
replaces the default value of 1 when specified.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.decrement("foo", 3), "decr foo 3\r\n", "5\r\n", 5)
|
||||
|
||||
|
||||
def test_stats(self):
|
||||
"""
|
||||
Test retrieving server statistics via the L{MemCacheProtocol.stats}
|
||||
command: it parses the data sent by the server and calls back the
|
||||
resulting L{Deferred} with a dictionary of the received statistics.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.stats(), "stats\r\n",
|
||||
"STAT foo bar\r\nSTAT egg spam\r\nEND\r\n",
|
||||
{"foo": "bar", "egg": "spam"})
|
||||
|
||||
|
||||
def test_statsWithArgument(self):
|
||||
"""
|
||||
L{MemCacheProtocol.stats} takes an optional C{str} argument which,
|
||||
if specified, is sent along with the I{STAT} command. The I{STAT}
|
||||
responses from the server are parsed as key/value pairs and returned
|
||||
as a C{dict} (as in the case where the argument is not specified).
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.stats("blah"), "stats blah\r\n",
|
||||
"STAT foo bar\r\nSTAT egg spam\r\nEND\r\n",
|
||||
{"foo": "bar", "egg": "spam"})
|
||||
|
||||
|
||||
def test_version(self):
|
||||
"""
|
||||
Test version retrieval via the L{MemCacheProtocol.version} command: it
|
||||
returns a L{Deferred} which is called back with the version sent by the
|
||||
server.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.version(), "version\r\n", "VERSION 1.1\r\n", "1.1")
|
||||
|
||||
|
||||
def test_flushAll(self):
|
||||
"""
|
||||
L{MemCacheProtocol.flushAll} returns a L{Deferred} which is called back
|
||||
with C{True} if the server acknowledges success.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.flushAll(), "flush_all\r\n", "OK\r\n", True)
|
||||
|
||||
|
||||
|
||||
class MemCacheTestCase(CommandMixin, TestCase):
|
||||
"""
|
||||
Test client protocol class L{MemCacheProtocol}.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Create a memcache client, connect it to a string protocol, and make it
|
||||
use a deterministic clock.
|
||||
"""
|
||||
self.proto = MemCacheProtocol()
|
||||
self.clock = Clock()
|
||||
self.proto.callLater = self.clock.callLater
|
||||
self.transport = StringTransportWithDisconnection()
|
||||
self.transport.protocol = self.proto
|
||||
self.proto.makeConnection(self.transport)
|
||||
|
||||
|
||||
def _test(self, d, send, recv, result):
|
||||
"""
|
||||
Implementation of C{_test} which checks that the command sends C{send}
|
||||
data, and that upon reception of C{recv} the result is C{result}.
|
||||
|
||||
@param d: the resulting deferred from the memcache command.
|
||||
@type d: C{Deferred}
|
||||
|
||||
@param send: the expected data to be sent.
|
||||
@type send: C{str}
|
||||
|
||||
@param recv: the data to simulate as reception.
|
||||
@type recv: C{str}
|
||||
|
||||
@param result: the expected result.
|
||||
@type result: C{any}
|
||||
"""
|
||||
def cb(res):
|
||||
self.assertEqual(res, result)
|
||||
self.assertEqual(self.transport.value(), send)
|
||||
d.addCallback(cb)
|
||||
self.proto.dataReceived(recv)
|
||||
return d
|
||||
|
||||
|
||||
def test_invalidGetResponse(self):
|
||||
"""
|
||||
If the value returned doesn't match the expected key of the current
|
||||
C{get} command, an error is raised in L{MemCacheProtocol.dataReceived}.
|
||||
"""
|
||||
self.proto.get("foo")
|
||||
s = "spamegg"
|
||||
self.assertRaises(
|
||||
RuntimeError, self.proto.dataReceived,
|
||||
"VALUE bar 0 %s\r\n%s\r\nEND\r\n" % (len(s), s))
|
||||
|
||||
|
||||
def test_invalidMultipleGetResponse(self):
|
||||
"""
|
||||
If the value returned doesn't match one the expected keys of the
|
||||
current multiple C{get} command, an error is raised error in
|
||||
L{MemCacheProtocol.dataReceived}.
|
||||
"""
|
||||
self.proto.getMultiple(["foo", "bar"])
|
||||
s = "spamegg"
|
||||
self.assertRaises(
|
||||
RuntimeError, self.proto.dataReceived,
|
||||
"VALUE egg 0 %s\r\n%s\r\nEND\r\n" % (len(s), s))
|
||||
|
||||
|
||||
def test_timeOut(self):
|
||||
"""
|
||||
Test the timeout on outgoing requests: when timeout is detected, all
|
||||
current commands fail with a L{TimeoutError}, and the connection is
|
||||
closed.
|
||||
"""
|
||||
d1 = self.proto.get("foo")
|
||||
d2 = self.proto.get("bar")
|
||||
d3 = Deferred()
|
||||
self.proto.connectionLost = d3.callback
|
||||
|
||||
self.clock.advance(self.proto.persistentTimeOut)
|
||||
self.assertFailure(d1, TimeoutError)
|
||||
self.assertFailure(d2, TimeoutError)
|
||||
|
||||
def checkMessage(error):
|
||||
self.assertEqual(str(error), "Connection timeout")
|
||||
|
||||
d1.addCallback(checkMessage)
|
||||
self.assertFailure(d3, ConnectionDone)
|
||||
return gatherResults([d1, d2, d3])
|
||||
|
||||
|
||||
def test_timeoutRemoved(self):
|
||||
"""
|
||||
When a request gets a response, no pending timeout call remains around.
|
||||
"""
|
||||
d = self.proto.get("foo")
|
||||
|
||||
self.clock.advance(self.proto.persistentTimeOut - 1)
|
||||
self.proto.dataReceived("VALUE foo 0 3\r\nbar\r\nEND\r\n")
|
||||
|
||||
def check(result):
|
||||
self.assertEqual(result, (0, "bar"))
|
||||
self.assertEqual(len(self.clock.calls), 0)
|
||||
d.addCallback(check)
|
||||
return d
|
||||
|
||||
|
||||
def test_timeOutRaw(self):
|
||||
"""
|
||||
Test the timeout when raw mode was started: the timeout is not reset
|
||||
until all the data has been received, so we can have a L{TimeoutError}
|
||||
when waiting for raw data.
|
||||
"""
|
||||
d1 = self.proto.get("foo")
|
||||
d2 = Deferred()
|
||||
self.proto.connectionLost = d2.callback
|
||||
|
||||
self.proto.dataReceived("VALUE foo 0 10\r\n12345")
|
||||
self.clock.advance(self.proto.persistentTimeOut)
|
||||
self.assertFailure(d1, TimeoutError)
|
||||
self.assertFailure(d2, ConnectionDone)
|
||||
return gatherResults([d1, d2])
|
||||
|
||||
|
||||
def test_timeOutStat(self):
|
||||
"""
|
||||
Test the timeout when stat command has started: the timeout is not
|
||||
reset until the final B{END} is received.
|
||||
"""
|
||||
d1 = self.proto.stats()
|
||||
d2 = Deferred()
|
||||
self.proto.connectionLost = d2.callback
|
||||
|
||||
self.proto.dataReceived("STAT foo bar\r\n")
|
||||
self.clock.advance(self.proto.persistentTimeOut)
|
||||
self.assertFailure(d1, TimeoutError)
|
||||
self.assertFailure(d2, ConnectionDone)
|
||||
return gatherResults([d1, d2])
|
||||
|
||||
|
||||
def test_timeoutPipelining(self):
|
||||
"""
|
||||
When two requests are sent, a timeout call remains around for the
|
||||
second request, and its timeout time is correct.
|
||||
"""
|
||||
d1 = self.proto.get("foo")
|
||||
d2 = self.proto.get("bar")
|
||||
d3 = Deferred()
|
||||
self.proto.connectionLost = d3.callback
|
||||
|
||||
self.clock.advance(self.proto.persistentTimeOut - 1)
|
||||
self.proto.dataReceived("VALUE foo 0 3\r\nbar\r\nEND\r\n")
|
||||
|
||||
def check(result):
|
||||
self.assertEqual(result, (0, "bar"))
|
||||
self.assertEqual(len(self.clock.calls), 1)
|
||||
for i in range(self.proto.persistentTimeOut):
|
||||
self.clock.advance(1)
|
||||
return self.assertFailure(d2, TimeoutError).addCallback(checkTime)
|
||||
|
||||
def checkTime(ignored):
|
||||
# Check that the timeout happened C{self.proto.persistentTimeOut}
|
||||
# after the last response
|
||||
self.assertEqual(
|
||||
self.clock.seconds(), 2 * self.proto.persistentTimeOut - 1)
|
||||
|
||||
d1.addCallback(check)
|
||||
self.assertFailure(d3, ConnectionDone)
|
||||
return d1
|
||||
|
||||
|
||||
def test_timeoutNotReset(self):
|
||||
"""
|
||||
Check that timeout is not resetted for every command, but keep the
|
||||
timeout from the first command without response.
|
||||
"""
|
||||
d1 = self.proto.get("foo")
|
||||
d3 = Deferred()
|
||||
self.proto.connectionLost = d3.callback
|
||||
|
||||
self.clock.advance(self.proto.persistentTimeOut - 1)
|
||||
d2 = self.proto.get("bar")
|
||||
self.clock.advance(1)
|
||||
self.assertFailure(d1, TimeoutError)
|
||||
self.assertFailure(d2, TimeoutError)
|
||||
self.assertFailure(d3, ConnectionDone)
|
||||
return gatherResults([d1, d2, d3])
|
||||
|
||||
|
||||
def test_timeoutCleanDeferreds(self):
|
||||
"""
|
||||
C{timeoutConnection} cleans the list of commands that it fires with
|
||||
C{TimeoutError}: C{connectionLost} doesn't try to fire them again, but
|
||||
sets the disconnected state so that future commands fail with a
|
||||
C{RuntimeError}.
|
||||
"""
|
||||
d1 = self.proto.get("foo")
|
||||
self.clock.advance(self.proto.persistentTimeOut)
|
||||
self.assertFailure(d1, TimeoutError)
|
||||
d2 = self.proto.get("bar")
|
||||
self.assertFailure(d2, RuntimeError)
|
||||
return gatherResults([d1, d2])
|
||||
|
||||
|
||||
def test_connectionLost(self):
|
||||
"""
|
||||
When disconnection occurs while commands are still outstanding, the
|
||||
commands fail.
|
||||
"""
|
||||
d1 = self.proto.get("foo")
|
||||
d2 = self.proto.get("bar")
|
||||
self.transport.loseConnection()
|
||||
done = DeferredList([d1, d2], consumeErrors=True)
|
||||
|
||||
def checkFailures(results):
|
||||
for success, result in results:
|
||||
self.assertFalse(success)
|
||||
result.trap(ConnectionDone)
|
||||
|
||||
return done.addCallback(checkFailures)
|
||||
|
||||
|
||||
def test_tooLongKey(self):
|
||||
"""
|
||||
An error is raised when trying to use a too long key: the called
|
||||
command returns a L{Deferred} which fails with a L{ClientError}.
|
||||
"""
|
||||
d1 = self.assertFailure(self.proto.set("a" * 500, "bar"), ClientError)
|
||||
d2 = self.assertFailure(self.proto.increment("a" * 500), ClientError)
|
||||
d3 = self.assertFailure(self.proto.get("a" * 500), ClientError)
|
||||
d4 = self.assertFailure(
|
||||
self.proto.append("a" * 500, "bar"), ClientError)
|
||||
d5 = self.assertFailure(
|
||||
self.proto.prepend("a" * 500, "bar"), ClientError)
|
||||
d6 = self.assertFailure(
|
||||
self.proto.getMultiple(["foo", "a" * 500]), ClientError)
|
||||
return gatherResults([d1, d2, d3, d4, d5, d6])
|
||||
|
||||
|
||||
def test_invalidCommand(self):
|
||||
"""
|
||||
When an unknown command is sent directly (not through public API), the
|
||||
server answers with an B{ERROR} token, and the command fails with
|
||||
L{NoSuchCommand}.
|
||||
"""
|
||||
d = self.proto._set("egg", "foo", "bar", 0, 0, "")
|
||||
self.assertEqual(self.transport.value(), "egg foo 0 0 3\r\nbar\r\n")
|
||||
self.assertFailure(d, NoSuchCommand)
|
||||
self.proto.dataReceived("ERROR\r\n")
|
||||
return d
|
||||
|
||||
|
||||
def test_clientError(self):
|
||||
"""
|
||||
Test the L{ClientError} error: when the server sends a B{CLIENT_ERROR}
|
||||
token, the originating command fails with L{ClientError}, and the error
|
||||
contains the text sent by the server.
|
||||
"""
|
||||
a = "eggspamm"
|
||||
d = self.proto.set("foo", a)
|
||||
self.assertEqual(self.transport.value(),
|
||||
"set foo 0 0 8\r\neggspamm\r\n")
|
||||
self.assertFailure(d, ClientError)
|
||||
|
||||
def check(err):
|
||||
self.assertEqual(str(err), "We don't like egg and spam")
|
||||
|
||||
d.addCallback(check)
|
||||
self.proto.dataReceived("CLIENT_ERROR We don't like egg and spam\r\n")
|
||||
return d
|
||||
|
||||
|
||||
def test_serverError(self):
|
||||
"""
|
||||
Test the L{ServerError} error: when the server sends a B{SERVER_ERROR}
|
||||
token, the originating command fails with L{ServerError}, and the error
|
||||
contains the text sent by the server.
|
||||
"""
|
||||
a = "eggspamm"
|
||||
d = self.proto.set("foo", a)
|
||||
self.assertEqual(self.transport.value(),
|
||||
"set foo 0 0 8\r\neggspamm\r\n")
|
||||
self.assertFailure(d, ServerError)
|
||||
|
||||
def check(err):
|
||||
self.assertEqual(str(err), "zomg")
|
||||
|
||||
d.addCallback(check)
|
||||
self.proto.dataReceived("SERVER_ERROR zomg\r\n")
|
||||
return d
|
||||
|
||||
|
||||
def test_unicodeKey(self):
|
||||
"""
|
||||
Using a non-string key as argument to commands raises an error.
|
||||
"""
|
||||
d1 = self.assertFailure(self.proto.set(u"foo", "bar"), ClientError)
|
||||
d2 = self.assertFailure(self.proto.increment(u"egg"), ClientError)
|
||||
d3 = self.assertFailure(self.proto.get(1), ClientError)
|
||||
d4 = self.assertFailure(self.proto.delete(u"bar"), ClientError)
|
||||
d5 = self.assertFailure(self.proto.append(u"foo", "bar"), ClientError)
|
||||
d6 = self.assertFailure(self.proto.prepend(u"foo", "bar"), ClientError)
|
||||
d7 = self.assertFailure(
|
||||
self.proto.getMultiple(["egg", 1]), ClientError)
|
||||
return gatherResults([d1, d2, d3, d4, d5, d6, d7])
|
||||
|
||||
|
||||
def test_unicodeValue(self):
|
||||
"""
|
||||
Using a non-string value raises an error.
|
||||
"""
|
||||
return self.assertFailure(self.proto.set("foo", u"bar"), ClientError)
|
||||
|
||||
|
||||
def test_pipelining(self):
|
||||
"""
|
||||
Multiple requests can be sent subsequently to the server, and the
|
||||
protocol orders the responses correctly and dispatch to the
|
||||
corresponding client command.
|
||||
"""
|
||||
d1 = self.proto.get("foo")
|
||||
d1.addCallback(self.assertEqual, (0, "bar"))
|
||||
d2 = self.proto.set("bar", "spamspamspam")
|
||||
d2.addCallback(self.assertEqual, True)
|
||||
d3 = self.proto.get("egg")
|
||||
d3.addCallback(self.assertEqual, (0, "spam"))
|
||||
self.assertEqual(
|
||||
self.transport.value(),
|
||||
"get foo\r\nset bar 0 0 12\r\nspamspamspam\r\nget egg\r\n")
|
||||
self.proto.dataReceived("VALUE foo 0 3\r\nbar\r\nEND\r\n"
|
||||
"STORED\r\n"
|
||||
"VALUE egg 0 4\r\nspam\r\nEND\r\n")
|
||||
return gatherResults([d1, d2, d3])
|
||||
|
||||
|
||||
def test_getInChunks(self):
|
||||
"""
|
||||
If the value retrieved by a C{get} arrive in chunks, the protocol
|
||||
is able to reconstruct it and to produce the good value.
|
||||
"""
|
||||
d = self.proto.get("foo")
|
||||
d.addCallback(self.assertEqual, (0, "0123456789"))
|
||||
self.assertEqual(self.transport.value(), "get foo\r\n")
|
||||
self.proto.dataReceived("VALUE foo 0 10\r\n0123456")
|
||||
self.proto.dataReceived("789")
|
||||
self.proto.dataReceived("\r\nEND")
|
||||
self.proto.dataReceived("\r\n")
|
||||
return d
|
||||
|
||||
|
||||
def test_append(self):
|
||||
"""
|
||||
L{MemCacheProtocol.append} behaves like a L{MemCacheProtocol.set}
|
||||
method: it returns a L{Deferred} which is called back with C{True} when
|
||||
the operation succeeds.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.append("foo", "bar"),
|
||||
"append foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)
|
||||
|
||||
|
||||
def test_prepend(self):
|
||||
"""
|
||||
L{MemCacheProtocol.prepend} behaves like a L{MemCacheProtocol.set}
|
||||
method: it returns a L{Deferred} which is called back with C{True} when
|
||||
the operation succeeds.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.prepend("foo", "bar"),
|
||||
"prepend foo 0 0 3\r\nbar\r\n", "STORED\r\n", True)
|
||||
|
||||
|
||||
def test_gets(self):
|
||||
"""
|
||||
L{MemCacheProtocol.get} handles an additional cas result when
|
||||
C{withIdentifier} is C{True} and forward it in the resulting
|
||||
L{Deferred}.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.get("foo", True), "gets foo\r\n",
|
||||
"VALUE foo 0 3 1234\r\nbar\r\nEND\r\n", (0, "1234", "bar"))
|
||||
|
||||
|
||||
def test_emptyGets(self):
|
||||
"""
|
||||
Test getting a non-available key with gets: it succeeds but return
|
||||
C{None} as value, C{0} as flag and an empty cas value.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.get("foo", True), "gets foo\r\n",
|
||||
"END\r\n", (0, "", None))
|
||||
|
||||
|
||||
def test_getsMultiple(self):
|
||||
"""
|
||||
L{MemCacheProtocol.getMultiple} handles an additional cas field in the
|
||||
returned tuples if C{withIdentifier} is C{True}.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.getMultiple(["foo", "bar"], True),
|
||||
"gets foo bar\r\n",
|
||||
"VALUE foo 0 3 1234\r\negg\r\n"
|
||||
"VALUE bar 0 4 2345\r\nspam\r\nEND\r\n",
|
||||
{'bar': (0, '2345', 'spam'), 'foo': (0, '1234', 'egg')})
|
||||
|
||||
|
||||
def test_getsMultipleWithEmpty(self):
|
||||
"""
|
||||
When getting a non-available key with L{MemCacheProtocol.getMultiple}
|
||||
when C{withIdentifier} is C{True}, the other keys are retrieved
|
||||
correctly, and the non-available key gets a tuple of C{0} as flag,
|
||||
C{None} as value, and an empty cas value.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.getMultiple(["foo", "bar"], True),
|
||||
"gets foo bar\r\n",
|
||||
"VALUE foo 0 3 1234\r\negg\r\nEND\r\n",
|
||||
{'bar': (0, '', None), 'foo': (0, '1234', 'egg')})
|
||||
|
||||
|
||||
def test_checkAndSet(self):
|
||||
"""
|
||||
L{MemCacheProtocol.checkAndSet} passes an additional cas identifier
|
||||
that the server handles to check if the data has to be updated.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.checkAndSet("foo", "bar", cas="1234"),
|
||||
"cas foo 0 0 3 1234\r\nbar\r\n", "STORED\r\n", True)
|
||||
|
||||
|
||||
def test_casUnknowKey(self):
|
||||
"""
|
||||
When L{MemCacheProtocol.checkAndSet} response is C{EXISTS}, the
|
||||
resulting L{Deferred} fires with C{False}.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.checkAndSet("foo", "bar", cas="1234"),
|
||||
"cas foo 0 0 3 1234\r\nbar\r\n", "EXISTS\r\n", False)
|
||||
|
||||
|
||||
|
||||
class CommandFailureTests(CommandMixin, TestCase):
|
||||
"""
|
||||
Tests for correct failure of commands on a disconnected
|
||||
L{MemCacheProtocol}.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Create a disconnected memcache client, using a deterministic clock.
|
||||
"""
|
||||
self.proto = MemCacheProtocol()
|
||||
self.clock = Clock()
|
||||
self.proto.callLater = self.clock.callLater
|
||||
self.transport = StringTransportWithDisconnection()
|
||||
self.transport.protocol = self.proto
|
||||
self.proto.makeConnection(self.transport)
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def _test(self, d, send, recv, result):
|
||||
"""
|
||||
Implementation of C{_test} which checks that the command fails with
|
||||
C{RuntimeError} because the transport is disconnected. All the
|
||||
parameters except C{d} are ignored.
|
||||
"""
|
||||
return self.assertFailure(d, RuntimeError)
|
||||
|
|
@ -0,0 +1,514 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for twisted.python.modules, abstract access to imported or importable
|
||||
objects.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import itertools
|
||||
import zipfile
|
||||
import compileall
|
||||
|
||||
import twisted
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
from twisted.python import modules
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.python.reflect import namedAny
|
||||
|
||||
from twisted.python.test.modules_helpers import TwistedModulesMixin
|
||||
from twisted.python.test.test_zippath import zipit
|
||||
|
||||
|
||||
class TwistedModulesTestCase(TwistedModulesMixin, TestCase):
|
||||
"""
|
||||
Base class for L{modules} test cases.
|
||||
"""
|
||||
def findByIteration(self, modname, where=modules, importPackages=False):
|
||||
"""
|
||||
You don't ever actually want to do this, so it's not in the public
|
||||
API, but sometimes we want to compare the result of an iterative call
|
||||
with a lookup call and make sure they're the same for test purposes.
|
||||
"""
|
||||
for modinfo in where.walkModules(importPackages=importPackages):
|
||||
if modinfo.name == modname:
|
||||
return modinfo
|
||||
self.fail("Unable to find module %r through iteration." % (modname,))
|
||||
|
||||
|
||||
|
||||
class BasicTests(TwistedModulesTestCase):
|
||||
|
||||
def test_namespacedPackages(self):
|
||||
"""
|
||||
Duplicate packages are not yielded when iterating over namespace
|
||||
packages.
|
||||
"""
|
||||
# Force pkgutil to be loaded already, since the probe package being
|
||||
# created depends on it, and the replaceSysPath call below will make
|
||||
# pretty much everything unimportable.
|
||||
__import__('pkgutil')
|
||||
|
||||
namespaceBoilerplate = (
|
||||
'import pkgutil; '
|
||||
'__path__ = pkgutil.extend_path(__path__, __name__)')
|
||||
|
||||
# Create two temporary directories with packages:
|
||||
#
|
||||
# entry:
|
||||
# test_package/
|
||||
# __init__.py
|
||||
# nested_package/
|
||||
# __init__.py
|
||||
# module.py
|
||||
#
|
||||
# anotherEntry:
|
||||
# test_package/
|
||||
# __init__.py
|
||||
# nested_package/
|
||||
# __init__.py
|
||||
# module2.py
|
||||
#
|
||||
# test_package and test_package.nested_package are namespace packages,
|
||||
# and when both of these are in sys.path, test_package.nested_package
|
||||
# should become a virtual package containing both "module" and
|
||||
# "module2"
|
||||
|
||||
entry = self.pathEntryWithOnePackage()
|
||||
testPackagePath = entry.child('test_package')
|
||||
testPackagePath.child('__init__.py').setContent(namespaceBoilerplate)
|
||||
|
||||
nestedEntry = testPackagePath.child('nested_package')
|
||||
nestedEntry.makedirs()
|
||||
nestedEntry.child('__init__.py').setContent(namespaceBoilerplate)
|
||||
nestedEntry.child('module.py').setContent('')
|
||||
|
||||
anotherEntry = self.pathEntryWithOnePackage()
|
||||
anotherPackagePath = anotherEntry.child('test_package')
|
||||
anotherPackagePath.child('__init__.py').setContent(namespaceBoilerplate)
|
||||
|
||||
anotherNestedEntry = anotherPackagePath.child('nested_package')
|
||||
anotherNestedEntry.makedirs()
|
||||
anotherNestedEntry.child('__init__.py').setContent(namespaceBoilerplate)
|
||||
anotherNestedEntry.child('module2.py').setContent('')
|
||||
|
||||
self.replaceSysPath([entry.path, anotherEntry.path])
|
||||
|
||||
module = modules.getModule('test_package')
|
||||
|
||||
# We have to use importPackages=True in order to resolve the namespace
|
||||
# packages, so we remove the imported packages from sys.modules after
|
||||
# walking
|
||||
try:
|
||||
walkedNames = [
|
||||
mod.name for mod in module.walkModules(importPackages=True)]
|
||||
finally:
|
||||
for module in sys.modules.keys():
|
||||
if module.startswith('test_package'):
|
||||
del sys.modules[module]
|
||||
|
||||
expected = [
|
||||
'test_package',
|
||||
'test_package.nested_package',
|
||||
'test_package.nested_package.module',
|
||||
'test_package.nested_package.module2',
|
||||
]
|
||||
|
||||
self.assertEqual(walkedNames, expected)
|
||||
|
||||
|
||||
def test_unimportablePackageGetItem(self):
|
||||
"""
|
||||
If a package has been explicitly forbidden from importing by setting a
|
||||
C{None} key in sys.modules under its name,
|
||||
L{modules.PythonPath.__getitem__} should still be able to retrieve an
|
||||
unloaded L{modules.PythonModule} for that package.
|
||||
"""
|
||||
shouldNotLoad = []
|
||||
path = modules.PythonPath(sysPath=[self.pathEntryWithOnePackage().path],
|
||||
moduleLoader=shouldNotLoad.append,
|
||||
importerCache={},
|
||||
sysPathHooks={},
|
||||
moduleDict={'test_package': None})
|
||||
self.assertEqual(shouldNotLoad, [])
|
||||
self.assertEqual(path['test_package'].isLoaded(), False)
|
||||
|
||||
|
||||
def test_unimportablePackageWalkModules(self):
|
||||
"""
|
||||
If a package has been explicitly forbidden from importing by setting a
|
||||
C{None} key in sys.modules under its name, L{modules.walkModules} should
|
||||
still be able to retrieve an unloaded L{modules.PythonModule} for that
|
||||
package.
|
||||
"""
|
||||
existentPath = self.pathEntryWithOnePackage()
|
||||
self.replaceSysPath([existentPath.path])
|
||||
self.replaceSysModules({"test_package": None})
|
||||
|
||||
walked = list(modules.walkModules())
|
||||
self.assertEqual([m.name for m in walked],
|
||||
["test_package"])
|
||||
self.assertEqual(walked[0].isLoaded(), False)
|
||||
|
||||
|
||||
def test_nonexistentPaths(self):
|
||||
"""
|
||||
Verify that L{modules.walkModules} ignores entries in sys.path which
|
||||
do not exist in the filesystem.
|
||||
"""
|
||||
existentPath = self.pathEntryWithOnePackage()
|
||||
|
||||
nonexistentPath = FilePath(self.mktemp())
|
||||
self.failIf(nonexistentPath.exists())
|
||||
|
||||
self.replaceSysPath([existentPath.path])
|
||||
|
||||
expected = [modules.getModule("test_package")]
|
||||
|
||||
beforeModules = list(modules.walkModules())
|
||||
sys.path.append(nonexistentPath.path)
|
||||
afterModules = list(modules.walkModules())
|
||||
|
||||
self.assertEqual(beforeModules, expected)
|
||||
self.assertEqual(afterModules, expected)
|
||||
|
||||
|
||||
def test_nonDirectoryPaths(self):
|
||||
"""
|
||||
Verify that L{modules.walkModules} ignores entries in sys.path which
|
||||
refer to regular files in the filesystem.
|
||||
"""
|
||||
existentPath = self.pathEntryWithOnePackage()
|
||||
|
||||
nonDirectoryPath = FilePath(self.mktemp())
|
||||
self.failIf(nonDirectoryPath.exists())
|
||||
nonDirectoryPath.setContent("zip file or whatever\n")
|
||||
|
||||
self.replaceSysPath([existentPath.path])
|
||||
|
||||
beforeModules = list(modules.walkModules())
|
||||
sys.path.append(nonDirectoryPath.path)
|
||||
afterModules = list(modules.walkModules())
|
||||
|
||||
self.assertEqual(beforeModules, afterModules)
|
||||
|
||||
|
||||
def test_twistedShowsUp(self):
|
||||
"""
|
||||
Scrounge around in the top-level module namespace and make sure that
|
||||
Twisted shows up, and that the module thusly obtained is the same as
|
||||
the module that we find when we look for it explicitly by name.
|
||||
"""
|
||||
self.assertEqual(modules.getModule('twisted'),
|
||||
self.findByIteration("twisted"))
|
||||
|
||||
|
||||
def test_dottedNames(self):
|
||||
"""
|
||||
Verify that the walkModules APIs will give us back subpackages, not just
|
||||
subpackages.
|
||||
"""
|
||||
self.assertEqual(
|
||||
modules.getModule('twisted.python'),
|
||||
self.findByIteration("twisted.python",
|
||||
where=modules.getModule('twisted')))
|
||||
|
||||
|
||||
def test_onlyTopModules(self):
|
||||
"""
|
||||
Verify that the iterModules API will only return top-level modules and
|
||||
packages, not submodules or subpackages.
|
||||
"""
|
||||
for module in modules.iterModules():
|
||||
self.failIf(
|
||||
'.' in module.name,
|
||||
"no nested modules should be returned from iterModules: %r"
|
||||
% (module.filePath))
|
||||
|
||||
|
||||
def test_loadPackagesAndModules(self):
|
||||
"""
|
||||
Verify that we can locate and load packages, modules, submodules, and
|
||||
subpackages.
|
||||
"""
|
||||
for n in ['os',
|
||||
'twisted',
|
||||
'twisted.python',
|
||||
'twisted.python.reflect']:
|
||||
m = namedAny(n)
|
||||
self.failUnlessIdentical(
|
||||
modules.getModule(n).load(),
|
||||
m)
|
||||
self.failUnlessIdentical(
|
||||
self.findByIteration(n).load(),
|
||||
m)
|
||||
|
||||
|
||||
def test_pathEntriesOnPath(self):
|
||||
"""
|
||||
Verify that path entries discovered via module loading are, in fact, on
|
||||
sys.path somewhere.
|
||||
"""
|
||||
for n in ['os',
|
||||
'twisted',
|
||||
'twisted.python',
|
||||
'twisted.python.reflect']:
|
||||
self.failUnlessIn(
|
||||
modules.getModule(n).pathEntry.filePath.path,
|
||||
sys.path)
|
||||
|
||||
|
||||
def test_alwaysPreferPy(self):
|
||||
"""
|
||||
Verify that .py files will always be preferred to .pyc files, regardless of
|
||||
directory listing order.
|
||||
"""
|
||||
mypath = FilePath(self.mktemp())
|
||||
mypath.createDirectory()
|
||||
pp = modules.PythonPath(sysPath=[mypath.path])
|
||||
originalSmartPath = pp._smartPath
|
||||
def _evilSmartPath(pathName):
|
||||
o = originalSmartPath(pathName)
|
||||
originalChildren = o.children
|
||||
def evilChildren():
|
||||
# normally this order is random; let's make sure it always
|
||||
# comes up .pyc-first.
|
||||
x = originalChildren()
|
||||
x.sort()
|
||||
x.reverse()
|
||||
return x
|
||||
o.children = evilChildren
|
||||
return o
|
||||
mypath.child("abcd.py").setContent('\n')
|
||||
compileall.compile_dir(mypath.path, quiet=True)
|
||||
# sanity check
|
||||
self.assertEqual(len(mypath.children()), 2)
|
||||
pp._smartPath = _evilSmartPath
|
||||
self.assertEqual(pp['abcd'].filePath,
|
||||
mypath.child('abcd.py'))
|
||||
|
||||
|
||||
def test_packageMissingPath(self):
|
||||
"""
|
||||
A package can delete its __path__ for some reasons,
|
||||
C{modules.PythonPath} should be able to deal with it.
|
||||
"""
|
||||
mypath = FilePath(self.mktemp())
|
||||
mypath.createDirectory()
|
||||
pp = modules.PythonPath(sysPath=[mypath.path])
|
||||
subpath = mypath.child("abcd")
|
||||
subpath.createDirectory()
|
||||
subpath.child("__init__.py").setContent('del __path__\n')
|
||||
sys.path.append(mypath.path)
|
||||
__import__("abcd")
|
||||
try:
|
||||
l = list(pp.walkModules())
|
||||
self.assertEqual(len(l), 1)
|
||||
self.assertEqual(l[0].name, 'abcd')
|
||||
finally:
|
||||
del sys.modules['abcd']
|
||||
sys.path.remove(mypath.path)
|
||||
|
||||
|
||||
|
||||
class PathModificationTest(TwistedModulesTestCase):
|
||||
"""
|
||||
These tests share setup/cleanup behavior of creating a dummy package and
|
||||
stuffing some code in it.
|
||||
"""
|
||||
|
||||
_serialnum = itertools.count().next # used to generate serial numbers for
|
||||
# package names.
|
||||
|
||||
def setUp(self):
|
||||
self.pathExtensionName = self.mktemp()
|
||||
self.pathExtension = FilePath(self.pathExtensionName)
|
||||
self.pathExtension.createDirectory()
|
||||
self.packageName = "pyspacetests%d" % (self._serialnum(),)
|
||||
self.packagePath = self.pathExtension.child(self.packageName)
|
||||
self.packagePath.createDirectory()
|
||||
self.packagePath.child("__init__.py").setContent("")
|
||||
self.packagePath.child("a.py").setContent("")
|
||||
self.packagePath.child("b.py").setContent("")
|
||||
self.packagePath.child("c__init__.py").setContent("")
|
||||
self.pathSetUp = False
|
||||
|
||||
|
||||
def _setupSysPath(self):
|
||||
assert not self.pathSetUp
|
||||
self.pathSetUp = True
|
||||
sys.path.append(self.pathExtensionName)
|
||||
|
||||
|
||||
def _underUnderPathTest(self, doImport=True):
|
||||
moddir2 = self.mktemp()
|
||||
fpmd = FilePath(moddir2)
|
||||
fpmd.createDirectory()
|
||||
fpmd.child("foozle.py").setContent("x = 123\n")
|
||||
self.packagePath.child("__init__.py").setContent(
|
||||
"__path__.append(%r)\n" % (moddir2,))
|
||||
# Cut here
|
||||
self._setupSysPath()
|
||||
modinfo = modules.getModule(self.packageName)
|
||||
self.assertEqual(
|
||||
self.findByIteration(self.packageName+".foozle", modinfo,
|
||||
importPackages=doImport),
|
||||
modinfo['foozle'])
|
||||
self.assertEqual(modinfo['foozle'].load().x, 123)
|
||||
|
||||
|
||||
def test_underUnderPathAlreadyImported(self):
|
||||
"""
|
||||
Verify that iterModules will honor the __path__ of already-loaded packages.
|
||||
"""
|
||||
self._underUnderPathTest()
|
||||
|
||||
|
||||
def test_underUnderPathNotAlreadyImported(self):
|
||||
"""
|
||||
Verify that iterModules will honor the __path__ of already-loaded packages.
|
||||
"""
|
||||
self._underUnderPathTest(False)
|
||||
|
||||
|
||||
test_underUnderPathNotAlreadyImported.todo = (
|
||||
"This may be impossible but it sure would be nice.")
|
||||
|
||||
|
||||
def _listModules(self):
|
||||
pkginfo = modules.getModule(self.packageName)
|
||||
nfni = [modinfo.name.split(".")[-1] for modinfo in
|
||||
pkginfo.iterModules()]
|
||||
nfni.sort()
|
||||
self.assertEqual(nfni, ['a', 'b', 'c__init__'])
|
||||
|
||||
|
||||
def test_listingModules(self):
|
||||
"""
|
||||
Make sure the module list comes back as we expect from iterModules on a
|
||||
package, whether zipped or not.
|
||||
"""
|
||||
self._setupSysPath()
|
||||
self._listModules()
|
||||
|
||||
|
||||
def test_listingModulesAlreadyImported(self):
|
||||
"""
|
||||
Make sure the module list comes back as we expect from iterModules on a
|
||||
package, whether zipped or not, even if the package has already been
|
||||
imported.
|
||||
"""
|
||||
self._setupSysPath()
|
||||
namedAny(self.packageName)
|
||||
self._listModules()
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
# Intentionally using 'assert' here, this is not a test assertion, this
|
||||
# is just an "oh fuck what is going ON" assertion. -glyph
|
||||
if self.pathSetUp:
|
||||
HORK = "path cleanup failed: don't be surprised if other tests break"
|
||||
assert sys.path.pop() is self.pathExtensionName, HORK+", 1"
|
||||
assert self.pathExtensionName not in sys.path, HORK+", 2"
|
||||
|
||||
|
||||
|
||||
class RebindingTest(PathModificationTest):
|
||||
"""
|
||||
These tests verify that the default path interrogation API works properly
|
||||
even when sys.path has been rebound to a different object.
|
||||
"""
|
||||
def _setupSysPath(self):
|
||||
assert not self.pathSetUp
|
||||
self.pathSetUp = True
|
||||
self.savedSysPath = sys.path
|
||||
sys.path = sys.path[:]
|
||||
sys.path.append(self.pathExtensionName)
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
"""
|
||||
Clean up sys.path by re-binding our original object.
|
||||
"""
|
||||
if self.pathSetUp:
|
||||
sys.path = self.savedSysPath
|
||||
|
||||
|
||||
|
||||
class ZipPathModificationTest(PathModificationTest):
|
||||
def _setupSysPath(self):
|
||||
assert not self.pathSetUp
|
||||
zipit(self.pathExtensionName, self.pathExtensionName+'.zip')
|
||||
self.pathExtensionName += '.zip'
|
||||
assert zipfile.is_zipfile(self.pathExtensionName)
|
||||
PathModificationTest._setupSysPath(self)
|
||||
|
||||
|
||||
class PythonPathTestCase(TestCase):
|
||||
"""
|
||||
Tests for the class which provides the implementation for all of the
|
||||
public API of L{twisted.python.modules}, L{PythonPath}.
|
||||
"""
|
||||
def test_unhandledImporter(self):
|
||||
"""
|
||||
Make sure that the behavior when encountering an unknown importer
|
||||
type is not catastrophic failure.
|
||||
"""
|
||||
class SecretImporter(object):
|
||||
pass
|
||||
|
||||
def hook(name):
|
||||
return SecretImporter()
|
||||
|
||||
syspath = ['example/path']
|
||||
sysmodules = {}
|
||||
syshooks = [hook]
|
||||
syscache = {}
|
||||
def sysloader(name):
|
||||
return None
|
||||
space = modules.PythonPath(
|
||||
syspath, sysmodules, syshooks, syscache, sysloader)
|
||||
entries = list(space.iterEntries())
|
||||
self.assertEqual(len(entries), 1)
|
||||
self.assertRaises(KeyError, lambda: entries[0]['module'])
|
||||
|
||||
|
||||
def test_inconsistentImporterCache(self):
|
||||
"""
|
||||
If the path a module loaded with L{PythonPath.__getitem__} is not
|
||||
present in the path importer cache, a warning is emitted, but the
|
||||
L{PythonModule} is returned as usual.
|
||||
"""
|
||||
space = modules.PythonPath([], sys.modules, [], {})
|
||||
thisModule = space[__name__]
|
||||
warnings = self.flushWarnings([self.test_inconsistentImporterCache])
|
||||
self.assertEqual(warnings[0]['category'], UserWarning)
|
||||
self.assertEqual(
|
||||
warnings[0]['message'],
|
||||
FilePath(twisted.__file__).parent().dirname() +
|
||||
" (for module " + __name__ + ") not in path importer cache "
|
||||
"(PEP 302 violation - check your local configuration).")
|
||||
self.assertEqual(len(warnings), 1)
|
||||
self.assertEqual(thisModule.name, __name__)
|
||||
|
||||
|
||||
def test_containsModule(self):
|
||||
"""
|
||||
L{PythonPath} implements the C{in} operator so that when it is the
|
||||
right-hand argument and the name of a module which exists on that
|
||||
L{PythonPath} is the left-hand argument, the result is C{True}.
|
||||
"""
|
||||
thePath = modules.PythonPath()
|
||||
self.assertIn('os', thePath)
|
||||
|
||||
|
||||
def test_doesntContainModule(self):
|
||||
"""
|
||||
L{PythonPath} implements the C{in} operator so that when it is the
|
||||
right-hand argument and the name of a module which does not exist on
|
||||
that L{PythonPath} is the left-hand argument, the result is C{False}.
|
||||
"""
|
||||
thePath = modules.PythonPath()
|
||||
self.assertNotIn('bogusModule', thePath)
|
||||
|
|
@ -0,0 +1,164 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.python.monkey}.
|
||||
"""
|
||||
|
||||
from __future__ import division, absolute_import
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.python.monkey import MonkeyPatcher
|
||||
|
||||
|
||||
class TestObj:
|
||||
def __init__(self):
|
||||
self.foo = 'foo value'
|
||||
self.bar = 'bar value'
|
||||
self.baz = 'baz value'
|
||||
|
||||
|
||||
|
||||
class MonkeyPatcherTest(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Tests for L{MonkeyPatcher} monkey-patching class.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.testObject = TestObj()
|
||||
self.originalObject = TestObj()
|
||||
self.monkeyPatcher = MonkeyPatcher()
|
||||
|
||||
|
||||
def test_empty(self):
|
||||
"""
|
||||
A monkey patcher without patches shouldn't change a thing.
|
||||
"""
|
||||
self.monkeyPatcher.patch()
|
||||
|
||||
# We can't assert that all state is unchanged, but at least we can
|
||||
# check our test object.
|
||||
self.assertEqual(self.originalObject.foo, self.testObject.foo)
|
||||
self.assertEqual(self.originalObject.bar, self.testObject.bar)
|
||||
self.assertEqual(self.originalObject.baz, self.testObject.baz)
|
||||
|
||||
|
||||
def test_constructWithPatches(self):
|
||||
"""
|
||||
Constructing a L{MonkeyPatcher} with patches should add all of the
|
||||
given patches to the patch list.
|
||||
"""
|
||||
patcher = MonkeyPatcher((self.testObject, 'foo', 'haha'),
|
||||
(self.testObject, 'bar', 'hehe'))
|
||||
patcher.patch()
|
||||
self.assertEqual('haha', self.testObject.foo)
|
||||
self.assertEqual('hehe', self.testObject.bar)
|
||||
self.assertEqual(self.originalObject.baz, self.testObject.baz)
|
||||
|
||||
|
||||
def test_patchExisting(self):
|
||||
"""
|
||||
Patching an attribute that exists sets it to the value defined in the
|
||||
patch.
|
||||
"""
|
||||
self.monkeyPatcher.addPatch(self.testObject, 'foo', 'haha')
|
||||
self.monkeyPatcher.patch()
|
||||
self.assertEqual(self.testObject.foo, 'haha')
|
||||
|
||||
|
||||
def test_patchNonExisting(self):
|
||||
"""
|
||||
Patching a non-existing attribute fails with an C{AttributeError}.
|
||||
"""
|
||||
self.monkeyPatcher.addPatch(self.testObject, 'nowhere',
|
||||
'blow up please')
|
||||
self.assertRaises(AttributeError, self.monkeyPatcher.patch)
|
||||
|
||||
|
||||
def test_patchAlreadyPatched(self):
|
||||
"""
|
||||
Adding a patch for an object and attribute that already have a patch
|
||||
overrides the existing patch.
|
||||
"""
|
||||
self.monkeyPatcher.addPatch(self.testObject, 'foo', 'blah')
|
||||
self.monkeyPatcher.addPatch(self.testObject, 'foo', 'BLAH')
|
||||
self.monkeyPatcher.patch()
|
||||
self.assertEqual(self.testObject.foo, 'BLAH')
|
||||
self.monkeyPatcher.restore()
|
||||
self.assertEqual(self.testObject.foo, self.originalObject.foo)
|
||||
|
||||
|
||||
def test_restoreTwiceIsANoOp(self):
|
||||
"""
|
||||
Restoring an already-restored monkey patch is a no-op.
|
||||
"""
|
||||
self.monkeyPatcher.addPatch(self.testObject, 'foo', 'blah')
|
||||
self.monkeyPatcher.patch()
|
||||
self.monkeyPatcher.restore()
|
||||
self.assertEqual(self.testObject.foo, self.originalObject.foo)
|
||||
self.monkeyPatcher.restore()
|
||||
self.assertEqual(self.testObject.foo, self.originalObject.foo)
|
||||
|
||||
|
||||
def test_runWithPatchesDecoration(self):
|
||||
"""
|
||||
runWithPatches should run the given callable, passing in all arguments
|
||||
and keyword arguments, and return the return value of the callable.
|
||||
"""
|
||||
log = []
|
||||
|
||||
def f(a, b, c=None):
|
||||
log.append((a, b, c))
|
||||
return 'foo'
|
||||
|
||||
result = self.monkeyPatcher.runWithPatches(f, 1, 2, c=10)
|
||||
self.assertEqual('foo', result)
|
||||
self.assertEqual([(1, 2, 10)], log)
|
||||
|
||||
|
||||
def test_repeatedRunWithPatches(self):
|
||||
"""
|
||||
We should be able to call the same function with runWithPatches more
|
||||
than once. All patches should apply for each call.
|
||||
"""
|
||||
def f():
|
||||
return (self.testObject.foo, self.testObject.bar,
|
||||
self.testObject.baz)
|
||||
|
||||
self.monkeyPatcher.addPatch(self.testObject, 'foo', 'haha')
|
||||
result = self.monkeyPatcher.runWithPatches(f)
|
||||
self.assertEqual(
|
||||
('haha', self.originalObject.bar, self.originalObject.baz), result)
|
||||
result = self.monkeyPatcher.runWithPatches(f)
|
||||
self.assertEqual(
|
||||
('haha', self.originalObject.bar, self.originalObject.baz),
|
||||
result)
|
||||
|
||||
|
||||
def test_runWithPatchesRestores(self):
|
||||
"""
|
||||
C{runWithPatches} should restore the original values after the function
|
||||
has executed.
|
||||
"""
|
||||
self.monkeyPatcher.addPatch(self.testObject, 'foo', 'haha')
|
||||
self.assertEqual(self.originalObject.foo, self.testObject.foo)
|
||||
self.monkeyPatcher.runWithPatches(lambda: None)
|
||||
self.assertEqual(self.originalObject.foo, self.testObject.foo)
|
||||
|
||||
|
||||
def test_runWithPatchesRestoresOnException(self):
|
||||
"""
|
||||
Test runWithPatches restores the original values even when the function
|
||||
raises an exception.
|
||||
"""
|
||||
def _():
|
||||
self.assertEqual(self.testObject.foo, 'haha')
|
||||
self.assertEqual(self.testObject.bar, 'blahblah')
|
||||
raise RuntimeError("Something went wrong!")
|
||||
|
||||
self.monkeyPatcher.addPatch(self.testObject, 'foo', 'haha')
|
||||
self.monkeyPatcher.addPatch(self.testObject, 'bar', 'blahblah')
|
||||
|
||||
self.assertRaises(RuntimeError, self.monkeyPatcher.runWithPatches, _)
|
||||
self.assertEqual(self.testObject.foo, self.originalObject.foo)
|
||||
self.assertEqual(self.testObject.bar, self.originalObject.bar)
|
||||
|
|
@ -0,0 +1,435 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.cred}, now with 30% more starch.
|
||||
"""
|
||||
|
||||
|
||||
import hmac
|
||||
from zope.interface import implements, Interface
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.cred import portal, checkers, credentials, error
|
||||
from twisted.python import components
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import deferredGenerator as dG, waitForDeferred as wFD
|
||||
|
||||
try:
|
||||
from crypt import crypt
|
||||
except ImportError:
|
||||
crypt = None
|
||||
|
||||
try:
|
||||
from twisted.cred import pamauth
|
||||
except ImportError:
|
||||
pamauth = None
|
||||
|
||||
|
||||
class ITestable(Interface):
|
||||
pass
|
||||
|
||||
class TestAvatar:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.loggedIn = False
|
||||
self.loggedOut = False
|
||||
|
||||
def login(self):
|
||||
assert not self.loggedIn
|
||||
self.loggedIn = True
|
||||
|
||||
def logout(self):
|
||||
self.loggedOut = True
|
||||
|
||||
class Testable(components.Adapter):
|
||||
implements(ITestable)
|
||||
|
||||
# components.Interface(TestAvatar).adaptWith(Testable, ITestable)
|
||||
|
||||
components.registerAdapter(Testable, TestAvatar, ITestable)
|
||||
|
||||
class IDerivedCredentials(credentials.IUsernamePassword):
|
||||
pass
|
||||
|
||||
class DerivedCredentials(object):
|
||||
implements(IDerivedCredentials, ITestable)
|
||||
|
||||
def __init__(self, username, password):
|
||||
self.username = username
|
||||
self.password = password
|
||||
|
||||
def checkPassword(self, password):
|
||||
return password == self.password
|
||||
|
||||
|
||||
class TestRealm:
|
||||
implements(portal.IRealm)
|
||||
def __init__(self):
|
||||
self.avatars = {}
|
||||
|
||||
def requestAvatar(self, avatarId, mind, *interfaces):
|
||||
if avatarId in self.avatars:
|
||||
avatar = self.avatars[avatarId]
|
||||
else:
|
||||
avatar = TestAvatar(avatarId)
|
||||
self.avatars[avatarId] = avatar
|
||||
avatar.login()
|
||||
return (interfaces[0], interfaces[0](avatar),
|
||||
avatar.logout)
|
||||
|
||||
class NewCredTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
r = self.realm = TestRealm()
|
||||
p = self.portal = portal.Portal(r)
|
||||
up = self.checker = checkers.InMemoryUsernamePasswordDatabaseDontUse()
|
||||
up.addUser("bob", "hello")
|
||||
p.registerChecker(up)
|
||||
|
||||
def testListCheckers(self):
|
||||
expected = [credentials.IUsernamePassword, credentials.IUsernameHashedPassword]
|
||||
got = self.portal.listCredentialsInterfaces()
|
||||
expected.sort()
|
||||
got.sort()
|
||||
self.assertEqual(got, expected)
|
||||
|
||||
def testBasicLogin(self):
|
||||
l = []; f = []
|
||||
self.portal.login(credentials.UsernamePassword("bob", "hello"),
|
||||
self, ITestable).addCallback(
|
||||
l.append).addErrback(f.append)
|
||||
if f:
|
||||
raise f[0]
|
||||
# print l[0].getBriefTraceback()
|
||||
iface, impl, logout = l[0]
|
||||
# whitebox
|
||||
self.assertEqual(iface, ITestable)
|
||||
self.failUnless(iface.providedBy(impl),
|
||||
"%s does not implement %s" % (impl, iface))
|
||||
# greybox
|
||||
self.failUnless(impl.original.loggedIn)
|
||||
self.failUnless(not impl.original.loggedOut)
|
||||
logout()
|
||||
self.failUnless(impl.original.loggedOut)
|
||||
|
||||
def test_derivedInterface(self):
|
||||
"""
|
||||
Login with credentials implementing an interface inheriting from an
|
||||
interface registered with a checker (but not itself registered).
|
||||
"""
|
||||
l = []
|
||||
f = []
|
||||
self.portal.login(DerivedCredentials("bob", "hello"), self, ITestable
|
||||
).addCallback(l.append
|
||||
).addErrback(f.append)
|
||||
if f:
|
||||
raise f[0]
|
||||
iface, impl, logout = l[0]
|
||||
# whitebox
|
||||
self.assertEqual(iface, ITestable)
|
||||
self.failUnless(iface.providedBy(impl),
|
||||
"%s does not implement %s" % (impl, iface))
|
||||
# greybox
|
||||
self.failUnless(impl.original.loggedIn)
|
||||
self.failUnless(not impl.original.loggedOut)
|
||||
logout()
|
||||
self.failUnless(impl.original.loggedOut)
|
||||
|
||||
def testFailedLogin(self):
|
||||
l = []
|
||||
self.portal.login(credentials.UsernamePassword("bob", "h3llo"),
|
||||
self, ITestable).addErrback(
|
||||
lambda x: x.trap(error.UnauthorizedLogin)).addCallback(l.append)
|
||||
self.failUnless(l)
|
||||
self.assertEqual(error.UnauthorizedLogin, l[0])
|
||||
|
||||
def testFailedLoginName(self):
|
||||
l = []
|
||||
self.portal.login(credentials.UsernamePassword("jay", "hello"),
|
||||
self, ITestable).addErrback(
|
||||
lambda x: x.trap(error.UnauthorizedLogin)).addCallback(l.append)
|
||||
self.failUnless(l)
|
||||
self.assertEqual(error.UnauthorizedLogin, l[0])
|
||||
|
||||
|
||||
class CramMD5CredentialsTestCase(unittest.TestCase):
|
||||
def testIdempotentChallenge(self):
|
||||
c = credentials.CramMD5Credentials()
|
||||
chal = c.getChallenge()
|
||||
self.assertEqual(chal, c.getChallenge())
|
||||
|
||||
def testCheckPassword(self):
|
||||
c = credentials.CramMD5Credentials()
|
||||
chal = c.getChallenge()
|
||||
c.response = hmac.HMAC('secret', chal).hexdigest()
|
||||
self.failUnless(c.checkPassword('secret'))
|
||||
|
||||
def testWrongPassword(self):
|
||||
c = credentials.CramMD5Credentials()
|
||||
self.failIf(c.checkPassword('secret'))
|
||||
|
||||
class OnDiskDatabaseTestCase(unittest.TestCase):
|
||||
users = [
|
||||
('user1', 'pass1'),
|
||||
('user2', 'pass2'),
|
||||
('user3', 'pass3'),
|
||||
]
|
||||
|
||||
|
||||
def testUserLookup(self):
|
||||
dbfile = self.mktemp()
|
||||
db = checkers.FilePasswordDB(dbfile)
|
||||
f = file(dbfile, 'w')
|
||||
for (u, p) in self.users:
|
||||
f.write('%s:%s\n' % (u, p))
|
||||
f.close()
|
||||
|
||||
for (u, p) in self.users:
|
||||
self.failUnlessRaises(KeyError, db.getUser, u.upper())
|
||||
self.assertEqual(db.getUser(u), (u, p))
|
||||
|
||||
def testCaseInSensitivity(self):
|
||||
dbfile = self.mktemp()
|
||||
db = checkers.FilePasswordDB(dbfile, caseSensitive=0)
|
||||
f = file(dbfile, 'w')
|
||||
for (u, p) in self.users:
|
||||
f.write('%s:%s\n' % (u, p))
|
||||
f.close()
|
||||
|
||||
for (u, p) in self.users:
|
||||
self.assertEqual(db.getUser(u.upper()), (u, p))
|
||||
|
||||
def testRequestAvatarId(self):
|
||||
dbfile = self.mktemp()
|
||||
db = checkers.FilePasswordDB(dbfile, caseSensitive=0)
|
||||
f = file(dbfile, 'w')
|
||||
for (u, p) in self.users:
|
||||
f.write('%s:%s\n' % (u, p))
|
||||
f.close()
|
||||
creds = [credentials.UsernamePassword(u, p) for u, p in self.users]
|
||||
d = defer.gatherResults(
|
||||
[defer.maybeDeferred(db.requestAvatarId, c) for c in creds])
|
||||
d.addCallback(self.assertEqual, [u for u, p in self.users])
|
||||
return d
|
||||
|
||||
def testRequestAvatarId_hashed(self):
|
||||
dbfile = self.mktemp()
|
||||
db = checkers.FilePasswordDB(dbfile, caseSensitive=0)
|
||||
f = file(dbfile, 'w')
|
||||
for (u, p) in self.users:
|
||||
f.write('%s:%s\n' % (u, p))
|
||||
f.close()
|
||||
creds = [credentials.UsernameHashedPassword(u, p) for u, p in self.users]
|
||||
d = defer.gatherResults(
|
||||
[defer.maybeDeferred(db.requestAvatarId, c) for c in creds])
|
||||
d.addCallback(self.assertEqual, [u for u, p in self.users])
|
||||
return d
|
||||
|
||||
|
||||
|
||||
class HashedPasswordOnDiskDatabaseTestCase(unittest.TestCase):
|
||||
users = [
|
||||
('user1', 'pass1'),
|
||||
('user2', 'pass2'),
|
||||
('user3', 'pass3'),
|
||||
]
|
||||
|
||||
|
||||
def hash(self, u, p, s):
|
||||
return crypt(p, s)
|
||||
|
||||
def setUp(self):
|
||||
dbfile = self.mktemp()
|
||||
self.db = checkers.FilePasswordDB(dbfile, hash=self.hash)
|
||||
f = file(dbfile, 'w')
|
||||
for (u, p) in self.users:
|
||||
f.write('%s:%s\n' % (u, crypt(p, u[:2])))
|
||||
f.close()
|
||||
r = TestRealm()
|
||||
self.port = portal.Portal(r)
|
||||
self.port.registerChecker(self.db)
|
||||
|
||||
def testGoodCredentials(self):
|
||||
goodCreds = [credentials.UsernamePassword(u, p) for u, p in self.users]
|
||||
d = defer.gatherResults([self.db.requestAvatarId(c) for c in goodCreds])
|
||||
d.addCallback(self.assertEqual, [u for u, p in self.users])
|
||||
return d
|
||||
|
||||
def testGoodCredentials_login(self):
|
||||
goodCreds = [credentials.UsernamePassword(u, p) for u, p in self.users]
|
||||
d = defer.gatherResults([self.port.login(c, None, ITestable)
|
||||
for c in goodCreds])
|
||||
d.addCallback(lambda x: [a.original.name for i, a, l in x])
|
||||
d.addCallback(self.assertEqual, [u for u, p in self.users])
|
||||
return d
|
||||
|
||||
def testBadCredentials(self):
|
||||
badCreds = [credentials.UsernamePassword(u, 'wrong password')
|
||||
for u, p in self.users]
|
||||
d = defer.DeferredList([self.port.login(c, None, ITestable)
|
||||
for c in badCreds], consumeErrors=True)
|
||||
d.addCallback(self._assertFailures, error.UnauthorizedLogin)
|
||||
return d
|
||||
|
||||
def testHashedCredentials(self):
|
||||
hashedCreds = [credentials.UsernameHashedPassword(u, crypt(p, u[:2]))
|
||||
for u, p in self.users]
|
||||
d = defer.DeferredList([self.port.login(c, None, ITestable)
|
||||
for c in hashedCreds], consumeErrors=True)
|
||||
d.addCallback(self._assertFailures, error.UnhandledCredentials)
|
||||
return d
|
||||
|
||||
def _assertFailures(self, failures, *expectedFailures):
|
||||
for flag, failure in failures:
|
||||
self.assertEqual(flag, defer.FAILURE)
|
||||
failure.trap(*expectedFailures)
|
||||
return None
|
||||
|
||||
if crypt is None:
|
||||
skip = "crypt module not available"
|
||||
|
||||
class PluggableAuthenticationModulesTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Replace L{pamauth.callIntoPAM} with a dummy implementation with
|
||||
easily-controlled behavior.
|
||||
"""
|
||||
self.patch(pamauth, 'callIntoPAM', self.callIntoPAM)
|
||||
|
||||
|
||||
def callIntoPAM(self, service, user, conv):
|
||||
if service != 'Twisted':
|
||||
raise error.UnauthorizedLogin('bad service: %s' % service)
|
||||
if user != 'testuser':
|
||||
raise error.UnauthorizedLogin('bad username: %s' % user)
|
||||
questions = [
|
||||
(1, "Password"),
|
||||
(2, "Message w/ Input"),
|
||||
(3, "Message w/o Input"),
|
||||
]
|
||||
replies = conv(questions)
|
||||
if replies != [
|
||||
("password", 0),
|
||||
("entry", 0),
|
||||
("", 0)
|
||||
]:
|
||||
raise error.UnauthorizedLogin('bad conversion: %s' % repr(replies))
|
||||
return 1
|
||||
|
||||
def _makeConv(self, d):
|
||||
def conv(questions):
|
||||
return defer.succeed([(d[t], 0) for t, q in questions])
|
||||
return conv
|
||||
|
||||
def testRequestAvatarId(self):
|
||||
db = checkers.PluggableAuthenticationModulesChecker()
|
||||
conv = self._makeConv({1:'password', 2:'entry', 3:''})
|
||||
creds = credentials.PluggableAuthenticationModules('testuser',
|
||||
conv)
|
||||
d = db.requestAvatarId(creds)
|
||||
d.addCallback(self.assertEqual, 'testuser')
|
||||
return d
|
||||
|
||||
def testBadCredentials(self):
|
||||
db = checkers.PluggableAuthenticationModulesChecker()
|
||||
conv = self._makeConv({1:'', 2:'', 3:''})
|
||||
creds = credentials.PluggableAuthenticationModules('testuser',
|
||||
conv)
|
||||
d = db.requestAvatarId(creds)
|
||||
self.assertFailure(d, error.UnauthorizedLogin)
|
||||
return d
|
||||
|
||||
def testBadUsername(self):
|
||||
db = checkers.PluggableAuthenticationModulesChecker()
|
||||
conv = self._makeConv({1:'password', 2:'entry', 3:''})
|
||||
creds = credentials.PluggableAuthenticationModules('baduser',
|
||||
conv)
|
||||
d = db.requestAvatarId(creds)
|
||||
self.assertFailure(d, error.UnauthorizedLogin)
|
||||
return d
|
||||
|
||||
if not pamauth:
|
||||
skip = "Can't run without PyPAM"
|
||||
|
||||
class CheckersMixin:
|
||||
def testPositive(self):
|
||||
for chk in self.getCheckers():
|
||||
for (cred, avatarId) in self.getGoodCredentials():
|
||||
r = wFD(chk.requestAvatarId(cred))
|
||||
yield r
|
||||
self.assertEqual(r.getResult(), avatarId)
|
||||
testPositive = dG(testPositive)
|
||||
|
||||
def testNegative(self):
|
||||
for chk in self.getCheckers():
|
||||
for cred in self.getBadCredentials():
|
||||
r = wFD(chk.requestAvatarId(cred))
|
||||
yield r
|
||||
self.assertRaises(error.UnauthorizedLogin, r.getResult)
|
||||
testNegative = dG(testNegative)
|
||||
|
||||
class HashlessFilePasswordDBMixin:
|
||||
credClass = credentials.UsernamePassword
|
||||
diskHash = None
|
||||
networkHash = staticmethod(lambda x: x)
|
||||
|
||||
_validCredentials = [
|
||||
('user1', 'password1'),
|
||||
('user2', 'password2'),
|
||||
('user3', 'password3')]
|
||||
|
||||
def getGoodCredentials(self):
|
||||
for u, p in self._validCredentials:
|
||||
yield self.credClass(u, self.networkHash(p)), u
|
||||
|
||||
def getBadCredentials(self):
|
||||
for u, p in [('user1', 'password3'),
|
||||
('user2', 'password1'),
|
||||
('bloof', 'blarf')]:
|
||||
yield self.credClass(u, self.networkHash(p))
|
||||
|
||||
def getCheckers(self):
|
||||
diskHash = self.diskHash or (lambda x: x)
|
||||
hashCheck = self.diskHash and (lambda username, password, stored: self.diskHash(password))
|
||||
|
||||
for cache in True, False:
|
||||
fn = self.mktemp()
|
||||
fObj = file(fn, 'w')
|
||||
for u, p in self._validCredentials:
|
||||
fObj.write('%s:%s\n' % (u, diskHash(p)))
|
||||
fObj.close()
|
||||
yield checkers.FilePasswordDB(fn, cache=cache, hash=hashCheck)
|
||||
|
||||
fn = self.mktemp()
|
||||
fObj = file(fn, 'w')
|
||||
for u, p in self._validCredentials:
|
||||
fObj.write('%s dingle dongle %s\n' % (diskHash(p), u))
|
||||
fObj.close()
|
||||
yield checkers.FilePasswordDB(fn, ' ', 3, 0, cache=cache, hash=hashCheck)
|
||||
|
||||
fn = self.mktemp()
|
||||
fObj = file(fn, 'w')
|
||||
for u, p in self._validCredentials:
|
||||
fObj.write('zip,zap,%s,zup,%s\n' % (u.title(), diskHash(p)))
|
||||
fObj.close()
|
||||
yield checkers.FilePasswordDB(fn, ',', 2, 4, False, cache=cache, hash=hashCheck)
|
||||
|
||||
class LocallyHashedFilePasswordDBMixin(HashlessFilePasswordDBMixin):
|
||||
diskHash = staticmethod(lambda x: x.encode('hex'))
|
||||
|
||||
class NetworkHashedFilePasswordDBMixin(HashlessFilePasswordDBMixin):
|
||||
networkHash = staticmethod(lambda x: x.encode('hex'))
|
||||
class credClass(credentials.UsernameHashedPassword):
|
||||
def checkPassword(self, password):
|
||||
return self.hashed.decode('hex') == password
|
||||
|
||||
class HashlessFilePasswordDBCheckerTestCase(HashlessFilePasswordDBMixin, CheckersMixin, unittest.TestCase):
|
||||
pass
|
||||
|
||||
class LocallyHashedFilePasswordDBCheckerTestCase(LocallyHashedFilePasswordDBMixin, CheckersMixin, unittest.TestCase):
|
||||
pass
|
||||
|
||||
class NetworkHashedFilePasswordDBCheckerTestCase(NetworkHashedFilePasswordDBMixin, CheckersMixin, unittest.TestCase):
|
||||
pass
|
||||
|
||||
115
Linux_i686/lib/python2.7/site-packages/twisted/test/test_nmea.py
Normal file
115
Linux_i686/lib/python2.7/site-packages/twisted/test/test_nmea.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""Test cases for the NMEA GPS protocol"""
|
||||
|
||||
import StringIO
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import reactor, protocol
|
||||
from twisted.python import reflect
|
||||
|
||||
from twisted.protocols.gps import nmea
|
||||
|
||||
class StringIOWithNoClose(StringIO.StringIO):
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
class ResultHarvester:
|
||||
def __init__(self):
|
||||
self.results = []
|
||||
|
||||
def __call__(self, *args):
|
||||
self.results.append(args)
|
||||
|
||||
def performTest(self, function, *args, **kwargs):
|
||||
l = len(self.results)
|
||||
try:
|
||||
function(*args, **kwargs)
|
||||
except Exception, e:
|
||||
self.results.append(e)
|
||||
if l == len(self.results):
|
||||
self.results.append(NotImplementedError())
|
||||
|
||||
class NMEATester(nmea.NMEAReceiver):
|
||||
ignore_invalid_sentence = 0
|
||||
ignore_checksum_mismatch = 0
|
||||
ignore_unknown_sentencetypes = 0
|
||||
convert_dates_before_y2k = 1
|
||||
|
||||
def connectionMade(self):
|
||||
self.resultHarvester = ResultHarvester()
|
||||
for fn in reflect.prefixedMethodNames(self.__class__, 'decode_'):
|
||||
setattr(self, 'handle_' + fn, self.resultHarvester)
|
||||
|
||||
class NMEAReceiverTestCase(unittest.TestCase):
|
||||
messages = (
|
||||
# fix - signal acquired
|
||||
"$GPGGA,231713.0,3910.413,N,07641.994,W,1,05,1.35,00044,M,-033,M,,*69",
|
||||
# fix - signal not acquired
|
||||
"$GPGGA,235947.000,0000.0000,N,00000.0000,E,0,00,0.0,0.0,M,,,,0000*00",
|
||||
# junk
|
||||
"lkjasdfkl!@#(*$!@(*#(ASDkfjasdfLMASDCVKAW!@#($)!(@#)(*",
|
||||
# fix - signal acquired (invalid checksum)
|
||||
"$GPGGA,231713.0,3910.413,N,07641.994,W,1,05,1.35,00044,M,-033,M,,*68",
|
||||
# invalid sentence
|
||||
"$GPGGX,231713.0,3910.413,N,07641.994,W,1,05,1.35,00044,M,-033,M,,*68",
|
||||
# position acquired
|
||||
"$GPGLL,4250.5589,S,14718.5084,E,092204.999,A*2D",
|
||||
# position not acquired
|
||||
"$GPGLL,0000.0000,N,00000.0000,E,235947.000,V*2D",
|
||||
# active satellites (no fix)
|
||||
"$GPGSA,A,1,,,,,,,,,,,,,0.0,0.0,0.0*30",
|
||||
# active satellites
|
||||
"$GPGSA,A,3,01,20,19,13,,,,,,,,,40.4,24.4,32.2*0A",
|
||||
# positiontime (no fix)
|
||||
"$GPRMC,235947.000,V,0000.0000,N,00000.0000,E,,,041299,,*1D",
|
||||
# positiontime
|
||||
"$GPRMC,092204.999,A,4250.5589,S,14718.5084,E,0.00,89.68,211200,,*25",
|
||||
# course over ground (no fix - not implemented)
|
||||
"$GPVTG,,T,,M,,N,,K*4E",
|
||||
# course over ground (not implemented)
|
||||
"$GPVTG,89.68,T,,M,0.00,N,0.0,K*5F",
|
||||
)
|
||||
results = (
|
||||
(83833.0, 39.17355, -76.6999, nmea.POSFIX_SPS, 5, 1.35, (44.0, 'M'), (-33.0, 'M'), None),
|
||||
(86387.0, 0.0, 0.0, 0, 0, 0.0, (0.0, 'M'), None, None),
|
||||
nmea.InvalidSentence(),
|
||||
nmea.InvalidChecksum(),
|
||||
nmea.InvalidSentence(),
|
||||
(-42.842648333333337, 147.30847333333332, 33724.999000000003, 1),
|
||||
(0.0, 0.0, 86387.0, 0),
|
||||
((None, None, None, None, None, None, None, None, None, None, None, None), (nmea.MODE_AUTO, nmea.MODE_NOFIX), 0.0, 0.0, 0.0),
|
||||
((1, 20, 19, 13, None, None, None, None, None, None, None, None), (nmea.MODE_AUTO, nmea.MODE_3D), 40.4, 24.4, 32.2),
|
||||
(0.0, 0.0, None, None, 86387.0, (1999, 12, 4), None),
|
||||
(-42.842648333333337, 147.30847333333332, 0.0, 89.68, 33724.999, (2000, 12, 21), None),
|
||||
NotImplementedError(),
|
||||
NotImplementedError(),
|
||||
)
|
||||
def testGPSMessages(self):
|
||||
dummy = NMEATester()
|
||||
dummy.makeConnection(protocol.FileWrapper(StringIOWithNoClose()))
|
||||
for line in self.messages:
|
||||
dummy.resultHarvester.performTest(dummy.lineReceived, line)
|
||||
def munge(myTuple):
|
||||
if type(myTuple) != type(()):
|
||||
return
|
||||
newTuple = []
|
||||
for v in myTuple:
|
||||
if type(v) == type(1.1):
|
||||
v = float(int(v * 10000.0)) * 0.0001
|
||||
newTuple.append(v)
|
||||
return tuple(newTuple)
|
||||
for (message, expectedResult, actualResult) in zip(self.messages, self.results, dummy.resultHarvester.results):
|
||||
expectedResult = munge(expectedResult)
|
||||
actualResult = munge(actualResult)
|
||||
if isinstance(expectedResult, Exception):
|
||||
if isinstance(actualResult, Exception):
|
||||
self.assertEqual(expectedResult.__class__, actualResult.__class__, "\nInput:\n%s\nExpected:\n%s.%s\nResults:\n%s.%s\n" % (message, expectedResult.__class__.__module__, expectedResult.__class__.__name__, actualResult.__class__.__module__, actualResult.__class__.__name__))
|
||||
else:
|
||||
self.assertEqual(1, 0, "\nInput:\n%s\nExpected:\n%s.%s\nResults:\n%r\n" % (message, expectedResult.__class__.__module__, expectedResult.__class__.__name__, actualResult))
|
||||
else:
|
||||
self.assertEqual(expectedResult, actualResult, "\nInput:\n%s\nExpected: %r\nResults: %r\n" % (message, expectedResult, actualResult))
|
||||
|
||||
testCases = [NMEAReceiverTestCase]
|
||||
1558
Linux_i686/lib/python2.7/site-packages/twisted/test/test_paths.py
Normal file
1558
Linux_i686/lib/python2.7/site-packages/twisted/test/test_paths.py
Normal file
File diff suppressed because it is too large
Load diff
1846
Linux_i686/lib/python2.7/site-packages/twisted/test/test_pb.py
Normal file
1846
Linux_i686/lib/python2.7/site-packages/twisted/test/test_pb.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,469 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for error handling in PB.
|
||||
"""
|
||||
|
||||
from StringIO import StringIO
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.spread import pb, flavors, jelly
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.python import log
|
||||
|
||||
##
|
||||
# test exceptions
|
||||
##
|
||||
class AsynchronousException(Exception):
|
||||
"""
|
||||
Helper used to test remote methods which return Deferreds which fail with
|
||||
exceptions which are not L{pb.Error} subclasses.
|
||||
"""
|
||||
|
||||
|
||||
class SynchronousException(Exception):
|
||||
"""
|
||||
Helper used to test remote methods which raise exceptions which are not
|
||||
L{pb.Error} subclasses.
|
||||
"""
|
||||
|
||||
|
||||
class AsynchronousError(pb.Error):
|
||||
"""
|
||||
Helper used to test remote methods which return Deferreds which fail with
|
||||
exceptions which are L{pb.Error} subclasses.
|
||||
"""
|
||||
|
||||
|
||||
class SynchronousError(pb.Error):
|
||||
"""
|
||||
Helper used to test remote methods which raise exceptions which are
|
||||
L{pb.Error} subclasses.
|
||||
"""
|
||||
|
||||
|
||||
#class JellyError(flavors.Jellyable, pb.Error): pass
|
||||
class JellyError(flavors.Jellyable, pb.Error, pb.RemoteCopy):
|
||||
pass
|
||||
|
||||
|
||||
class SecurityError(pb.Error, pb.RemoteCopy):
|
||||
pass
|
||||
|
||||
pb.setUnjellyableForClass(JellyError, JellyError)
|
||||
pb.setUnjellyableForClass(SecurityError, SecurityError)
|
||||
pb.globalSecurity.allowInstancesOf(SecurityError)
|
||||
|
||||
|
||||
####
|
||||
# server-side
|
||||
####
|
||||
class SimpleRoot(pb.Root):
|
||||
def remote_asynchronousException(self):
|
||||
"""
|
||||
Fail asynchronously with a non-pb.Error exception.
|
||||
"""
|
||||
return defer.fail(AsynchronousException("remote asynchronous exception"))
|
||||
|
||||
def remote_synchronousException(self):
|
||||
"""
|
||||
Fail synchronously with a non-pb.Error exception.
|
||||
"""
|
||||
raise SynchronousException("remote synchronous exception")
|
||||
|
||||
def remote_asynchronousError(self):
|
||||
"""
|
||||
Fail asynchronously with a pb.Error exception.
|
||||
"""
|
||||
return defer.fail(AsynchronousError("remote asynchronous error"))
|
||||
|
||||
def remote_synchronousError(self):
|
||||
"""
|
||||
Fail synchronously with a pb.Error exception.
|
||||
"""
|
||||
raise SynchronousError("remote synchronous error")
|
||||
|
||||
def remote_unknownError(self):
|
||||
"""
|
||||
Fail with error that is not known to client.
|
||||
"""
|
||||
class UnknownError(pb.Error):
|
||||
pass
|
||||
raise UnknownError("I'm not known to client!")
|
||||
|
||||
def remote_jelly(self):
|
||||
self.raiseJelly()
|
||||
|
||||
def remote_security(self):
|
||||
self.raiseSecurity()
|
||||
|
||||
def remote_deferredJelly(self):
|
||||
d = defer.Deferred()
|
||||
d.addCallback(self.raiseJelly)
|
||||
d.callback(None)
|
||||
return d
|
||||
|
||||
def remote_deferredSecurity(self):
|
||||
d = defer.Deferred()
|
||||
d.addCallback(self.raiseSecurity)
|
||||
d.callback(None)
|
||||
return d
|
||||
|
||||
def raiseJelly(self, results=None):
|
||||
raise JellyError("I'm jellyable!")
|
||||
|
||||
def raiseSecurity(self, results=None):
|
||||
raise SecurityError("I'm secure!")
|
||||
|
||||
|
||||
|
||||
class SaveProtocolServerFactory(pb.PBServerFactory):
|
||||
"""
|
||||
A L{pb.PBServerFactory} that saves the latest connected client in
|
||||
C{protocolInstance}.
|
||||
"""
|
||||
protocolInstance = None
|
||||
|
||||
def clientConnectionMade(self, protocol):
|
||||
"""
|
||||
Keep track of the given protocol.
|
||||
"""
|
||||
self.protocolInstance = protocol
|
||||
|
||||
|
||||
|
||||
class PBConnTestCase(unittest.TestCase):
|
||||
unsafeTracebacks = 0
|
||||
|
||||
def setUp(self):
|
||||
self._setUpServer()
|
||||
self._setUpClient()
|
||||
|
||||
def _setUpServer(self):
|
||||
self.serverFactory = SaveProtocolServerFactory(SimpleRoot())
|
||||
self.serverFactory.unsafeTracebacks = self.unsafeTracebacks
|
||||
self.serverPort = reactor.listenTCP(0, self.serverFactory, interface="127.0.0.1")
|
||||
|
||||
def _setUpClient(self):
|
||||
portNo = self.serverPort.getHost().port
|
||||
self.clientFactory = pb.PBClientFactory()
|
||||
self.clientConnector = reactor.connectTCP("127.0.0.1", portNo, self.clientFactory)
|
||||
|
||||
def tearDown(self):
|
||||
if self.serverFactory.protocolInstance is not None:
|
||||
self.serverFactory.protocolInstance.transport.loseConnection()
|
||||
return defer.gatherResults([
|
||||
self._tearDownServer(),
|
||||
self._tearDownClient()])
|
||||
|
||||
def _tearDownServer(self):
|
||||
return defer.maybeDeferred(self.serverPort.stopListening)
|
||||
|
||||
def _tearDownClient(self):
|
||||
self.clientConnector.disconnect()
|
||||
return defer.succeed(None)
|
||||
|
||||
|
||||
|
||||
class PBFailureTest(PBConnTestCase):
|
||||
compare = unittest.TestCase.assertEqual
|
||||
|
||||
|
||||
def _exceptionTest(self, method, exceptionType, flush):
|
||||
def eb(err):
|
||||
err.trap(exceptionType)
|
||||
self.compare(err.traceback, "Traceback unavailable\n")
|
||||
if flush:
|
||||
errs = self.flushLoggedErrors(exceptionType)
|
||||
self.assertEqual(len(errs), 1)
|
||||
return (err.type, err.value, err.traceback)
|
||||
d = self.clientFactory.getRootObject()
|
||||
def gotRootObject(root):
|
||||
d = root.callRemote(method)
|
||||
d.addErrback(eb)
|
||||
return d
|
||||
d.addCallback(gotRootObject)
|
||||
return d
|
||||
|
||||
|
||||
def test_asynchronousException(self):
|
||||
"""
|
||||
Test that a Deferred returned by a remote method which already has a
|
||||
Failure correctly has that error passed back to the calling side.
|
||||
"""
|
||||
return self._exceptionTest(
|
||||
'asynchronousException', AsynchronousException, True)
|
||||
|
||||
|
||||
def test_synchronousException(self):
|
||||
"""
|
||||
Like L{test_asynchronousException}, but for a method which raises an
|
||||
exception synchronously.
|
||||
"""
|
||||
return self._exceptionTest(
|
||||
'synchronousException', SynchronousException, True)
|
||||
|
||||
|
||||
def test_asynchronousError(self):
|
||||
"""
|
||||
Like L{test_asynchronousException}, but for a method which returns a
|
||||
Deferred failing with an L{pb.Error} subclass.
|
||||
"""
|
||||
return self._exceptionTest(
|
||||
'asynchronousError', AsynchronousError, False)
|
||||
|
||||
|
||||
def test_synchronousError(self):
|
||||
"""
|
||||
Like L{test_asynchronousError}, but for a method which synchronously
|
||||
raises a L{pb.Error} subclass.
|
||||
"""
|
||||
return self._exceptionTest(
|
||||
'synchronousError', SynchronousError, False)
|
||||
|
||||
|
||||
def _success(self, result, expectedResult):
|
||||
self.assertEqual(result, expectedResult)
|
||||
return result
|
||||
|
||||
|
||||
def _addFailingCallbacks(self, remoteCall, expectedResult, eb):
|
||||
remoteCall.addCallbacks(self._success, eb,
|
||||
callbackArgs=(expectedResult,))
|
||||
return remoteCall
|
||||
|
||||
|
||||
def _testImpl(self, method, expected, eb, exc=None):
|
||||
"""
|
||||
Call the given remote method and attach the given errback to the
|
||||
resulting Deferred. If C{exc} is not None, also assert that one
|
||||
exception of that type was logged.
|
||||
"""
|
||||
rootDeferred = self.clientFactory.getRootObject()
|
||||
def gotRootObj(obj):
|
||||
failureDeferred = self._addFailingCallbacks(obj.callRemote(method), expected, eb)
|
||||
if exc is not None:
|
||||
def gotFailure(err):
|
||||
self.assertEqual(len(self.flushLoggedErrors(exc)), 1)
|
||||
return err
|
||||
failureDeferred.addBoth(gotFailure)
|
||||
return failureDeferred
|
||||
rootDeferred.addCallback(gotRootObj)
|
||||
return rootDeferred
|
||||
|
||||
|
||||
def test_jellyFailure(self):
|
||||
"""
|
||||
Test that an exception which is a subclass of L{pb.Error} has more
|
||||
information passed across the network to the calling side.
|
||||
"""
|
||||
def failureJelly(fail):
|
||||
fail.trap(JellyError)
|
||||
self.failIf(isinstance(fail.type, str))
|
||||
self.failUnless(isinstance(fail.value, fail.type))
|
||||
return 43
|
||||
return self._testImpl('jelly', 43, failureJelly)
|
||||
|
||||
|
||||
def test_deferredJellyFailure(self):
|
||||
"""
|
||||
Test that a Deferred which fails with a L{pb.Error} is treated in
|
||||
the same way as a synchronously raised L{pb.Error}.
|
||||
"""
|
||||
def failureDeferredJelly(fail):
|
||||
fail.trap(JellyError)
|
||||
self.failIf(isinstance(fail.type, str))
|
||||
self.failUnless(isinstance(fail.value, fail.type))
|
||||
return 430
|
||||
return self._testImpl('deferredJelly', 430, failureDeferredJelly)
|
||||
|
||||
|
||||
def test_unjellyableFailure(self):
|
||||
"""
|
||||
An non-jellyable L{pb.Error} subclass raised by a remote method is
|
||||
turned into a Failure with a type set to the FQPN of the exception
|
||||
type.
|
||||
"""
|
||||
def failureUnjellyable(fail):
|
||||
self.assertEqual(
|
||||
fail.type, 'twisted.test.test_pbfailure.SynchronousError')
|
||||
return 431
|
||||
return self._testImpl('synchronousError', 431, failureUnjellyable)
|
||||
|
||||
|
||||
def test_unknownFailure(self):
|
||||
"""
|
||||
Test that an exception which is a subclass of L{pb.Error} but not
|
||||
known on the client side has its type set properly.
|
||||
"""
|
||||
def failureUnknown(fail):
|
||||
self.assertEqual(
|
||||
fail.type, 'twisted.test.test_pbfailure.UnknownError')
|
||||
return 4310
|
||||
return self._testImpl('unknownError', 4310, failureUnknown)
|
||||
|
||||
|
||||
def test_securityFailure(self):
|
||||
"""
|
||||
Test that even if an exception is not explicitly jellyable (by being
|
||||
a L{pb.Jellyable} subclass), as long as it is an L{pb.Error}
|
||||
subclass it receives the same special treatment.
|
||||
"""
|
||||
def failureSecurity(fail):
|
||||
fail.trap(SecurityError)
|
||||
self.failIf(isinstance(fail.type, str))
|
||||
self.failUnless(isinstance(fail.value, fail.type))
|
||||
return 4300
|
||||
return self._testImpl('security', 4300, failureSecurity)
|
||||
|
||||
|
||||
def test_deferredSecurity(self):
|
||||
"""
|
||||
Test that a Deferred which fails with a L{pb.Error} which is not
|
||||
also a L{pb.Jellyable} is treated in the same way as a synchronously
|
||||
raised exception of the same type.
|
||||
"""
|
||||
def failureDeferredSecurity(fail):
|
||||
fail.trap(SecurityError)
|
||||
self.failIf(isinstance(fail.type, str))
|
||||
self.failUnless(isinstance(fail.value, fail.type))
|
||||
return 43000
|
||||
return self._testImpl('deferredSecurity', 43000, failureDeferredSecurity)
|
||||
|
||||
|
||||
def test_noSuchMethodFailure(self):
|
||||
"""
|
||||
Test that attempting to call a method which is not defined correctly
|
||||
results in an AttributeError on the calling side.
|
||||
"""
|
||||
def failureNoSuch(fail):
|
||||
fail.trap(pb.NoSuchMethod)
|
||||
self.compare(fail.traceback, "Traceback unavailable\n")
|
||||
return 42000
|
||||
return self._testImpl('nosuch', 42000, failureNoSuch, AttributeError)
|
||||
|
||||
|
||||
def test_copiedFailureLogging(self):
|
||||
"""
|
||||
Test that a copied failure received from a PB call can be logged
|
||||
locally.
|
||||
|
||||
Note: this test needs some serious help: all it really tests is that
|
||||
log.err(copiedFailure) doesn't raise an exception.
|
||||
"""
|
||||
d = self.clientFactory.getRootObject()
|
||||
|
||||
def connected(rootObj):
|
||||
return rootObj.callRemote('synchronousException')
|
||||
d.addCallback(connected)
|
||||
|
||||
def exception(failure):
|
||||
log.err(failure)
|
||||
errs = self.flushLoggedErrors(SynchronousException)
|
||||
self.assertEqual(len(errs), 2)
|
||||
d.addErrback(exception)
|
||||
|
||||
return d
|
||||
|
||||
|
||||
def test_throwExceptionIntoGenerator(self):
|
||||
"""
|
||||
L{pb.CopiedFailure.throwExceptionIntoGenerator} will throw a
|
||||
L{RemoteError} into the given paused generator at the point where it
|
||||
last yielded.
|
||||
"""
|
||||
original = pb.CopyableFailure(AttributeError("foo"))
|
||||
copy = jelly.unjelly(jelly.jelly(original, invoker=DummyInvoker()))
|
||||
exception = []
|
||||
def generatorFunc():
|
||||
try:
|
||||
yield None
|
||||
except pb.RemoteError, exc:
|
||||
exception.append(exc)
|
||||
else:
|
||||
self.fail("RemoteError not raised")
|
||||
gen = generatorFunc()
|
||||
gen.send(None)
|
||||
self.assertRaises(StopIteration, copy.throwExceptionIntoGenerator, gen)
|
||||
self.assertEqual(len(exception), 1)
|
||||
exc = exception[0]
|
||||
self.assertEqual(exc.remoteType, "exceptions.AttributeError")
|
||||
self.assertEqual(exc.args, ("foo",))
|
||||
self.assertEqual(exc.remoteTraceback, 'Traceback unavailable\n')
|
||||
|
||||
|
||||
|
||||
class PBFailureTestUnsafe(PBFailureTest):
|
||||
compare = unittest.TestCase.failIfEquals
|
||||
unsafeTracebacks = 1
|
||||
|
||||
|
||||
|
||||
class DummyInvoker(object):
|
||||
"""
|
||||
A behaviorless object to be used as the invoker parameter to
|
||||
L{jelly.jelly}.
|
||||
"""
|
||||
serializingPerspective = None
|
||||
|
||||
|
||||
|
||||
class FailureJellyingTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for the interaction of jelly and failures.
|
||||
"""
|
||||
def test_unjelliedFailureCheck(self):
|
||||
"""
|
||||
An unjellied L{CopyableFailure} has a check method which behaves the
|
||||
same way as the original L{CopyableFailure}'s check method.
|
||||
"""
|
||||
original = pb.CopyableFailure(ZeroDivisionError())
|
||||
self.assertIdentical(
|
||||
original.check(ZeroDivisionError), ZeroDivisionError)
|
||||
self.assertIdentical(original.check(ArithmeticError), ArithmeticError)
|
||||
copied = jelly.unjelly(jelly.jelly(original, invoker=DummyInvoker()))
|
||||
self.assertIdentical(
|
||||
copied.check(ZeroDivisionError), ZeroDivisionError)
|
||||
self.assertIdentical(copied.check(ArithmeticError), ArithmeticError)
|
||||
|
||||
|
||||
def test_twiceUnjelliedFailureCheck(self):
|
||||
"""
|
||||
The object which results from jellying a L{CopyableFailure}, unjellying
|
||||
the result, creating a new L{CopyableFailure} from the result of that,
|
||||
jellying it, and finally unjellying the result of that has a check
|
||||
method which behaves the same way as the original L{CopyableFailure}'s
|
||||
check method.
|
||||
"""
|
||||
original = pb.CopyableFailure(ZeroDivisionError())
|
||||
self.assertIdentical(
|
||||
original.check(ZeroDivisionError), ZeroDivisionError)
|
||||
self.assertIdentical(original.check(ArithmeticError), ArithmeticError)
|
||||
copiedOnce = jelly.unjelly(
|
||||
jelly.jelly(original, invoker=DummyInvoker()))
|
||||
derivative = pb.CopyableFailure(copiedOnce)
|
||||
copiedTwice = jelly.unjelly(
|
||||
jelly.jelly(derivative, invoker=DummyInvoker()))
|
||||
self.assertIdentical(
|
||||
copiedTwice.check(ZeroDivisionError), ZeroDivisionError)
|
||||
self.assertIdentical(
|
||||
copiedTwice.check(ArithmeticError), ArithmeticError)
|
||||
|
||||
|
||||
def test_printTracebackIncludesValue(self):
|
||||
"""
|
||||
When L{CopiedFailure.printTraceback} is used to print a copied failure
|
||||
which was unjellied from a L{CopyableFailure} with C{unsafeTracebacks}
|
||||
set to C{False}, the string representation of the exception value is
|
||||
included in the output.
|
||||
"""
|
||||
original = pb.CopyableFailure(Exception("some reason"))
|
||||
copied = jelly.unjelly(jelly.jelly(original, invoker=DummyInvoker()))
|
||||
output = StringIO()
|
||||
copied.printTraceback(output)
|
||||
self.assertEqual(
|
||||
"Traceback from remote host -- Traceback unavailable\n"
|
||||
"exceptions.Exception: some reason\n",
|
||||
output.getvalue())
|
||||
|
||||
368
Linux_i686/lib/python2.7/site-packages/twisted/test/test_pcp.py
Normal file
368
Linux_i686/lib/python2.7/site-packages/twisted/test/test_pcp.py
Normal file
|
|
@ -0,0 +1,368 @@
|
|||
# -*- Python -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
__version__ = '$Revision: 1.5 $'[11:-2]
|
||||
|
||||
from StringIO import StringIO
|
||||
from twisted.trial import unittest
|
||||
from twisted.protocols import pcp
|
||||
|
||||
# Goal:
|
||||
|
||||
# Take a Protocol instance. Own all outgoing data - anything that
|
||||
# would go to p.transport.write. Own all incoming data - anything
|
||||
# that comes to p.dataReceived.
|
||||
|
||||
# I need:
|
||||
# Something with the AbstractFileDescriptor interface.
|
||||
# That is:
|
||||
# - acts as a Transport
|
||||
# - has a method write()
|
||||
# - which buffers
|
||||
# - acts as a Consumer
|
||||
# - has a registerProducer, unRegisterProducer
|
||||
# - tells the Producer to back off (pauseProducing) when its buffer is full.
|
||||
# - tells the Producer to resumeProducing when its buffer is not so full.
|
||||
# - acts as a Producer
|
||||
# - calls registerProducer
|
||||
# - calls write() on consumers
|
||||
# - honors requests to pause/resume producing
|
||||
# - honors stopProducing, and passes it along to upstream Producers
|
||||
|
||||
|
||||
class DummyTransport:
|
||||
"""A dumb transport to wrap around."""
|
||||
|
||||
def __init__(self):
|
||||
self._writes = []
|
||||
|
||||
def write(self, data):
|
||||
self._writes.append(data)
|
||||
|
||||
def getvalue(self):
|
||||
return ''.join(self._writes)
|
||||
|
||||
class DummyProducer:
|
||||
resumed = False
|
||||
stopped = False
|
||||
paused = False
|
||||
|
||||
def __init__(self, consumer):
|
||||
self.consumer = consumer
|
||||
|
||||
def resumeProducing(self):
|
||||
self.resumed = True
|
||||
self.paused = False
|
||||
|
||||
def pauseProducing(self):
|
||||
self.paused = True
|
||||
|
||||
def stopProducing(self):
|
||||
self.stopped = True
|
||||
|
||||
|
||||
class DummyConsumer(DummyTransport):
|
||||
producer = None
|
||||
finished = False
|
||||
unregistered = True
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
self.producer = (producer, streaming)
|
||||
|
||||
def unregisterProducer(self):
|
||||
self.unregistered = True
|
||||
|
||||
def finish(self):
|
||||
self.finished = True
|
||||
|
||||
class TransportInterfaceTest(unittest.TestCase):
|
||||
proxyClass = pcp.BasicProducerConsumerProxy
|
||||
|
||||
def setUp(self):
|
||||
self.underlying = DummyConsumer()
|
||||
self.transport = self.proxyClass(self.underlying)
|
||||
|
||||
def testWrite(self):
|
||||
self.transport.write("some bytes")
|
||||
|
||||
class ConsumerInterfaceTest:
|
||||
"""Test ProducerConsumerProxy as a Consumer.
|
||||
|
||||
Normally we have ProducingServer -> ConsumingTransport.
|
||||
|
||||
If I am to go between (Server -> Shaper -> Transport), I have to
|
||||
play the role of Consumer convincingly for the ProducingServer.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.underlying = DummyConsumer()
|
||||
self.consumer = self.proxyClass(self.underlying)
|
||||
self.producer = DummyProducer(self.consumer)
|
||||
|
||||
def testRegisterPush(self):
|
||||
self.consumer.registerProducer(self.producer, True)
|
||||
## Consumer should NOT have called PushProducer.resumeProducing
|
||||
self.failIf(self.producer.resumed)
|
||||
|
||||
## I'm I'm just a proxy, should I only do resumeProducing when
|
||||
## I get poked myself?
|
||||
#def testRegisterPull(self):
|
||||
# self.consumer.registerProducer(self.producer, False)
|
||||
# ## Consumer SHOULD have called PushProducer.resumeProducing
|
||||
# self.failUnless(self.producer.resumed)
|
||||
|
||||
def testUnregister(self):
|
||||
self.consumer.registerProducer(self.producer, False)
|
||||
self.consumer.unregisterProducer()
|
||||
# Now when the consumer would ordinarily want more data, it
|
||||
# shouldn't ask producer for it.
|
||||
# The most succinct way to trigger "want more data" is to proxy for
|
||||
# a PullProducer and have someone ask me for data.
|
||||
self.producer.resumed = False
|
||||
self.consumer.resumeProducing()
|
||||
self.failIf(self.producer.resumed)
|
||||
|
||||
def testFinish(self):
|
||||
self.consumer.registerProducer(self.producer, False)
|
||||
self.consumer.finish()
|
||||
# I guess finish should behave like unregister?
|
||||
self.producer.resumed = False
|
||||
self.consumer.resumeProducing()
|
||||
self.failIf(self.producer.resumed)
|
||||
|
||||
|
||||
class ProducerInterfaceTest:
|
||||
"""Test ProducerConsumerProxy as a Producer.
|
||||
|
||||
Normally we have ProducingServer -> ConsumingTransport.
|
||||
|
||||
If I am to go between (Server -> Shaper -> Transport), I have to
|
||||
play the role of Producer convincingly for the ConsumingTransport.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.consumer = DummyConsumer()
|
||||
self.producer = self.proxyClass(self.consumer)
|
||||
|
||||
def testRegistersProducer(self):
|
||||
self.assertEqual(self.consumer.producer[0], self.producer)
|
||||
|
||||
def testPause(self):
|
||||
self.producer.pauseProducing()
|
||||
self.producer.write("yakkity yak")
|
||||
self.failIf(self.consumer.getvalue(),
|
||||
"Paused producer should not have sent data.")
|
||||
|
||||
def testResume(self):
|
||||
self.producer.pauseProducing()
|
||||
self.producer.resumeProducing()
|
||||
self.producer.write("yakkity yak")
|
||||
self.assertEqual(self.consumer.getvalue(), "yakkity yak")
|
||||
|
||||
def testResumeNoEmptyWrite(self):
|
||||
self.producer.pauseProducing()
|
||||
self.producer.resumeProducing()
|
||||
self.assertEqual(len(self.consumer._writes), 0,
|
||||
"Resume triggered an empty write.")
|
||||
|
||||
def testResumeBuffer(self):
|
||||
self.producer.pauseProducing()
|
||||
self.producer.write("buffer this")
|
||||
self.producer.resumeProducing()
|
||||
self.assertEqual(self.consumer.getvalue(), "buffer this")
|
||||
|
||||
def testStop(self):
|
||||
self.producer.stopProducing()
|
||||
self.producer.write("yakkity yak")
|
||||
self.failIf(self.consumer.getvalue(),
|
||||
"Stopped producer should not have sent data.")
|
||||
|
||||
|
||||
class PCP_ConsumerInterfaceTest(ConsumerInterfaceTest, unittest.TestCase):
|
||||
proxyClass = pcp.BasicProducerConsumerProxy
|
||||
|
||||
class PCPII_ConsumerInterfaceTest(ConsumerInterfaceTest, unittest.TestCase):
|
||||
proxyClass = pcp.ProducerConsumerProxy
|
||||
|
||||
class PCP_ProducerInterfaceTest(ProducerInterfaceTest, unittest.TestCase):
|
||||
proxyClass = pcp.BasicProducerConsumerProxy
|
||||
|
||||
class PCPII_ProducerInterfaceTest(ProducerInterfaceTest, unittest.TestCase):
|
||||
proxyClass = pcp.ProducerConsumerProxy
|
||||
|
||||
class ProducerProxyTest(unittest.TestCase):
|
||||
"""Producer methods on me should be relayed to the Producer I proxy.
|
||||
"""
|
||||
proxyClass = pcp.BasicProducerConsumerProxy
|
||||
|
||||
def setUp(self):
|
||||
self.proxy = self.proxyClass(None)
|
||||
self.parentProducer = DummyProducer(self.proxy)
|
||||
self.proxy.registerProducer(self.parentProducer, True)
|
||||
|
||||
def testStop(self):
|
||||
self.proxy.stopProducing()
|
||||
self.failUnless(self.parentProducer.stopped)
|
||||
|
||||
|
||||
class ConsumerProxyTest(unittest.TestCase):
|
||||
"""Consumer methods on me should be relayed to the Consumer I proxy.
|
||||
"""
|
||||
proxyClass = pcp.BasicProducerConsumerProxy
|
||||
|
||||
def setUp(self):
|
||||
self.underlying = DummyConsumer()
|
||||
self.consumer = self.proxyClass(self.underlying)
|
||||
|
||||
def testWrite(self):
|
||||
# NOTE: This test only valid for streaming (Push) systems.
|
||||
self.consumer.write("some bytes")
|
||||
self.assertEqual(self.underlying.getvalue(), "some bytes")
|
||||
|
||||
def testFinish(self):
|
||||
self.consumer.finish()
|
||||
self.failUnless(self.underlying.finished)
|
||||
|
||||
def testUnregister(self):
|
||||
self.consumer.unregisterProducer()
|
||||
self.failUnless(self.underlying.unregistered)
|
||||
|
||||
|
||||
class PullProducerTest:
|
||||
def setUp(self):
|
||||
self.underlying = DummyConsumer()
|
||||
self.proxy = self.proxyClass(self.underlying)
|
||||
self.parentProducer = DummyProducer(self.proxy)
|
||||
self.proxy.registerProducer(self.parentProducer, True)
|
||||
|
||||
def testHoldWrites(self):
|
||||
self.proxy.write("hello")
|
||||
# Consumer should get no data before it says resumeProducing.
|
||||
self.failIf(self.underlying.getvalue(),
|
||||
"Pulling Consumer got data before it pulled.")
|
||||
|
||||
def testPull(self):
|
||||
self.proxy.write("hello")
|
||||
self.proxy.resumeProducing()
|
||||
self.assertEqual(self.underlying.getvalue(), "hello")
|
||||
|
||||
def testMergeWrites(self):
|
||||
self.proxy.write("hello ")
|
||||
self.proxy.write("sunshine")
|
||||
self.proxy.resumeProducing()
|
||||
nwrites = len(self.underlying._writes)
|
||||
self.assertEqual(nwrites, 1, "Pull resulted in %d writes instead "
|
||||
"of 1." % (nwrites,))
|
||||
self.assertEqual(self.underlying.getvalue(), "hello sunshine")
|
||||
|
||||
|
||||
def testLateWrite(self):
|
||||
# consumer sends its initial pull before we have data
|
||||
self.proxy.resumeProducing()
|
||||
self.proxy.write("data")
|
||||
# This data should answer that pull request.
|
||||
self.assertEqual(self.underlying.getvalue(), "data")
|
||||
|
||||
class PCP_PullProducerTest(PullProducerTest, unittest.TestCase):
|
||||
class proxyClass(pcp.BasicProducerConsumerProxy):
|
||||
iAmStreaming = False
|
||||
|
||||
class PCPII_PullProducerTest(PullProducerTest, unittest.TestCase):
|
||||
class proxyClass(pcp.ProducerConsumerProxy):
|
||||
iAmStreaming = False
|
||||
|
||||
# Buffering!
|
||||
|
||||
class BufferedConsumerTest(unittest.TestCase):
|
||||
"""As a consumer, ask the producer to pause after too much data."""
|
||||
|
||||
proxyClass = pcp.ProducerConsumerProxy
|
||||
|
||||
def setUp(self):
|
||||
self.underlying = DummyConsumer()
|
||||
self.proxy = self.proxyClass(self.underlying)
|
||||
self.proxy.bufferSize = 100
|
||||
|
||||
self.parentProducer = DummyProducer(self.proxy)
|
||||
self.proxy.registerProducer(self.parentProducer, True)
|
||||
|
||||
def testRegisterPull(self):
|
||||
self.proxy.registerProducer(self.parentProducer, False)
|
||||
## Consumer SHOULD have called PushProducer.resumeProducing
|
||||
self.failUnless(self.parentProducer.resumed)
|
||||
|
||||
def testPauseIntercept(self):
|
||||
self.proxy.pauseProducing()
|
||||
self.failIf(self.parentProducer.paused)
|
||||
|
||||
def testResumeIntercept(self):
|
||||
self.proxy.pauseProducing()
|
||||
self.proxy.resumeProducing()
|
||||
# With a streaming producer, just because the proxy was resumed is
|
||||
# not necessarily a reason to resume the parent producer. The state
|
||||
# of the buffer should decide that.
|
||||
self.failIf(self.parentProducer.resumed)
|
||||
|
||||
def testTriggerPause(self):
|
||||
"""Make sure I say \"when.\""""
|
||||
|
||||
# Pause the proxy so data sent to it builds up in its buffer.
|
||||
self.proxy.pauseProducing()
|
||||
self.failIf(self.parentProducer.paused, "don't pause yet")
|
||||
self.proxy.write("x" * 51)
|
||||
self.failIf(self.parentProducer.paused, "don't pause yet")
|
||||
self.proxy.write("x" * 51)
|
||||
self.failUnless(self.parentProducer.paused)
|
||||
|
||||
def testTriggerResume(self):
|
||||
"""Make sure I resumeProducing when my buffer empties."""
|
||||
self.proxy.pauseProducing()
|
||||
self.proxy.write("x" * 102)
|
||||
self.failUnless(self.parentProducer.paused, "should be paused")
|
||||
self.proxy.resumeProducing()
|
||||
# Resuming should have emptied my buffer, so I should tell my
|
||||
# parent to resume too.
|
||||
self.failIf(self.parentProducer.paused,
|
||||
"Producer should have resumed.")
|
||||
self.failIf(self.proxy.producerPaused)
|
||||
|
||||
class BufferedPullTests(unittest.TestCase):
|
||||
class proxyClass(pcp.ProducerConsumerProxy):
|
||||
iAmStreaming = False
|
||||
|
||||
def _writeSomeData(self, data):
|
||||
pcp.ProducerConsumerProxy._writeSomeData(self, data[:100])
|
||||
return min(len(data), 100)
|
||||
|
||||
def setUp(self):
|
||||
self.underlying = DummyConsumer()
|
||||
self.proxy = self.proxyClass(self.underlying)
|
||||
self.proxy.bufferSize = 100
|
||||
|
||||
self.parentProducer = DummyProducer(self.proxy)
|
||||
self.proxy.registerProducer(self.parentProducer, False)
|
||||
|
||||
def testResumePull(self):
|
||||
# If proxy has no data to send on resumeProducing, it had better pull
|
||||
# some from its PullProducer.
|
||||
self.parentProducer.resumed = False
|
||||
self.proxy.resumeProducing()
|
||||
self.failUnless(self.parentProducer.resumed)
|
||||
|
||||
def testLateWriteBuffering(self):
|
||||
# consumer sends its initial pull before we have data
|
||||
self.proxy.resumeProducing()
|
||||
self.proxy.write("datum" * 21)
|
||||
# This data should answer that pull request.
|
||||
self.assertEqual(self.underlying.getvalue(), "datum" * 20)
|
||||
# but there should be some left over
|
||||
self.assertEqual(self.proxy._buffer, ["datum"])
|
||||
|
||||
|
||||
# TODO:
|
||||
# test that web request finishing bug (when we weren't proxying
|
||||
# unregisterProducer but were proxying finish, web file transfers
|
||||
# would hang on the last block.)
|
||||
# test what happens if writeSomeBytes decided to write zero bytes.
|
||||
|
|
@ -0,0 +1,377 @@
|
|||
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
# System Imports
|
||||
import sys
|
||||
|
||||
from twisted.trial import unittest
|
||||
|
||||
try:
|
||||
import cPickle as pickle
|
||||
except ImportError:
|
||||
import pickle
|
||||
|
||||
try:
|
||||
import cStringIO as StringIO
|
||||
except ImportError:
|
||||
import StringIO
|
||||
|
||||
# Twisted Imports
|
||||
from twisted.persisted import styles, aot, crefutil
|
||||
|
||||
|
||||
class VersionTestCase(unittest.TestCase):
|
||||
def testNullVersionUpgrade(self):
|
||||
global NullVersioned
|
||||
class NullVersioned:
|
||||
ok = 0
|
||||
pkcl = pickle.dumps(NullVersioned())
|
||||
class NullVersioned(styles.Versioned):
|
||||
persistenceVersion = 1
|
||||
def upgradeToVersion1(self):
|
||||
self.ok = 1
|
||||
mnv = pickle.loads(pkcl)
|
||||
styles.doUpgrade()
|
||||
assert mnv.ok, "initial upgrade not run!"
|
||||
|
||||
def testVersionUpgrade(self):
|
||||
global MyVersioned
|
||||
class MyVersioned(styles.Versioned):
|
||||
persistenceVersion = 2
|
||||
persistenceForgets = ['garbagedata']
|
||||
v3 = 0
|
||||
v4 = 0
|
||||
|
||||
def __init__(self):
|
||||
self.somedata = 'xxx'
|
||||
self.garbagedata = lambda q: 'cant persist'
|
||||
|
||||
def upgradeToVersion3(self):
|
||||
self.v3 += 1
|
||||
|
||||
def upgradeToVersion4(self):
|
||||
self.v4 += 1
|
||||
mv = MyVersioned()
|
||||
assert not (mv.v3 or mv.v4), "hasn't been upgraded yet"
|
||||
pickl = pickle.dumps(mv)
|
||||
MyVersioned.persistenceVersion = 4
|
||||
obj = pickle.loads(pickl)
|
||||
styles.doUpgrade()
|
||||
assert obj.v3, "didn't do version 3 upgrade"
|
||||
assert obj.v4, "didn't do version 4 upgrade"
|
||||
pickl = pickle.dumps(obj)
|
||||
obj = pickle.loads(pickl)
|
||||
styles.doUpgrade()
|
||||
assert obj.v3 == 1, "upgraded unnecessarily"
|
||||
assert obj.v4 == 1, "upgraded unnecessarily"
|
||||
|
||||
def testNonIdentityHash(self):
|
||||
global ClassWithCustomHash
|
||||
class ClassWithCustomHash(styles.Versioned):
|
||||
def __init__(self, unique, hash):
|
||||
self.unique = unique
|
||||
self.hash = hash
|
||||
def __hash__(self):
|
||||
return self.hash
|
||||
|
||||
v1 = ClassWithCustomHash('v1', 0)
|
||||
v2 = ClassWithCustomHash('v2', 0)
|
||||
|
||||
pkl = pickle.dumps((v1, v2))
|
||||
del v1, v2
|
||||
ClassWithCustomHash.persistenceVersion = 1
|
||||
ClassWithCustomHash.upgradeToVersion1 = lambda self: setattr(self, 'upgraded', True)
|
||||
v1, v2 = pickle.loads(pkl)
|
||||
styles.doUpgrade()
|
||||
self.assertEqual(v1.unique, 'v1')
|
||||
self.assertEqual(v2.unique, 'v2')
|
||||
self.failUnless(v1.upgraded)
|
||||
self.failUnless(v2.upgraded)
|
||||
|
||||
def testUpgradeDeserializesObjectsRequiringUpgrade(self):
|
||||
global ToyClassA, ToyClassB
|
||||
class ToyClassA(styles.Versioned):
|
||||
pass
|
||||
class ToyClassB(styles.Versioned):
|
||||
pass
|
||||
x = ToyClassA()
|
||||
y = ToyClassB()
|
||||
pklA, pklB = pickle.dumps(x), pickle.dumps(y)
|
||||
del x, y
|
||||
ToyClassA.persistenceVersion = 1
|
||||
def upgradeToVersion1(self):
|
||||
self.y = pickle.loads(pklB)
|
||||
styles.doUpgrade()
|
||||
ToyClassA.upgradeToVersion1 = upgradeToVersion1
|
||||
ToyClassB.persistenceVersion = 1
|
||||
ToyClassB.upgradeToVersion1 = lambda self: setattr(self, 'upgraded', True)
|
||||
|
||||
x = pickle.loads(pklA)
|
||||
styles.doUpgrade()
|
||||
self.failUnless(x.y.upgraded)
|
||||
|
||||
|
||||
|
||||
class VersionedSubClass(styles.Versioned):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class SecondVersionedSubClass(styles.Versioned):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class VersionedSubSubClass(VersionedSubClass):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class VersionedDiamondSubClass(VersionedSubSubClass, SecondVersionedSubClass):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class AybabtuTests(unittest.TestCase):
|
||||
"""
|
||||
L{styles._aybabtu} gets all of classes in the inheritance hierarchy of its
|
||||
argument that are strictly between L{Versioned} and the class itself.
|
||||
"""
|
||||
|
||||
def test_aybabtuStrictEmpty(self):
|
||||
"""
|
||||
L{styles._aybabtu} of L{Versioned} itself is an empty list.
|
||||
"""
|
||||
self.assertEqual(styles._aybabtu(styles.Versioned), [])
|
||||
|
||||
|
||||
def test_aybabtuStrictSubclass(self):
|
||||
"""
|
||||
There are no classes I{between} L{VersionedSubClass} and L{Versioned},
|
||||
so L{styles._aybabtu} returns an empty list.
|
||||
"""
|
||||
self.assertEqual(styles._aybabtu(VersionedSubClass), [])
|
||||
|
||||
|
||||
def test_aybabtuSubsubclass(self):
|
||||
"""
|
||||
With a sub-sub-class of L{Versioned}, L{styles._aybabtu} returns a list
|
||||
containing the intervening subclass.
|
||||
"""
|
||||
self.assertEqual(styles._aybabtu(VersionedSubSubClass),
|
||||
[VersionedSubClass])
|
||||
|
||||
|
||||
def test_aybabtuStrict(self):
|
||||
"""
|
||||
For a diamond-shaped inheritance graph, L{styles._aybabtu} returns a
|
||||
list containing I{both} intermediate subclasses.
|
||||
"""
|
||||
self.assertEqual(
|
||||
styles._aybabtu(VersionedDiamondSubClass),
|
||||
[VersionedSubSubClass, VersionedSubClass, SecondVersionedSubClass])
|
||||
|
||||
|
||||
|
||||
class MyEphemeral(styles.Ephemeral):
|
||||
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
|
||||
class EphemeralTestCase(unittest.TestCase):
|
||||
|
||||
def testEphemeral(self):
|
||||
o = MyEphemeral(3)
|
||||
self.assertEqual(o.__class__, MyEphemeral)
|
||||
self.assertEqual(o.x, 3)
|
||||
|
||||
pickl = pickle.dumps(o)
|
||||
o = pickle.loads(pickl)
|
||||
|
||||
self.assertEqual(o.__class__, styles.Ephemeral)
|
||||
self.assert_(not hasattr(o, 'x'))
|
||||
|
||||
|
||||
class Pickleable:
|
||||
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
def getX(self):
|
||||
return self.x
|
||||
|
||||
class A:
|
||||
"""
|
||||
dummy class
|
||||
"""
|
||||
def amethod(self):
|
||||
pass
|
||||
|
||||
class B:
|
||||
"""
|
||||
dummy class
|
||||
"""
|
||||
def bmethod(self):
|
||||
pass
|
||||
|
||||
def funktion():
|
||||
pass
|
||||
|
||||
class PicklingTestCase(unittest.TestCase):
|
||||
"""Test pickling of extra object types."""
|
||||
|
||||
def testModule(self):
|
||||
pickl = pickle.dumps(styles)
|
||||
o = pickle.loads(pickl)
|
||||
self.assertEqual(o, styles)
|
||||
|
||||
def testClassMethod(self):
|
||||
pickl = pickle.dumps(Pickleable.getX)
|
||||
o = pickle.loads(pickl)
|
||||
self.assertEqual(o, Pickleable.getX)
|
||||
|
||||
def testInstanceMethod(self):
|
||||
obj = Pickleable(4)
|
||||
pickl = pickle.dumps(obj.getX)
|
||||
o = pickle.loads(pickl)
|
||||
self.assertEqual(o(), 4)
|
||||
self.assertEqual(type(o), type(obj.getX))
|
||||
|
||||
def testStringIO(self):
|
||||
f = StringIO.StringIO()
|
||||
f.write("abc")
|
||||
pickl = pickle.dumps(f)
|
||||
o = pickle.loads(pickl)
|
||||
self.assertEqual(type(o), type(f))
|
||||
self.assertEqual(f.getvalue(), "abc")
|
||||
|
||||
|
||||
class EvilSourceror:
|
||||
def __init__(self, x):
|
||||
self.a = self
|
||||
self.a.b = self
|
||||
self.a.b.c = x
|
||||
|
||||
class NonDictState:
|
||||
def __getstate__(self):
|
||||
return self.state
|
||||
def __setstate__(self, state):
|
||||
self.state = state
|
||||
|
||||
class AOTTestCase(unittest.TestCase):
|
||||
def testSimpleTypes(self):
|
||||
obj = (1, 2.0, 3j, True, slice(1, 2, 3), 'hello', u'world', sys.maxint + 1, None, Ellipsis)
|
||||
rtObj = aot.unjellyFromSource(aot.jellyToSource(obj))
|
||||
self.assertEqual(obj, rtObj)
|
||||
|
||||
def testMethodSelfIdentity(self):
|
||||
a = A()
|
||||
b = B()
|
||||
a.bmethod = b.bmethod
|
||||
b.a = a
|
||||
im_ = aot.unjellyFromSource(aot.jellyToSource(b)).a.bmethod
|
||||
self.assertEqual(im_.im_class, im_.im_self.__class__)
|
||||
|
||||
|
||||
def test_methodNotSelfIdentity(self):
|
||||
"""
|
||||
If a class change after an instance has been created,
|
||||
L{aot.unjellyFromSource} shoud raise a C{TypeError} when trying to
|
||||
unjelly the instance.
|
||||
"""
|
||||
a = A()
|
||||
b = B()
|
||||
a.bmethod = b.bmethod
|
||||
b.a = a
|
||||
savedbmethod = B.bmethod
|
||||
del B.bmethod
|
||||
try:
|
||||
self.assertRaises(TypeError, aot.unjellyFromSource,
|
||||
aot.jellyToSource(b))
|
||||
finally:
|
||||
B.bmethod = savedbmethod
|
||||
|
||||
|
||||
def test_unsupportedType(self):
|
||||
"""
|
||||
L{aot.jellyToSource} should raise a C{TypeError} when trying to jelly
|
||||
an unknown type.
|
||||
"""
|
||||
try:
|
||||
set
|
||||
except:
|
||||
from sets import Set as set
|
||||
self.assertRaises(TypeError, aot.jellyToSource, set())
|
||||
|
||||
|
||||
def testBasicIdentity(self):
|
||||
# Anyone wanting to make this datastructure more complex, and thus this
|
||||
# test more comprehensive, is welcome to do so.
|
||||
aj = aot.AOTJellier().jellyToAO
|
||||
d = {'hello': 'world', "method": aj}
|
||||
l = [1, 2, 3,
|
||||
"he\tllo\n\n\"x world!",
|
||||
u"goodbye \n\t\u1010 world!",
|
||||
1, 1.0, 100 ** 100l, unittest, aot.AOTJellier, d,
|
||||
funktion
|
||||
]
|
||||
t = tuple(l)
|
||||
l.append(l)
|
||||
l.append(t)
|
||||
l.append(t)
|
||||
uj = aot.unjellyFromSource(aot.jellyToSource([l, l]))
|
||||
assert uj[0] is uj[1]
|
||||
assert uj[1][0:5] == l[0:5]
|
||||
|
||||
|
||||
def testNonDictState(self):
|
||||
a = NonDictState()
|
||||
a.state = "meringue!"
|
||||
assert aot.unjellyFromSource(aot.jellyToSource(a)).state == a.state
|
||||
|
||||
def testCopyReg(self):
|
||||
s = "foo_bar"
|
||||
sio = StringIO.StringIO()
|
||||
sio.write(s)
|
||||
uj = aot.unjellyFromSource(aot.jellyToSource(sio))
|
||||
# print repr(uj.__dict__)
|
||||
assert uj.getvalue() == s
|
||||
|
||||
def testFunkyReferences(self):
|
||||
o = EvilSourceror(EvilSourceror([]))
|
||||
j1 = aot.jellyToAOT(o)
|
||||
oj = aot.unjellyFromAOT(j1)
|
||||
|
||||
assert oj.a is oj
|
||||
assert oj.a.b is oj.b
|
||||
assert oj.c is not oj.c.c
|
||||
|
||||
|
||||
class CrefUtilTestCase(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{crefutil}.
|
||||
"""
|
||||
|
||||
def test_dictUnknownKey(self):
|
||||
"""
|
||||
L{crefutil._DictKeyAndValue} only support keys C{0} and C{1}.
|
||||
"""
|
||||
d = crefutil._DictKeyAndValue({})
|
||||
self.assertRaises(RuntimeError, d.__setitem__, 2, 3)
|
||||
|
||||
|
||||
def test_deferSetMultipleTimes(self):
|
||||
"""
|
||||
L{crefutil._Defer} can be assigned a key only one time.
|
||||
"""
|
||||
d = crefutil._Defer()
|
||||
d[0] = 1
|
||||
self.assertRaises(RuntimeError, d.__setitem__, 0, 1)
|
||||
|
||||
|
||||
|
||||
testCases = [VersionTestCase, EphemeralTestCase, PicklingTestCase]
|
||||
|
||||
|
|
@ -0,0 +1,719 @@
|
|||
# Copyright (c) 2005 Divmod, Inc.
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for Twisted plugin system.
|
||||
"""
|
||||
|
||||
import sys, errno, os, time
|
||||
import compileall
|
||||
|
||||
from zope.interface import Interface
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.python.log import textFromEventDict, addObserver, removeObserver
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.python.util import mergeFunctionMetadata
|
||||
|
||||
from twisted import plugin
|
||||
|
||||
|
||||
|
||||
class ITestPlugin(Interface):
|
||||
"""
|
||||
A plugin for use by the plugin system's unit tests.
|
||||
|
||||
Do not use this.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class ITestPlugin2(Interface):
|
||||
"""
|
||||
See L{ITestPlugin}.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class PluginTestCase(unittest.TestCase):
|
||||
"""
|
||||
Tests which verify the behavior of the current, active Twisted plugins
|
||||
directory.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Save C{sys.path} and C{sys.modules}, and create a package for tests.
|
||||
"""
|
||||
self.originalPath = sys.path[:]
|
||||
self.savedModules = sys.modules.copy()
|
||||
|
||||
self.root = FilePath(self.mktemp())
|
||||
self.root.createDirectory()
|
||||
self.package = self.root.child('mypackage')
|
||||
self.package.createDirectory()
|
||||
self.package.child('__init__.py').setContent("")
|
||||
|
||||
FilePath(__file__).sibling('plugin_basic.py'
|
||||
).copyTo(self.package.child('testplugin.py'))
|
||||
|
||||
self.originalPlugin = "testplugin"
|
||||
|
||||
sys.path.insert(0, self.root.path)
|
||||
import mypackage
|
||||
self.module = mypackage
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
"""
|
||||
Restore C{sys.path} and C{sys.modules} to their original values.
|
||||
"""
|
||||
sys.path[:] = self.originalPath
|
||||
sys.modules.clear()
|
||||
sys.modules.update(self.savedModules)
|
||||
|
||||
|
||||
def _unimportPythonModule(self, module, deleteSource=False):
|
||||
modulePath = module.__name__.split('.')
|
||||
packageName = '.'.join(modulePath[:-1])
|
||||
moduleName = modulePath[-1]
|
||||
|
||||
delattr(sys.modules[packageName], moduleName)
|
||||
del sys.modules[module.__name__]
|
||||
for ext in ['c', 'o'] + (deleteSource and [''] or []):
|
||||
try:
|
||||
os.remove(module.__file__ + ext)
|
||||
except OSError, ose:
|
||||
if ose.errno != errno.ENOENT:
|
||||
raise
|
||||
|
||||
|
||||
def _clearCache(self):
|
||||
"""
|
||||
Remove the plugins B{droping.cache} file.
|
||||
"""
|
||||
self.package.child('dropin.cache').remove()
|
||||
|
||||
|
||||
def _withCacheness(meth):
|
||||
"""
|
||||
This is a paranoid test wrapper, that calls C{meth} 2 times, clear the
|
||||
cache, and calls it 2 other times. It's supposed to ensure that the
|
||||
plugin system behaves correctly no matter what the state of the cache
|
||||
is.
|
||||
"""
|
||||
def wrapped(self):
|
||||
meth(self)
|
||||
meth(self)
|
||||
self._clearCache()
|
||||
meth(self)
|
||||
meth(self)
|
||||
return mergeFunctionMetadata(meth, wrapped)
|
||||
|
||||
|
||||
def test_cache(self):
|
||||
"""
|
||||
Check that the cache returned by L{plugin.getCache} hold the plugin
|
||||
B{testplugin}, and that this plugin has the properties we expect:
|
||||
provide L{TestPlugin}, has the good name and description, and can be
|
||||
loaded successfully.
|
||||
"""
|
||||
cache = plugin.getCache(self.module)
|
||||
|
||||
dropin = cache[self.originalPlugin]
|
||||
self.assertEqual(dropin.moduleName,
|
||||
'mypackage.%s' % (self.originalPlugin,))
|
||||
self.assertIn("I'm a test drop-in.", dropin.description)
|
||||
|
||||
# Note, not the preferred way to get a plugin by its interface.
|
||||
p1 = [p for p in dropin.plugins if ITestPlugin in p.provided][0]
|
||||
self.assertIdentical(p1.dropin, dropin)
|
||||
self.assertEqual(p1.name, "TestPlugin")
|
||||
|
||||
# Check the content of the description comes from the plugin module
|
||||
# docstring
|
||||
self.assertEqual(
|
||||
p1.description.strip(),
|
||||
"A plugin used solely for testing purposes.")
|
||||
self.assertEqual(p1.provided, [ITestPlugin, plugin.IPlugin])
|
||||
realPlugin = p1.load()
|
||||
# The plugin should match the class present in sys.modules
|
||||
self.assertIdentical(
|
||||
realPlugin,
|
||||
sys.modules['mypackage.%s' % (self.originalPlugin,)].TestPlugin)
|
||||
|
||||
# And it should also match if we import it classicly
|
||||
import mypackage.testplugin as tp
|
||||
self.assertIdentical(realPlugin, tp.TestPlugin)
|
||||
|
||||
test_cache = _withCacheness(test_cache)
|
||||
|
||||
|
||||
def test_plugins(self):
|
||||
"""
|
||||
L{plugin.getPlugins} should return the list of plugins matching the
|
||||
specified interface (here, L{ITestPlugin2}), and these plugins
|
||||
should be instances of classes with a C{test} method, to be sure
|
||||
L{plugin.getPlugins} load classes correctly.
|
||||
"""
|
||||
plugins = list(plugin.getPlugins(ITestPlugin2, self.module))
|
||||
|
||||
self.assertEqual(len(plugins), 2)
|
||||
|
||||
names = ['AnotherTestPlugin', 'ThirdTestPlugin']
|
||||
for p in plugins:
|
||||
names.remove(p.__name__)
|
||||
p.test()
|
||||
|
||||
test_plugins = _withCacheness(test_plugins)
|
||||
|
||||
|
||||
def test_detectNewFiles(self):
|
||||
"""
|
||||
Check that L{plugin.getPlugins} is able to detect plugins added at
|
||||
runtime.
|
||||
"""
|
||||
FilePath(__file__).sibling('plugin_extra1.py'
|
||||
).copyTo(self.package.child('pluginextra.py'))
|
||||
try:
|
||||
# Check that the current situation is clean
|
||||
self.failIfIn('mypackage.pluginextra', sys.modules)
|
||||
self.failIf(hasattr(sys.modules['mypackage'], 'pluginextra'),
|
||||
"mypackage still has pluginextra module")
|
||||
|
||||
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
|
||||
|
||||
# We should find 2 plugins: the one in testplugin, and the one in
|
||||
# pluginextra
|
||||
self.assertEqual(len(plgs), 2)
|
||||
|
||||
names = ['TestPlugin', 'FourthTestPlugin']
|
||||
for p in plgs:
|
||||
names.remove(p.__name__)
|
||||
p.test1()
|
||||
finally:
|
||||
self._unimportPythonModule(
|
||||
sys.modules['mypackage.pluginextra'],
|
||||
True)
|
||||
|
||||
test_detectNewFiles = _withCacheness(test_detectNewFiles)
|
||||
|
||||
|
||||
def test_detectFilesChanged(self):
|
||||
"""
|
||||
Check that if the content of a plugin change, L{plugin.getPlugins} is
|
||||
able to detect the new plugins added.
|
||||
"""
|
||||
FilePath(__file__).sibling('plugin_extra1.py'
|
||||
).copyTo(self.package.child('pluginextra.py'))
|
||||
try:
|
||||
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
|
||||
# Sanity check
|
||||
self.assertEqual(len(plgs), 2)
|
||||
|
||||
FilePath(__file__).sibling('plugin_extra2.py'
|
||||
).copyTo(self.package.child('pluginextra.py'))
|
||||
|
||||
# Fake out Python.
|
||||
self._unimportPythonModule(sys.modules['mypackage.pluginextra'])
|
||||
|
||||
# Make sure additions are noticed
|
||||
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
|
||||
|
||||
self.assertEqual(len(plgs), 3)
|
||||
|
||||
names = ['TestPlugin', 'FourthTestPlugin', 'FifthTestPlugin']
|
||||
for p in plgs:
|
||||
names.remove(p.__name__)
|
||||
p.test1()
|
||||
finally:
|
||||
self._unimportPythonModule(
|
||||
sys.modules['mypackage.pluginextra'],
|
||||
True)
|
||||
|
||||
test_detectFilesChanged = _withCacheness(test_detectFilesChanged)
|
||||
|
||||
|
||||
def test_detectFilesRemoved(self):
|
||||
"""
|
||||
Check that when a dropin file is removed, L{plugin.getPlugins} doesn't
|
||||
return it anymore.
|
||||
"""
|
||||
FilePath(__file__).sibling('plugin_extra1.py'
|
||||
).copyTo(self.package.child('pluginextra.py'))
|
||||
try:
|
||||
# Generate a cache with pluginextra in it.
|
||||
list(plugin.getPlugins(ITestPlugin, self.module))
|
||||
|
||||
finally:
|
||||
self._unimportPythonModule(
|
||||
sys.modules['mypackage.pluginextra'],
|
||||
True)
|
||||
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
|
||||
self.assertEqual(1, len(plgs))
|
||||
|
||||
test_detectFilesRemoved = _withCacheness(test_detectFilesRemoved)
|
||||
|
||||
|
||||
def test_nonexistentPathEntry(self):
|
||||
"""
|
||||
Test that getCache skips over any entries in a plugin package's
|
||||
C{__path__} which do not exist.
|
||||
"""
|
||||
path = self.mktemp()
|
||||
self.failIf(os.path.exists(path))
|
||||
# Add the test directory to the plugins path
|
||||
self.module.__path__.append(path)
|
||||
try:
|
||||
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
|
||||
self.assertEqual(len(plgs), 1)
|
||||
finally:
|
||||
self.module.__path__.remove(path)
|
||||
|
||||
test_nonexistentPathEntry = _withCacheness(test_nonexistentPathEntry)
|
||||
|
||||
|
||||
def test_nonDirectoryChildEntry(self):
|
||||
"""
|
||||
Test that getCache skips over any entries in a plugin package's
|
||||
C{__path__} which refer to children of paths which are not directories.
|
||||
"""
|
||||
path = FilePath(self.mktemp())
|
||||
self.failIf(path.exists())
|
||||
path.touch()
|
||||
child = path.child("test_package").path
|
||||
self.module.__path__.append(child)
|
||||
try:
|
||||
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
|
||||
self.assertEqual(len(plgs), 1)
|
||||
finally:
|
||||
self.module.__path__.remove(child)
|
||||
|
||||
test_nonDirectoryChildEntry = _withCacheness(test_nonDirectoryChildEntry)
|
||||
|
||||
|
||||
def test_deployedMode(self):
|
||||
"""
|
||||
The C{dropin.cache} file may not be writable: the cache should still be
|
||||
attainable, but an error should be logged to show that the cache
|
||||
couldn't be updated.
|
||||
"""
|
||||
# Generate the cache
|
||||
plugin.getCache(self.module)
|
||||
|
||||
cachepath = self.package.child('dropin.cache')
|
||||
|
||||
# Add a new plugin
|
||||
FilePath(__file__).sibling('plugin_extra1.py'
|
||||
).copyTo(self.package.child('pluginextra.py'))
|
||||
|
||||
os.chmod(self.package.path, 0500)
|
||||
# Change the right of dropin.cache too for windows
|
||||
os.chmod(cachepath.path, 0400)
|
||||
self.addCleanup(os.chmod, self.package.path, 0700)
|
||||
self.addCleanup(os.chmod, cachepath.path, 0700)
|
||||
|
||||
# Start observing log events to see the warning
|
||||
events = []
|
||||
addObserver(events.append)
|
||||
self.addCleanup(removeObserver, events.append)
|
||||
|
||||
cache = plugin.getCache(self.module)
|
||||
# The new plugin should be reported
|
||||
self.assertIn('pluginextra', cache)
|
||||
self.assertIn(self.originalPlugin, cache)
|
||||
|
||||
# Make sure something was logged about the cache.
|
||||
expected = "Unable to write to plugin cache %s: error number %d" % (
|
||||
cachepath.path, errno.EPERM)
|
||||
for event in events:
|
||||
if expected in textFromEventDict(event):
|
||||
break
|
||||
else:
|
||||
self.fail(
|
||||
"Did not observe unwriteable cache warning in log "
|
||||
"events: %r" % (events,))
|
||||
|
||||
|
||||
|
||||
# This is something like the Twisted plugins file.
|
||||
pluginInitFile = """
|
||||
from twisted.plugin import pluginPackagePaths
|
||||
__path__.extend(pluginPackagePaths(__name__))
|
||||
__all__ = []
|
||||
"""
|
||||
|
||||
def pluginFileContents(name):
|
||||
return (
|
||||
"from zope.interface import classProvides\n"
|
||||
"from twisted.plugin import IPlugin\n"
|
||||
"from twisted.test.test_plugin import ITestPlugin\n"
|
||||
"\n"
|
||||
"class %s(object):\n"
|
||||
" classProvides(IPlugin, ITestPlugin)\n") % (name,)
|
||||
|
||||
|
||||
def _createPluginDummy(entrypath, pluginContent, real, pluginModule):
|
||||
"""
|
||||
Create a plugindummy package.
|
||||
"""
|
||||
entrypath.createDirectory()
|
||||
pkg = entrypath.child('plugindummy')
|
||||
pkg.createDirectory()
|
||||
if real:
|
||||
pkg.child('__init__.py').setContent('')
|
||||
plugs = pkg.child('plugins')
|
||||
plugs.createDirectory()
|
||||
if real:
|
||||
plugs.child('__init__.py').setContent(pluginInitFile)
|
||||
plugs.child(pluginModule + '.py').setContent(pluginContent)
|
||||
return plugs
|
||||
|
||||
|
||||
|
||||
class DeveloperSetupTests(unittest.TestCase):
|
||||
"""
|
||||
These tests verify things about the plugin system without actually
|
||||
interacting with the deployed 'twisted.plugins' package, instead creating a
|
||||
temporary package.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Create a complex environment with multiple entries on sys.path, akin to
|
||||
a developer's environment who has a development (trunk) checkout of
|
||||
Twisted, a system installed version of Twisted (for their operating
|
||||
system's tools) and a project which provides Twisted plugins.
|
||||
"""
|
||||
self.savedPath = sys.path[:]
|
||||
self.savedModules = sys.modules.copy()
|
||||
self.fakeRoot = FilePath(self.mktemp())
|
||||
self.fakeRoot.createDirectory()
|
||||
self.systemPath = self.fakeRoot.child('system_path')
|
||||
self.devPath = self.fakeRoot.child('development_path')
|
||||
self.appPath = self.fakeRoot.child('application_path')
|
||||
self.systemPackage = _createPluginDummy(
|
||||
self.systemPath, pluginFileContents('system'),
|
||||
True, 'plugindummy_builtin')
|
||||
self.devPackage = _createPluginDummy(
|
||||
self.devPath, pluginFileContents('dev'),
|
||||
True, 'plugindummy_builtin')
|
||||
self.appPackage = _createPluginDummy(
|
||||
self.appPath, pluginFileContents('app'),
|
||||
False, 'plugindummy_app')
|
||||
|
||||
# Now we're going to do the system installation.
|
||||
sys.path.extend([x.path for x in [self.systemPath,
|
||||
self.appPath]])
|
||||
# Run all the way through the plugins list to cause the
|
||||
# L{plugin.getPlugins} generator to write cache files for the system
|
||||
# installation.
|
||||
self.getAllPlugins()
|
||||
self.sysplug = self.systemPath.child('plugindummy').child('plugins')
|
||||
self.syscache = self.sysplug.child('dropin.cache')
|
||||
# Make sure there's a nice big difference in modification times so that
|
||||
# we won't re-build the system cache.
|
||||
now = time.time()
|
||||
os.utime(
|
||||
self.sysplug.child('plugindummy_builtin.py').path,
|
||||
(now - 5000,) * 2)
|
||||
os.utime(self.syscache.path, (now - 2000,) * 2)
|
||||
# For extra realism, let's make sure that the system path is no longer
|
||||
# writable.
|
||||
self.lockSystem()
|
||||
self.resetEnvironment()
|
||||
|
||||
|
||||
def lockSystem(self):
|
||||
"""
|
||||
Lock the system directories, as if they were unwritable by this user.
|
||||
"""
|
||||
os.chmod(self.sysplug.path, 0555)
|
||||
os.chmod(self.syscache.path, 0555)
|
||||
|
||||
|
||||
def unlockSystem(self):
|
||||
"""
|
||||
Unlock the system directories, as if they were writable by this user.
|
||||
"""
|
||||
os.chmod(self.sysplug.path, 0777)
|
||||
os.chmod(self.syscache.path, 0777)
|
||||
|
||||
|
||||
def getAllPlugins(self):
|
||||
"""
|
||||
Get all the plugins loadable from our dummy package, and return their
|
||||
short names.
|
||||
"""
|
||||
# Import the module we just added to our path. (Local scope because
|
||||
# this package doesn't exist outside of this test.)
|
||||
import plugindummy.plugins
|
||||
x = list(plugin.getPlugins(ITestPlugin, plugindummy.plugins))
|
||||
return [plug.__name__ for plug in x]
|
||||
|
||||
|
||||
def resetEnvironment(self):
|
||||
"""
|
||||
Change the environment to what it should be just as the test is
|
||||
starting.
|
||||
"""
|
||||
self.unsetEnvironment()
|
||||
sys.path.extend([x.path for x in [self.devPath,
|
||||
self.systemPath,
|
||||
self.appPath]])
|
||||
|
||||
def unsetEnvironment(self):
|
||||
"""
|
||||
Change the Python environment back to what it was before the test was
|
||||
started.
|
||||
"""
|
||||
sys.modules.clear()
|
||||
sys.modules.update(self.savedModules)
|
||||
sys.path[:] = self.savedPath
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
"""
|
||||
Reset the Python environment to what it was before this test ran, and
|
||||
restore permissions on files which were marked read-only so that the
|
||||
directory may be cleanly cleaned up.
|
||||
"""
|
||||
self.unsetEnvironment()
|
||||
# Normally we wouldn't "clean up" the filesystem like this (leaving
|
||||
# things for post-test inspection), but if we left the permissions the
|
||||
# way they were, we'd be leaving files around that the buildbots
|
||||
# couldn't delete, and that would be bad.
|
||||
self.unlockSystem()
|
||||
|
||||
|
||||
def test_developmentPluginAvailability(self):
|
||||
"""
|
||||
Plugins added in the development path should be loadable, even when
|
||||
the (now non-importable) system path contains its own idea of the
|
||||
list of plugins for a package. Inversely, plugins added in the
|
||||
system path should not be available.
|
||||
"""
|
||||
# Run 3 times: uncached, cached, and then cached again to make sure we
|
||||
# didn't overwrite / corrupt the cache on the cached try.
|
||||
for x in range(3):
|
||||
names = self.getAllPlugins()
|
||||
names.sort()
|
||||
self.assertEqual(names, ['app', 'dev'])
|
||||
|
||||
|
||||
def test_freshPyReplacesStalePyc(self):
|
||||
"""
|
||||
Verify that if a stale .pyc file on the PYTHONPATH is replaced by a
|
||||
fresh .py file, the plugins in the new .py are picked up rather than
|
||||
the stale .pyc, even if the .pyc is still around.
|
||||
"""
|
||||
mypath = self.appPackage.child("stale.py")
|
||||
mypath.setContent(pluginFileContents('one'))
|
||||
# Make it super stale
|
||||
x = time.time() - 1000
|
||||
os.utime(mypath.path, (x, x))
|
||||
pyc = mypath.sibling('stale.pyc')
|
||||
# compile it
|
||||
compileall.compile_dir(self.appPackage.path, quiet=1)
|
||||
os.utime(pyc.path, (x, x))
|
||||
# Eliminate the other option.
|
||||
mypath.remove()
|
||||
# Make sure it's the .pyc path getting cached.
|
||||
self.resetEnvironment()
|
||||
# Sanity check.
|
||||
self.assertIn('one', self.getAllPlugins())
|
||||
self.failIfIn('two', self.getAllPlugins())
|
||||
self.resetEnvironment()
|
||||
mypath.setContent(pluginFileContents('two'))
|
||||
self.failIfIn('one', self.getAllPlugins())
|
||||
self.assertIn('two', self.getAllPlugins())
|
||||
|
||||
|
||||
def test_newPluginsOnReadOnlyPath(self):
|
||||
"""
|
||||
Verify that a failure to write the dropin.cache file on a read-only
|
||||
path will not affect the list of plugins returned.
|
||||
|
||||
Note: this test should pass on both Linux and Windows, but may not
|
||||
provide useful coverage on Windows due to the different meaning of
|
||||
"read-only directory".
|
||||
"""
|
||||
self.unlockSystem()
|
||||
self.sysplug.child('newstuff.py').setContent(pluginFileContents('one'))
|
||||
self.lockSystem()
|
||||
|
||||
# Take the developer path out, so that the system plugins are actually
|
||||
# examined.
|
||||
sys.path.remove(self.devPath.path)
|
||||
|
||||
# Start observing log events to see the warning
|
||||
events = []
|
||||
addObserver(events.append)
|
||||
self.addCleanup(removeObserver, events.append)
|
||||
|
||||
self.assertIn('one', self.getAllPlugins())
|
||||
|
||||
# Make sure something was logged about the cache.
|
||||
expected = "Unable to write to plugin cache %s: error number %d" % (
|
||||
self.syscache.path, errno.EPERM)
|
||||
for event in events:
|
||||
if expected in textFromEventDict(event):
|
||||
break
|
||||
else:
|
||||
self.fail(
|
||||
"Did not observe unwriteable cache warning in log "
|
||||
"events: %r" % (events,))
|
||||
|
||||
|
||||
|
||||
class AdjacentPackageTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for the behavior of the plugin system when there are multiple
|
||||
installed copies of the package containing the plugins being loaded.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Save the elements of C{sys.path} and the items of C{sys.modules}.
|
||||
"""
|
||||
self.originalPath = sys.path[:]
|
||||
self.savedModules = sys.modules.copy()
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
"""
|
||||
Restore C{sys.path} and C{sys.modules} to their original values.
|
||||
"""
|
||||
sys.path[:] = self.originalPath
|
||||
sys.modules.clear()
|
||||
sys.modules.update(self.savedModules)
|
||||
|
||||
|
||||
def createDummyPackage(self, root, name, pluginName):
|
||||
"""
|
||||
Create a directory containing a Python package named I{dummy} with a
|
||||
I{plugins} subpackage.
|
||||
|
||||
@type root: L{FilePath}
|
||||
@param root: The directory in which to create the hierarchy.
|
||||
|
||||
@type name: C{str}
|
||||
@param name: The name of the directory to create which will contain
|
||||
the package.
|
||||
|
||||
@type pluginName: C{str}
|
||||
@param pluginName: The name of a module to create in the
|
||||
I{dummy.plugins} package.
|
||||
|
||||
@rtype: L{FilePath}
|
||||
@return: The directory which was created to contain the I{dummy}
|
||||
package.
|
||||
"""
|
||||
directory = root.child(name)
|
||||
package = directory.child('dummy')
|
||||
package.makedirs()
|
||||
package.child('__init__.py').setContent('')
|
||||
plugins = package.child('plugins')
|
||||
plugins.makedirs()
|
||||
plugins.child('__init__.py').setContent(pluginInitFile)
|
||||
pluginModule = plugins.child(pluginName + '.py')
|
||||
pluginModule.setContent(pluginFileContents(name))
|
||||
return directory
|
||||
|
||||
|
||||
def test_hiddenPackageSamePluginModuleNameObscured(self):
|
||||
"""
|
||||
Only plugins from the first package in sys.path should be returned by
|
||||
getPlugins in the case where there are two Python packages by the same
|
||||
name installed, each with a plugin module by a single name.
|
||||
"""
|
||||
root = FilePath(self.mktemp())
|
||||
root.makedirs()
|
||||
|
||||
firstDirectory = self.createDummyPackage(root, 'first', 'someplugin')
|
||||
secondDirectory = self.createDummyPackage(root, 'second', 'someplugin')
|
||||
|
||||
sys.path.append(firstDirectory.path)
|
||||
sys.path.append(secondDirectory.path)
|
||||
|
||||
import dummy.plugins
|
||||
|
||||
plugins = list(plugin.getPlugins(ITestPlugin, dummy.plugins))
|
||||
self.assertEqual(['first'], [p.__name__ for p in plugins])
|
||||
|
||||
|
||||
def test_hiddenPackageDifferentPluginModuleNameObscured(self):
|
||||
"""
|
||||
Plugins from the first package in sys.path should be returned by
|
||||
getPlugins in the case where there are two Python packages by the same
|
||||
name installed, each with a plugin module by a different name.
|
||||
"""
|
||||
root = FilePath(self.mktemp())
|
||||
root.makedirs()
|
||||
|
||||
firstDirectory = self.createDummyPackage(root, 'first', 'thisplugin')
|
||||
secondDirectory = self.createDummyPackage(root, 'second', 'thatplugin')
|
||||
|
||||
sys.path.append(firstDirectory.path)
|
||||
sys.path.append(secondDirectory.path)
|
||||
|
||||
import dummy.plugins
|
||||
|
||||
plugins = list(plugin.getPlugins(ITestPlugin, dummy.plugins))
|
||||
self.assertEqual(['first'], [p.__name__ for p in plugins])
|
||||
|
||||
|
||||
|
||||
class PackagePathTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{plugin.pluginPackagePaths} which constructs search paths for
|
||||
plugin packages.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Save the elements of C{sys.path}.
|
||||
"""
|
||||
self.originalPath = sys.path[:]
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
"""
|
||||
Restore C{sys.path} to its original value.
|
||||
"""
|
||||
sys.path[:] = self.originalPath
|
||||
|
||||
|
||||
def test_pluginDirectories(self):
|
||||
"""
|
||||
L{plugin.pluginPackagePaths} should return a list containing each
|
||||
directory in C{sys.path} with a suffix based on the supplied package
|
||||
name.
|
||||
"""
|
||||
foo = FilePath('foo')
|
||||
bar = FilePath('bar')
|
||||
sys.path = [foo.path, bar.path]
|
||||
self.assertEqual(
|
||||
plugin.pluginPackagePaths('dummy.plugins'),
|
||||
[foo.child('dummy').child('plugins').path,
|
||||
bar.child('dummy').child('plugins').path])
|
||||
|
||||
|
||||
def test_pluginPackagesExcluded(self):
|
||||
"""
|
||||
L{plugin.pluginPackagePaths} should exclude directories which are
|
||||
Python packages. The only allowed plugin package (the only one
|
||||
associated with a I{dummy} package which Python will allow to be
|
||||
imported) will already be known to the caller of
|
||||
L{plugin.pluginPackagePaths} and will most commonly already be in
|
||||
the C{__path__} they are about to mutate.
|
||||
"""
|
||||
root = FilePath(self.mktemp())
|
||||
foo = root.child('foo').child('dummy').child('plugins')
|
||||
foo.makedirs()
|
||||
foo.child('__init__.py').setContent('')
|
||||
sys.path = [root.child('foo').path, root.child('bar').path]
|
||||
self.assertEqual(
|
||||
plugin.pluginPackagePaths('dummy.plugins'),
|
||||
[root.child('bar').child('dummy').child('plugins').path])
|
||||
|
|
@ -0,0 +1,872 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test code for policies.
|
||||
"""
|
||||
|
||||
from __future__ import division, absolute_import
|
||||
|
||||
from zope.interface import Interface, implementer, implementedBy
|
||||
|
||||
from twisted.python.compat import NativeStringIO, _PY3
|
||||
from twisted.trial import unittest
|
||||
from twisted.test.proto_helpers import StringTransport
|
||||
from twisted.test.proto_helpers import StringTransportWithDisconnection
|
||||
|
||||
from twisted.internet import protocol, reactor, address, defer, task
|
||||
from twisted.protocols import policies
|
||||
|
||||
|
||||
|
||||
class SimpleProtocol(protocol.Protocol):
|
||||
|
||||
connected = disconnected = 0
|
||||
buffer = b""
|
||||
|
||||
def __init__(self):
|
||||
self.dConnected = defer.Deferred()
|
||||
self.dDisconnected = defer.Deferred()
|
||||
|
||||
def connectionMade(self):
|
||||
self.connected = 1
|
||||
self.dConnected.callback('')
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.disconnected = 1
|
||||
self.dDisconnected.callback('')
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.buffer += data
|
||||
|
||||
|
||||
|
||||
class SillyFactory(protocol.ClientFactory):
|
||||
|
||||
def __init__(self, p):
|
||||
self.p = p
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
return self.p
|
||||
|
||||
|
||||
class EchoProtocol(protocol.Protocol):
|
||||
paused = False
|
||||
|
||||
def pauseProducing(self):
|
||||
self.paused = True
|
||||
|
||||
def resumeProducing(self):
|
||||
self.paused = False
|
||||
|
||||
def stopProducing(self):
|
||||
pass
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.transport.write(data)
|
||||
|
||||
|
||||
|
||||
class Server(protocol.ServerFactory):
|
||||
"""
|
||||
A simple server factory using L{EchoProtocol}.
|
||||
"""
|
||||
protocol = EchoProtocol
|
||||
|
||||
|
||||
|
||||
class TestableThrottlingFactory(policies.ThrottlingFactory):
|
||||
"""
|
||||
L{policies.ThrottlingFactory} using a L{task.Clock} for tests.
|
||||
"""
|
||||
|
||||
def __init__(self, clock, *args, **kwargs):
|
||||
"""
|
||||
@param clock: object providing a callLater method that can be used
|
||||
for tests.
|
||||
@type clock: C{task.Clock} or alike.
|
||||
"""
|
||||
policies.ThrottlingFactory.__init__(self, *args, **kwargs)
|
||||
self.clock = clock
|
||||
|
||||
|
||||
def callLater(self, period, func):
|
||||
"""
|
||||
Forward to the testable clock.
|
||||
"""
|
||||
return self.clock.callLater(period, func)
|
||||
|
||||
|
||||
|
||||
class TestableTimeoutFactory(policies.TimeoutFactory):
|
||||
"""
|
||||
L{policies.TimeoutFactory} using a L{task.Clock} for tests.
|
||||
"""
|
||||
|
||||
def __init__(self, clock, *args, **kwargs):
|
||||
"""
|
||||
@param clock: object providing a callLater method that can be used
|
||||
for tests.
|
||||
@type clock: C{task.Clock} or alike.
|
||||
"""
|
||||
policies.TimeoutFactory.__init__(self, *args, **kwargs)
|
||||
self.clock = clock
|
||||
|
||||
|
||||
def callLater(self, period, func):
|
||||
"""
|
||||
Forward to the testable clock.
|
||||
"""
|
||||
return self.clock.callLater(period, func)
|
||||
|
||||
|
||||
|
||||
class WrapperTestCase(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{WrappingFactory} and L{ProtocolWrapper}.
|
||||
"""
|
||||
def test_protocolFactoryAttribute(self):
|
||||
"""
|
||||
Make sure protocol.factory is the wrapped factory, not the wrapping
|
||||
factory.
|
||||
"""
|
||||
f = Server()
|
||||
wf = policies.WrappingFactory(f)
|
||||
p = wf.buildProtocol(address.IPv4Address('TCP', '127.0.0.1', 35))
|
||||
self.assertIdentical(p.wrappedProtocol.factory, f)
|
||||
|
||||
|
||||
def test_transportInterfaces(self):
|
||||
"""
|
||||
The transport wrapper passed to the wrapped protocol's
|
||||
C{makeConnection} provides the same interfaces as are provided by the
|
||||
original transport.
|
||||
"""
|
||||
class IStubTransport(Interface):
|
||||
pass
|
||||
|
||||
@implementer(IStubTransport)
|
||||
class StubTransport:
|
||||
pass
|
||||
|
||||
# Looking up what ProtocolWrapper implements also mutates the class.
|
||||
# It adds __implemented__ and __providedBy__ attributes to it. These
|
||||
# prevent __getattr__ from causing the IStubTransport.providedBy call
|
||||
# below from returning True. If, by accident, nothing else causes
|
||||
# these attributes to be added to ProtocolWrapper, the test will pass,
|
||||
# but the interface will only be provided until something does trigger
|
||||
# their addition. So we just trigger it right now to be sure.
|
||||
implementedBy(policies.ProtocolWrapper)
|
||||
|
||||
proto = protocol.Protocol()
|
||||
wrapper = policies.ProtocolWrapper(policies.WrappingFactory(None), proto)
|
||||
|
||||
wrapper.makeConnection(StubTransport())
|
||||
self.assertTrue(IStubTransport.providedBy(proto.transport))
|
||||
|
||||
|
||||
def test_factoryLogPrefix(self):
|
||||
"""
|
||||
L{WrappingFactory.logPrefix} is customized to mention both the original
|
||||
factory and the wrapping factory.
|
||||
"""
|
||||
server = Server()
|
||||
factory = policies.WrappingFactory(server)
|
||||
self.assertEqual("Server (WrappingFactory)", factory.logPrefix())
|
||||
|
||||
|
||||
def test_factoryLogPrefixFallback(self):
|
||||
"""
|
||||
If the wrapped factory doesn't have a L{logPrefix} method,
|
||||
L{WrappingFactory.logPrefix} falls back to the factory class name.
|
||||
"""
|
||||
class NoFactory(object):
|
||||
pass
|
||||
|
||||
server = NoFactory()
|
||||
factory = policies.WrappingFactory(server)
|
||||
self.assertEqual("NoFactory (WrappingFactory)", factory.logPrefix())
|
||||
|
||||
|
||||
def test_protocolLogPrefix(self):
|
||||
"""
|
||||
L{ProtocolWrapper.logPrefix} is customized to mention both the original
|
||||
protocol and the wrapper.
|
||||
"""
|
||||
server = Server()
|
||||
factory = policies.WrappingFactory(server)
|
||||
protocol = factory.buildProtocol(
|
||||
address.IPv4Address('TCP', '127.0.0.1', 35))
|
||||
self.assertEqual("EchoProtocol (ProtocolWrapper)",
|
||||
protocol.logPrefix())
|
||||
|
||||
|
||||
def test_protocolLogPrefixFallback(self):
|
||||
"""
|
||||
If the wrapped protocol doesn't have a L{logPrefix} method,
|
||||
L{ProtocolWrapper.logPrefix} falls back to the protocol class name.
|
||||
"""
|
||||
class NoProtocol(object):
|
||||
pass
|
||||
|
||||
server = Server()
|
||||
server.protocol = NoProtocol
|
||||
factory = policies.WrappingFactory(server)
|
||||
protocol = factory.buildProtocol(
|
||||
address.IPv4Address('TCP', '127.0.0.1', 35))
|
||||
self.assertEqual("NoProtocol (ProtocolWrapper)",
|
||||
protocol.logPrefix())
|
||||
|
||||
|
||||
def _getWrapper(self):
|
||||
"""
|
||||
Return L{policies.ProtocolWrapper} that has been connected to a
|
||||
L{StringTransport}.
|
||||
"""
|
||||
wrapper = policies.ProtocolWrapper(policies.WrappingFactory(Server()),
|
||||
protocol.Protocol())
|
||||
transport = StringTransport()
|
||||
wrapper.makeConnection(transport)
|
||||
return wrapper
|
||||
|
||||
|
||||
def test_getHost(self):
|
||||
"""
|
||||
L{policies.ProtocolWrapper.getHost} calls C{getHost} on the underlying
|
||||
transport.
|
||||
"""
|
||||
wrapper = self._getWrapper()
|
||||
self.assertEqual(wrapper.getHost(), wrapper.transport.getHost())
|
||||
|
||||
|
||||
def test_getPeer(self):
|
||||
"""
|
||||
L{policies.ProtocolWrapper.getPeer} calls C{getPeer} on the underlying
|
||||
transport.
|
||||
"""
|
||||
wrapper = self._getWrapper()
|
||||
self.assertEqual(wrapper.getPeer(), wrapper.transport.getPeer())
|
||||
|
||||
|
||||
def test_registerProducer(self):
|
||||
"""
|
||||
L{policies.ProtocolWrapper.registerProducer} calls C{registerProducer}
|
||||
on the underlying transport.
|
||||
"""
|
||||
wrapper = self._getWrapper()
|
||||
producer = object()
|
||||
wrapper.registerProducer(producer, True)
|
||||
self.assertIdentical(wrapper.transport.producer, producer)
|
||||
self.assertTrue(wrapper.transport.streaming)
|
||||
|
||||
|
||||
def test_unregisterProducer(self):
|
||||
"""
|
||||
L{policies.ProtocolWrapper.unregisterProducer} calls
|
||||
C{unregisterProducer} on the underlying transport.
|
||||
"""
|
||||
wrapper = self._getWrapper()
|
||||
producer = object()
|
||||
wrapper.registerProducer(producer, True)
|
||||
wrapper.unregisterProducer()
|
||||
self.assertIdentical(wrapper.transport.producer, None)
|
||||
self.assertIdentical(wrapper.transport.streaming, None)
|
||||
|
||||
|
||||
def test_stopConsuming(self):
|
||||
"""
|
||||
L{policies.ProtocolWrapper.stopConsuming} calls C{stopConsuming} on
|
||||
the underlying transport.
|
||||
"""
|
||||
wrapper = self._getWrapper()
|
||||
result = []
|
||||
wrapper.transport.stopConsuming = lambda: result.append(True)
|
||||
wrapper.stopConsuming()
|
||||
self.assertEqual(result, [True])
|
||||
|
||||
|
||||
def test_startedConnecting(self):
|
||||
"""
|
||||
L{policies.WrappingFactory.startedConnecting} calls
|
||||
C{startedConnecting} on the underlying factory.
|
||||
"""
|
||||
result = []
|
||||
class Factory(object):
|
||||
def startedConnecting(self, connector):
|
||||
result.append(connector)
|
||||
|
||||
wrapper = policies.WrappingFactory(Factory())
|
||||
connector = object()
|
||||
wrapper.startedConnecting(connector)
|
||||
self.assertEqual(result, [connector])
|
||||
|
||||
|
||||
def test_clientConnectionLost(self):
|
||||
"""
|
||||
L{policies.WrappingFactory.clientConnectionLost} calls
|
||||
C{clientConnectionLost} on the underlying factory.
|
||||
"""
|
||||
result = []
|
||||
class Factory(object):
|
||||
def clientConnectionLost(self, connector, reason):
|
||||
result.append((connector, reason))
|
||||
|
||||
wrapper = policies.WrappingFactory(Factory())
|
||||
connector = object()
|
||||
reason = object()
|
||||
wrapper.clientConnectionLost(connector, reason)
|
||||
self.assertEqual(result, [(connector, reason)])
|
||||
|
||||
|
||||
def test_clientConnectionFailed(self):
|
||||
"""
|
||||
L{policies.WrappingFactory.clientConnectionFailed} calls
|
||||
C{clientConnectionFailed} on the underlying factory.
|
||||
"""
|
||||
result = []
|
||||
class Factory(object):
|
||||
def clientConnectionFailed(self, connector, reason):
|
||||
result.append((connector, reason))
|
||||
|
||||
wrapper = policies.WrappingFactory(Factory())
|
||||
connector = object()
|
||||
reason = object()
|
||||
wrapper.clientConnectionFailed(connector, reason)
|
||||
self.assertEqual(result, [(connector, reason)])
|
||||
|
||||
|
||||
|
||||
class WrappingFactory(policies.WrappingFactory):
|
||||
protocol = lambda s, f, p: p
|
||||
|
||||
def startFactory(self):
|
||||
policies.WrappingFactory.startFactory(self)
|
||||
self.deferred.callback(None)
|
||||
|
||||
|
||||
|
||||
class ThrottlingTestCase(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{policies.ThrottlingFactory}.
|
||||
"""
|
||||
|
||||
def test_limit(self):
|
||||
"""
|
||||
Full test using a custom server limiting number of connections.
|
||||
"""
|
||||
server = Server()
|
||||
c1, c2, c3, c4 = [SimpleProtocol() for i in range(4)]
|
||||
tServer = policies.ThrottlingFactory(server, 2)
|
||||
wrapTServer = WrappingFactory(tServer)
|
||||
wrapTServer.deferred = defer.Deferred()
|
||||
|
||||
# Start listening
|
||||
p = reactor.listenTCP(0, wrapTServer, interface="127.0.0.1")
|
||||
n = p.getHost().port
|
||||
|
||||
def _connect123(results):
|
||||
reactor.connectTCP("127.0.0.1", n, SillyFactory(c1))
|
||||
c1.dConnected.addCallback(
|
||||
lambda r: reactor.connectTCP("127.0.0.1", n, SillyFactory(c2)))
|
||||
c2.dConnected.addCallback(
|
||||
lambda r: reactor.connectTCP("127.0.0.1", n, SillyFactory(c3)))
|
||||
return c3.dDisconnected
|
||||
|
||||
def _check123(results):
|
||||
self.assertEqual([c.connected for c in (c1, c2, c3)], [1, 1, 1])
|
||||
self.assertEqual([c.disconnected for c in (c1, c2, c3)], [0, 0, 1])
|
||||
self.assertEqual(len(tServer.protocols.keys()), 2)
|
||||
return results
|
||||
|
||||
def _lose1(results):
|
||||
# disconnect one protocol and now another should be able to connect
|
||||
c1.transport.loseConnection()
|
||||
return c1.dDisconnected
|
||||
|
||||
def _connect4(results):
|
||||
reactor.connectTCP("127.0.0.1", n, SillyFactory(c4))
|
||||
return c4.dConnected
|
||||
|
||||
def _check4(results):
|
||||
self.assertEqual(c4.connected, 1)
|
||||
self.assertEqual(c4.disconnected, 0)
|
||||
return results
|
||||
|
||||
def _cleanup(results):
|
||||
for c in c2, c4:
|
||||
c.transport.loseConnection()
|
||||
return defer.DeferredList([
|
||||
defer.maybeDeferred(p.stopListening),
|
||||
c2.dDisconnected,
|
||||
c4.dDisconnected])
|
||||
|
||||
wrapTServer.deferred.addCallback(_connect123)
|
||||
wrapTServer.deferred.addCallback(_check123)
|
||||
wrapTServer.deferred.addCallback(_lose1)
|
||||
wrapTServer.deferred.addCallback(_connect4)
|
||||
wrapTServer.deferred.addCallback(_check4)
|
||||
wrapTServer.deferred.addCallback(_cleanup)
|
||||
return wrapTServer.deferred
|
||||
|
||||
|
||||
def test_writeSequence(self):
|
||||
"""
|
||||
L{ThrottlingProtocol.writeSequence} is called on the underlying factory.
|
||||
"""
|
||||
server = Server()
|
||||
tServer = TestableThrottlingFactory(task.Clock(), server)
|
||||
protocol = tServer.buildProtocol(
|
||||
address.IPv4Address('TCP', '127.0.0.1', 0))
|
||||
transport = StringTransportWithDisconnection()
|
||||
transport.protocol = protocol
|
||||
protocol.makeConnection(transport)
|
||||
|
||||
protocol.writeSequence([b'bytes'] * 4)
|
||||
self.assertEqual(transport.value(), b"bytesbytesbytesbytes")
|
||||
self.assertEqual(tServer.writtenThisSecond, 20)
|
||||
|
||||
|
||||
def test_writeLimit(self):
|
||||
"""
|
||||
Check the writeLimit parameter: write data, and check for the pause
|
||||
status.
|
||||
"""
|
||||
server = Server()
|
||||
tServer = TestableThrottlingFactory(task.Clock(), server, writeLimit=10)
|
||||
port = tServer.buildProtocol(address.IPv4Address('TCP', '127.0.0.1', 0))
|
||||
tr = StringTransportWithDisconnection()
|
||||
tr.protocol = port
|
||||
port.makeConnection(tr)
|
||||
port.producer = port.wrappedProtocol
|
||||
|
||||
port.dataReceived(b"0123456789")
|
||||
port.dataReceived(b"abcdefghij")
|
||||
self.assertEqual(tr.value(), b"0123456789abcdefghij")
|
||||
self.assertEqual(tServer.writtenThisSecond, 20)
|
||||
self.assertFalse(port.wrappedProtocol.paused)
|
||||
|
||||
# at this point server should've written 20 bytes, 10 bytes
|
||||
# above the limit so writing should be paused around 1 second
|
||||
# from 'now', and resumed a second after that
|
||||
tServer.clock.advance(1.05)
|
||||
self.assertEqual(tServer.writtenThisSecond, 0)
|
||||
self.assertTrue(port.wrappedProtocol.paused)
|
||||
|
||||
tServer.clock.advance(1.05)
|
||||
self.assertEqual(tServer.writtenThisSecond, 0)
|
||||
self.assertFalse(port.wrappedProtocol.paused)
|
||||
|
||||
|
||||
def test_readLimit(self):
|
||||
"""
|
||||
Check the readLimit parameter: read data and check for the pause
|
||||
status.
|
||||
"""
|
||||
server = Server()
|
||||
tServer = TestableThrottlingFactory(task.Clock(), server, readLimit=10)
|
||||
port = tServer.buildProtocol(address.IPv4Address('TCP', '127.0.0.1', 0))
|
||||
tr = StringTransportWithDisconnection()
|
||||
tr.protocol = port
|
||||
port.makeConnection(tr)
|
||||
|
||||
port.dataReceived(b"0123456789")
|
||||
port.dataReceived(b"abcdefghij")
|
||||
self.assertEqual(tr.value(), b"0123456789abcdefghij")
|
||||
self.assertEqual(tServer.readThisSecond, 20)
|
||||
|
||||
tServer.clock.advance(1.05)
|
||||
self.assertEqual(tServer.readThisSecond, 0)
|
||||
self.assertEqual(tr.producerState, 'paused')
|
||||
|
||||
tServer.clock.advance(1.05)
|
||||
self.assertEqual(tServer.readThisSecond, 0)
|
||||
self.assertEqual(tr.producerState, 'producing')
|
||||
|
||||
tr.clear()
|
||||
port.dataReceived(b"0123456789")
|
||||
port.dataReceived(b"abcdefghij")
|
||||
self.assertEqual(tr.value(), b"0123456789abcdefghij")
|
||||
self.assertEqual(tServer.readThisSecond, 20)
|
||||
|
||||
tServer.clock.advance(1.05)
|
||||
self.assertEqual(tServer.readThisSecond, 0)
|
||||
self.assertEqual(tr.producerState, 'paused')
|
||||
|
||||
tServer.clock.advance(1.05)
|
||||
self.assertEqual(tServer.readThisSecond, 0)
|
||||
self.assertEqual(tr.producerState, 'producing')
|
||||
|
||||
|
||||
|
||||
class TimeoutTestCase(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{policies.TimeoutFactory}.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Create a testable, deterministic clock, and a set of
|
||||
server factory/protocol/transport.
|
||||
"""
|
||||
self.clock = task.Clock()
|
||||
wrappedFactory = protocol.ServerFactory()
|
||||
wrappedFactory.protocol = SimpleProtocol
|
||||
self.factory = TestableTimeoutFactory(self.clock, wrappedFactory, 3)
|
||||
self.proto = self.factory.buildProtocol(
|
||||
address.IPv4Address('TCP', '127.0.0.1', 12345))
|
||||
self.transport = StringTransportWithDisconnection()
|
||||
self.transport.protocol = self.proto
|
||||
self.proto.makeConnection(self.transport)
|
||||
|
||||
|
||||
def test_timeout(self):
|
||||
"""
|
||||
Make sure that when a TimeoutFactory accepts a connection, it will
|
||||
time out that connection if no data is read or written within the
|
||||
timeout period.
|
||||
"""
|
||||
# Let almost 3 time units pass
|
||||
self.clock.pump([0.0, 0.5, 1.0, 1.0, 0.4])
|
||||
self.failIf(self.proto.wrappedProtocol.disconnected)
|
||||
|
||||
# Now let the timer elapse
|
||||
self.clock.pump([0.0, 0.2])
|
||||
self.failUnless(self.proto.wrappedProtocol.disconnected)
|
||||
|
||||
|
||||
def test_sendAvoidsTimeout(self):
|
||||
"""
|
||||
Make sure that writing data to a transport from a protocol
|
||||
constructed by a TimeoutFactory resets the timeout countdown.
|
||||
"""
|
||||
# Let half the countdown period elapse
|
||||
self.clock.pump([0.0, 0.5, 1.0])
|
||||
self.failIf(self.proto.wrappedProtocol.disconnected)
|
||||
|
||||
# Send some data (self.proto is the /real/ proto's transport, so this
|
||||
# is the write that gets called)
|
||||
self.proto.write(b'bytes bytes bytes')
|
||||
|
||||
# More time passes, putting us past the original timeout
|
||||
self.clock.pump([0.0, 1.0, 1.0])
|
||||
self.failIf(self.proto.wrappedProtocol.disconnected)
|
||||
|
||||
# Make sure writeSequence delays timeout as well
|
||||
self.proto.writeSequence([b'bytes'] * 3)
|
||||
|
||||
# Tick tock
|
||||
self.clock.pump([0.0, 1.0, 1.0])
|
||||
self.failIf(self.proto.wrappedProtocol.disconnected)
|
||||
|
||||
# Don't write anything more, just let the timeout expire
|
||||
self.clock.pump([0.0, 2.0])
|
||||
self.failUnless(self.proto.wrappedProtocol.disconnected)
|
||||
|
||||
|
||||
def test_receiveAvoidsTimeout(self):
|
||||
"""
|
||||
Make sure that receiving data also resets the timeout countdown.
|
||||
"""
|
||||
# Let half the countdown period elapse
|
||||
self.clock.pump([0.0, 1.0, 0.5])
|
||||
self.failIf(self.proto.wrappedProtocol.disconnected)
|
||||
|
||||
# Some bytes arrive, they should reset the counter
|
||||
self.proto.dataReceived(b'bytes bytes bytes')
|
||||
|
||||
# We pass the original timeout
|
||||
self.clock.pump([0.0, 1.0, 1.0])
|
||||
self.failIf(self.proto.wrappedProtocol.disconnected)
|
||||
|
||||
# Nothing more arrives though, the new timeout deadline is passed,
|
||||
# the connection should be dropped.
|
||||
self.clock.pump([0.0, 1.0, 1.0])
|
||||
self.failUnless(self.proto.wrappedProtocol.disconnected)
|
||||
|
||||
|
||||
|
||||
class TimeoutTester(protocol.Protocol, policies.TimeoutMixin):
|
||||
"""
|
||||
A testable protocol with timeout facility.
|
||||
|
||||
@ivar timedOut: set to C{True} if a timeout has been detected.
|
||||
@type timedOut: C{bool}
|
||||
"""
|
||||
timeOut = 3
|
||||
timedOut = False
|
||||
|
||||
def __init__(self, clock):
|
||||
"""
|
||||
Initialize the protocol with a C{task.Clock} object.
|
||||
"""
|
||||
self.clock = clock
|
||||
|
||||
|
||||
def connectionMade(self):
|
||||
"""
|
||||
Upon connection, set the timeout.
|
||||
"""
|
||||
self.setTimeout(self.timeOut)
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
"""
|
||||
Reset the timeout on data.
|
||||
"""
|
||||
self.resetTimeout()
|
||||
protocol.Protocol.dataReceived(self, data)
|
||||
|
||||
|
||||
def connectionLost(self, reason=None):
|
||||
"""
|
||||
On connection lost, cancel all timeout operations.
|
||||
"""
|
||||
self.setTimeout(None)
|
||||
|
||||
|
||||
def timeoutConnection(self):
|
||||
"""
|
||||
Flags the timedOut variable to indicate the timeout of the connection.
|
||||
"""
|
||||
self.timedOut = True
|
||||
|
||||
|
||||
def callLater(self, timeout, func, *args, **kwargs):
|
||||
"""
|
||||
Override callLater to use the deterministic clock.
|
||||
"""
|
||||
return self.clock.callLater(timeout, func, *args, **kwargs)
|
||||
|
||||
|
||||
|
||||
class TestTimeout(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{policies.TimeoutMixin}.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Create a testable, deterministic clock and a C{TimeoutTester} instance.
|
||||
"""
|
||||
self.clock = task.Clock()
|
||||
self.proto = TimeoutTester(self.clock)
|
||||
|
||||
|
||||
def test_overriddenCallLater(self):
|
||||
"""
|
||||
Test that the callLater of the clock is used instead of
|
||||
C{reactor.callLater}.
|
||||
"""
|
||||
self.proto.setTimeout(10)
|
||||
self.assertEqual(len(self.clock.calls), 1)
|
||||
|
||||
|
||||
def test_timeout(self):
|
||||
"""
|
||||
Check that the protocol does timeout at the time specified by its
|
||||
C{timeOut} attribute.
|
||||
"""
|
||||
self.proto.makeConnection(StringTransport())
|
||||
|
||||
# timeOut value is 3
|
||||
self.clock.pump([0, 0.5, 1.0, 1.0])
|
||||
self.failIf(self.proto.timedOut)
|
||||
self.clock.pump([0, 1.0])
|
||||
self.failUnless(self.proto.timedOut)
|
||||
|
||||
|
||||
def test_noTimeout(self):
|
||||
"""
|
||||
Check that receiving data is delaying the timeout of the connection.
|
||||
"""
|
||||
self.proto.makeConnection(StringTransport())
|
||||
|
||||
self.clock.pump([0, 0.5, 1.0, 1.0])
|
||||
self.failIf(self.proto.timedOut)
|
||||
self.proto.dataReceived(b'hello there')
|
||||
self.clock.pump([0, 1.0, 1.0, 0.5])
|
||||
self.failIf(self.proto.timedOut)
|
||||
self.clock.pump([0, 1.0])
|
||||
self.failUnless(self.proto.timedOut)
|
||||
|
||||
|
||||
def test_resetTimeout(self):
|
||||
"""
|
||||
Check that setting a new value for timeout cancel the previous value
|
||||
and install a new timeout.
|
||||
"""
|
||||
self.proto.timeOut = None
|
||||
self.proto.makeConnection(StringTransport())
|
||||
|
||||
self.proto.setTimeout(1)
|
||||
self.assertEqual(self.proto.timeOut, 1)
|
||||
|
||||
self.clock.pump([0, 0.9])
|
||||
self.failIf(self.proto.timedOut)
|
||||
self.clock.pump([0, 0.2])
|
||||
self.failUnless(self.proto.timedOut)
|
||||
|
||||
|
||||
def test_cancelTimeout(self):
|
||||
"""
|
||||
Setting the timeout to C{None} cancel any timeout operations.
|
||||
"""
|
||||
self.proto.timeOut = 5
|
||||
self.proto.makeConnection(StringTransport())
|
||||
|
||||
self.proto.setTimeout(None)
|
||||
self.assertEqual(self.proto.timeOut, None)
|
||||
|
||||
self.clock.pump([0, 5, 5, 5])
|
||||
self.failIf(self.proto.timedOut)
|
||||
|
||||
|
||||
def test_return(self):
|
||||
"""
|
||||
setTimeout should return the value of the previous timeout.
|
||||
"""
|
||||
self.proto.timeOut = 5
|
||||
|
||||
self.assertEqual(self.proto.setTimeout(10), 5)
|
||||
self.assertEqual(self.proto.setTimeout(None), 10)
|
||||
self.assertEqual(self.proto.setTimeout(1), None)
|
||||
self.assertEqual(self.proto.timeOut, 1)
|
||||
|
||||
# Clean up the DelayedCall
|
||||
self.proto.setTimeout(None)
|
||||
|
||||
|
||||
|
||||
class LimitTotalConnectionsFactoryTestCase(unittest.TestCase):
|
||||
"""Tests for policies.LimitTotalConnectionsFactory"""
|
||||
def testConnectionCounting(self):
|
||||
# Make a basic factory
|
||||
factory = policies.LimitTotalConnectionsFactory()
|
||||
factory.protocol = protocol.Protocol
|
||||
|
||||
# connectionCount starts at zero
|
||||
self.assertEqual(0, factory.connectionCount)
|
||||
|
||||
# connectionCount increments as connections are made
|
||||
p1 = factory.buildProtocol(None)
|
||||
self.assertEqual(1, factory.connectionCount)
|
||||
p2 = factory.buildProtocol(None)
|
||||
self.assertEqual(2, factory.connectionCount)
|
||||
|
||||
# and decrements as they are lost
|
||||
p1.connectionLost(None)
|
||||
self.assertEqual(1, factory.connectionCount)
|
||||
p2.connectionLost(None)
|
||||
self.assertEqual(0, factory.connectionCount)
|
||||
|
||||
def testConnectionLimiting(self):
|
||||
# Make a basic factory with a connection limit of 1
|
||||
factory = policies.LimitTotalConnectionsFactory()
|
||||
factory.protocol = protocol.Protocol
|
||||
factory.connectionLimit = 1
|
||||
|
||||
# Make a connection
|
||||
p = factory.buildProtocol(None)
|
||||
self.assertNotEqual(None, p)
|
||||
self.assertEqual(1, factory.connectionCount)
|
||||
|
||||
# Try to make a second connection, which will exceed the connection
|
||||
# limit. This should return None, because overflowProtocol is None.
|
||||
self.assertEqual(None, factory.buildProtocol(None))
|
||||
self.assertEqual(1, factory.connectionCount)
|
||||
|
||||
# Define an overflow protocol
|
||||
class OverflowProtocol(protocol.Protocol):
|
||||
def connectionMade(self):
|
||||
factory.overflowed = True
|
||||
factory.overflowProtocol = OverflowProtocol
|
||||
factory.overflowed = False
|
||||
|
||||
# Try to make a second connection again, now that we have an overflow
|
||||
# protocol. Note that overflow connections count towards the connection
|
||||
# count.
|
||||
op = factory.buildProtocol(None)
|
||||
op.makeConnection(None) # to trigger connectionMade
|
||||
self.assertEqual(True, factory.overflowed)
|
||||
self.assertEqual(2, factory.connectionCount)
|
||||
|
||||
# Close the connections.
|
||||
p.connectionLost(None)
|
||||
self.assertEqual(1, factory.connectionCount)
|
||||
op.connectionLost(None)
|
||||
self.assertEqual(0, factory.connectionCount)
|
||||
|
||||
|
||||
class WriteSequenceEchoProtocol(EchoProtocol):
|
||||
def dataReceived(self, bytes):
|
||||
if bytes.find(b'vector!') != -1:
|
||||
self.transport.writeSequence([bytes])
|
||||
else:
|
||||
EchoProtocol.dataReceived(self, bytes)
|
||||
|
||||
class TestLoggingFactory(policies.TrafficLoggingFactory):
|
||||
openFile = None
|
||||
def open(self, name):
|
||||
assert self.openFile is None, "open() called too many times"
|
||||
self.openFile = NativeStringIO()
|
||||
return self.openFile
|
||||
|
||||
|
||||
|
||||
class LoggingFactoryTestCase(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{policies.TrafficLoggingFactory}.
|
||||
"""
|
||||
|
||||
def test_thingsGetLogged(self):
|
||||
"""
|
||||
Check the output produced by L{policies.TrafficLoggingFactory}.
|
||||
"""
|
||||
wrappedFactory = Server()
|
||||
wrappedFactory.protocol = WriteSequenceEchoProtocol
|
||||
t = StringTransportWithDisconnection()
|
||||
f = TestLoggingFactory(wrappedFactory, 'test')
|
||||
p = f.buildProtocol(('1.2.3.4', 5678))
|
||||
t.protocol = p
|
||||
p.makeConnection(t)
|
||||
|
||||
v = f.openFile.getvalue()
|
||||
self.assertIn('*', v)
|
||||
self.failIf(t.value())
|
||||
|
||||
p.dataReceived(b'here are some bytes')
|
||||
|
||||
v = f.openFile.getvalue()
|
||||
self.assertIn("C 1: %r" % (b'here are some bytes',), v)
|
||||
self.assertIn("S 1: %r" % (b'here are some bytes',), v)
|
||||
self.assertEqual(t.value(), b'here are some bytes')
|
||||
|
||||
t.clear()
|
||||
p.dataReceived(b'prepare for vector! to the extreme')
|
||||
v = f.openFile.getvalue()
|
||||
self.assertIn("SV 1: %r" % ([b'prepare for vector! to the extreme'],), v)
|
||||
self.assertEqual(t.value(), b'prepare for vector! to the extreme')
|
||||
|
||||
p.loseConnection()
|
||||
|
||||
v = f.openFile.getvalue()
|
||||
self.assertIn('ConnectionDone', v)
|
||||
|
||||
|
||||
def test_counter(self):
|
||||
"""
|
||||
Test counter management with the resetCounter method.
|
||||
"""
|
||||
wrappedFactory = Server()
|
||||
f = TestLoggingFactory(wrappedFactory, 'test')
|
||||
self.assertEqual(f._counter, 0)
|
||||
f.buildProtocol(('1.2.3.4', 5678))
|
||||
self.assertEqual(f._counter, 1)
|
||||
# Reset log file
|
||||
f.openFile = None
|
||||
f.buildProtocol(('1.2.3.4', 5679))
|
||||
self.assertEqual(f._counter, 2)
|
||||
|
||||
f.resetCounter()
|
||||
self.assertEqual(f._counter, 0)
|
||||
|
||||
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test cases for twisted.protocols.postfix module.
|
||||
"""
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.protocols import postfix
|
||||
from twisted.test.proto_helpers import StringTransport
|
||||
|
||||
|
||||
class PostfixTCPMapQuoteTestCase(unittest.TestCase):
|
||||
data = [
|
||||
# (raw, quoted, [aliasQuotedForms]),
|
||||
('foo', 'foo'),
|
||||
('foo bar', 'foo%20bar'),
|
||||
('foo\tbar', 'foo%09bar'),
|
||||
('foo\nbar', 'foo%0Abar', 'foo%0abar'),
|
||||
('foo\r\nbar', 'foo%0D%0Abar', 'foo%0D%0abar', 'foo%0d%0Abar', 'foo%0d%0abar'),
|
||||
('foo ', 'foo%20'),
|
||||
(' foo', '%20foo'),
|
||||
]
|
||||
|
||||
def testData(self):
|
||||
for entry in self.data:
|
||||
raw = entry[0]
|
||||
quoted = entry[1:]
|
||||
|
||||
self.assertEqual(postfix.quote(raw), quoted[0])
|
||||
for q in quoted:
|
||||
self.assertEqual(postfix.unquote(q), raw)
|
||||
|
||||
class PostfixTCPMapServerTestCase:
|
||||
data = {
|
||||
# 'key': 'value',
|
||||
}
|
||||
|
||||
chat = [
|
||||
# (input, expected_output),
|
||||
]
|
||||
|
||||
def test_chat(self):
|
||||
"""
|
||||
Test that I{get} and I{put} commands are responded to correctly by
|
||||
L{postfix.PostfixTCPMapServer} when its factory is an instance of
|
||||
L{postifx.PostfixTCPMapDictServerFactory}.
|
||||
"""
|
||||
factory = postfix.PostfixTCPMapDictServerFactory(self.data)
|
||||
transport = StringTransport()
|
||||
|
||||
protocol = postfix.PostfixTCPMapServer()
|
||||
protocol.service = factory
|
||||
protocol.factory = factory
|
||||
protocol.makeConnection(transport)
|
||||
|
||||
for input, expected_output in self.chat:
|
||||
protocol.lineReceived(input)
|
||||
self.assertEqual(
|
||||
transport.value(), expected_output,
|
||||
'For %r, expected %r but got %r' % (
|
||||
input, expected_output, transport.value()))
|
||||
transport.clear()
|
||||
protocol.setTimeout(None)
|
||||
|
||||
|
||||
def test_deferredChat(self):
|
||||
"""
|
||||
Test that I{get} and I{put} commands are responded to correctly by
|
||||
L{postfix.PostfixTCPMapServer} when its factory is an instance of
|
||||
L{postifx.PostfixTCPMapDeferringDictServerFactory}.
|
||||
"""
|
||||
factory = postfix.PostfixTCPMapDeferringDictServerFactory(self.data)
|
||||
transport = StringTransport()
|
||||
|
||||
protocol = postfix.PostfixTCPMapServer()
|
||||
protocol.service = factory
|
||||
protocol.factory = factory
|
||||
protocol.makeConnection(transport)
|
||||
|
||||
for input, expected_output in self.chat:
|
||||
protocol.lineReceived(input)
|
||||
self.assertEqual(
|
||||
transport.value(), expected_output,
|
||||
'For %r, expected %r but got %r' % (
|
||||
input, expected_output, transport.value()))
|
||||
transport.clear()
|
||||
protocol.setTimeout(None)
|
||||
|
||||
|
||||
|
||||
class Valid(PostfixTCPMapServerTestCase, unittest.TestCase):
|
||||
data = {
|
||||
'foo': 'ThisIs Foo',
|
||||
'bar': ' bar really is found\r\n',
|
||||
}
|
||||
chat = [
|
||||
('get', "400 Command 'get' takes 1 parameters.\n"),
|
||||
('get foo bar', "500 \n"),
|
||||
('put', "400 Command 'put' takes 2 parameters.\n"),
|
||||
('put foo', "400 Command 'put' takes 2 parameters.\n"),
|
||||
('put foo bar baz', "500 put is not implemented yet.\n"),
|
||||
('put foo bar', '500 put is not implemented yet.\n'),
|
||||
('get foo', '200 ThisIs%20Foo\n'),
|
||||
('get bar', '200 %20bar%20really%20is%20found%0D%0A\n'),
|
||||
('get baz', '500 \n'),
|
||||
('foo', '400 unknown command\n'),
|
||||
]
|
||||
2597
Linux_i686/lib/python2.7/site-packages/twisted/test/test_process.py
Normal file
2597
Linux_i686/lib/python2.7/site-packages/twisted/test/test_process.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,236 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test cases for twisted.protocols package.
|
||||
"""
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.protocols import wire, portforward
|
||||
from twisted.internet import reactor, defer, address, protocol
|
||||
from twisted.test import proto_helpers
|
||||
|
||||
|
||||
class WireTestCase(unittest.TestCase):
|
||||
"""
|
||||
Test wire protocols.
|
||||
"""
|
||||
|
||||
def test_echo(self):
|
||||
"""
|
||||
Test wire.Echo protocol: send some data and check it send it back.
|
||||
"""
|
||||
t = proto_helpers.StringTransport()
|
||||
a = wire.Echo()
|
||||
a.makeConnection(t)
|
||||
a.dataReceived("hello")
|
||||
a.dataReceived("world")
|
||||
a.dataReceived("how")
|
||||
a.dataReceived("are")
|
||||
a.dataReceived("you")
|
||||
self.assertEqual(t.value(), "helloworldhowareyou")
|
||||
|
||||
|
||||
def test_who(self):
|
||||
"""
|
||||
Test wire.Who protocol.
|
||||
"""
|
||||
t = proto_helpers.StringTransport()
|
||||
a = wire.Who()
|
||||
a.makeConnection(t)
|
||||
self.assertEqual(t.value(), "root\r\n")
|
||||
|
||||
|
||||
def test_QOTD(self):
|
||||
"""
|
||||
Test wire.QOTD protocol.
|
||||
"""
|
||||
t = proto_helpers.StringTransport()
|
||||
a = wire.QOTD()
|
||||
a.makeConnection(t)
|
||||
self.assertEqual(t.value(),
|
||||
"An apple a day keeps the doctor away.\r\n")
|
||||
|
||||
|
||||
def test_discard(self):
|
||||
"""
|
||||
Test wire.Discard protocol.
|
||||
"""
|
||||
t = proto_helpers.StringTransport()
|
||||
a = wire.Discard()
|
||||
a.makeConnection(t)
|
||||
a.dataReceived("hello")
|
||||
a.dataReceived("world")
|
||||
a.dataReceived("how")
|
||||
a.dataReceived("are")
|
||||
a.dataReceived("you")
|
||||
self.assertEqual(t.value(), "")
|
||||
|
||||
|
||||
|
||||
class TestableProxyClientFactory(portforward.ProxyClientFactory):
|
||||
"""
|
||||
Test proxy client factory that keeps the last created protocol instance.
|
||||
|
||||
@ivar protoInstance: the last instance of the protocol.
|
||||
@type protoInstance: L{portforward.ProxyClient}
|
||||
"""
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
"""
|
||||
Create the protocol instance and keeps track of it.
|
||||
"""
|
||||
proto = portforward.ProxyClientFactory.buildProtocol(self, addr)
|
||||
self.protoInstance = proto
|
||||
return proto
|
||||
|
||||
|
||||
|
||||
class TestableProxyFactory(portforward.ProxyFactory):
|
||||
"""
|
||||
Test proxy factory that keeps the last created protocol instance.
|
||||
|
||||
@ivar protoInstance: the last instance of the protocol.
|
||||
@type protoInstance: L{portforward.ProxyServer}
|
||||
|
||||
@ivar clientFactoryInstance: client factory used by C{protoInstance} to
|
||||
create forward connections.
|
||||
@type clientFactoryInstance: L{TestableProxyClientFactory}
|
||||
"""
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
"""
|
||||
Create the protocol instance, keeps track of it, and makes it use
|
||||
C{clientFactoryInstance} as client factory.
|
||||
"""
|
||||
proto = portforward.ProxyFactory.buildProtocol(self, addr)
|
||||
self.clientFactoryInstance = TestableProxyClientFactory()
|
||||
# Force the use of this specific instance
|
||||
proto.clientProtocolFactory = lambda: self.clientFactoryInstance
|
||||
self.protoInstance = proto
|
||||
return proto
|
||||
|
||||
|
||||
|
||||
class Portforwarding(unittest.TestCase):
|
||||
"""
|
||||
Test port forwarding.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.serverProtocol = wire.Echo()
|
||||
self.clientProtocol = protocol.Protocol()
|
||||
self.openPorts = []
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
try:
|
||||
self.proxyServerFactory.protoInstance.transport.loseConnection()
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
pi = self.proxyServerFactory.clientFactoryInstance.protoInstance
|
||||
pi.transport.loseConnection()
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
self.clientProtocol.transport.loseConnection()
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
self.serverProtocol.transport.loseConnection()
|
||||
except AttributeError:
|
||||
pass
|
||||
return defer.gatherResults(
|
||||
[defer.maybeDeferred(p.stopListening) for p in self.openPorts])
|
||||
|
||||
|
||||
def test_portforward(self):
|
||||
"""
|
||||
Test port forwarding through Echo protocol.
|
||||
"""
|
||||
realServerFactory = protocol.ServerFactory()
|
||||
realServerFactory.protocol = lambda: self.serverProtocol
|
||||
realServerPort = reactor.listenTCP(0, realServerFactory,
|
||||
interface='127.0.0.1')
|
||||
self.openPorts.append(realServerPort)
|
||||
self.proxyServerFactory = TestableProxyFactory('127.0.0.1',
|
||||
realServerPort.getHost().port)
|
||||
proxyServerPort = reactor.listenTCP(0, self.proxyServerFactory,
|
||||
interface='127.0.0.1')
|
||||
self.openPorts.append(proxyServerPort)
|
||||
|
||||
nBytes = 1000
|
||||
received = []
|
||||
d = defer.Deferred()
|
||||
|
||||
def testDataReceived(data):
|
||||
received.extend(data)
|
||||
if len(received) >= nBytes:
|
||||
self.assertEqual(''.join(received), 'x' * nBytes)
|
||||
d.callback(None)
|
||||
|
||||
self.clientProtocol.dataReceived = testDataReceived
|
||||
|
||||
def testConnectionMade():
|
||||
self.clientProtocol.transport.write('x' * nBytes)
|
||||
|
||||
self.clientProtocol.connectionMade = testConnectionMade
|
||||
|
||||
clientFactory = protocol.ClientFactory()
|
||||
clientFactory.protocol = lambda: self.clientProtocol
|
||||
|
||||
reactor.connectTCP(
|
||||
'127.0.0.1', proxyServerPort.getHost().port, clientFactory)
|
||||
|
||||
return d
|
||||
|
||||
|
||||
def test_registerProducers(self):
|
||||
"""
|
||||
The proxy client registers itself as a producer of the proxy server and
|
||||
vice versa.
|
||||
"""
|
||||
# create a ProxyServer instance
|
||||
addr = address.IPv4Address('TCP', '127.0.0.1', 0)
|
||||
server = portforward.ProxyFactory('127.0.0.1', 0).buildProtocol(addr)
|
||||
|
||||
# set the reactor for this test
|
||||
reactor = proto_helpers.MemoryReactor()
|
||||
server.reactor = reactor
|
||||
|
||||
# make the connection
|
||||
serverTransport = proto_helpers.StringTransport()
|
||||
server.makeConnection(serverTransport)
|
||||
|
||||
# check that the ProxyClientFactory is connecting to the backend
|
||||
self.assertEqual(len(reactor.tcpClients), 1)
|
||||
# get the factory instance and check it's the one we expect
|
||||
host, port, clientFactory, timeout, _ = reactor.tcpClients[0]
|
||||
self.assertIsInstance(clientFactory, portforward.ProxyClientFactory)
|
||||
|
||||
# Connect it
|
||||
client = clientFactory.buildProtocol(addr)
|
||||
clientTransport = proto_helpers.StringTransport()
|
||||
client.makeConnection(clientTransport)
|
||||
|
||||
# check that the producers are registered
|
||||
self.assertIdentical(clientTransport.producer, serverTransport)
|
||||
self.assertIdentical(serverTransport.producer, clientTransport)
|
||||
# check the streaming attribute in both transports
|
||||
self.assertTrue(clientTransport.streaming)
|
||||
self.assertTrue(serverTransport.streaming)
|
||||
|
||||
|
||||
|
||||
class StringTransportTestCase(unittest.TestCase):
|
||||
"""
|
||||
Test L{proto_helpers.StringTransport} helper behaviour.
|
||||
"""
|
||||
|
||||
def test_noUnicode(self):
|
||||
"""
|
||||
Test that L{proto_helpers.StringTransport} doesn't accept unicode data.
|
||||
"""
|
||||
s = proto_helpers.StringTransport()
|
||||
self.assertRaises(TypeError, s.write, u'foo')
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test cases for L{twisted.python.randbytes}.
|
||||
"""
|
||||
|
||||
from __future__ import division, absolute_import
|
||||
|
||||
import os
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.python import randbytes
|
||||
|
||||
|
||||
|
||||
class SecureRandomTestCaseBase(object):
|
||||
"""
|
||||
Base class for secureRandom test cases.
|
||||
"""
|
||||
|
||||
def _check(self, source):
|
||||
"""
|
||||
The given random bytes source should return the number of bytes
|
||||
requested each time it is called and should probably not return the
|
||||
same bytes on two consecutive calls (although this is a perfectly
|
||||
legitimate occurrence and rejecting it may generate a spurious failure
|
||||
-- maybe we'll get lucky and the heat death with come first).
|
||||
"""
|
||||
for nbytes in range(17, 25):
|
||||
s = source(nbytes)
|
||||
self.assertEqual(len(s), nbytes)
|
||||
s2 = source(nbytes)
|
||||
self.assertEqual(len(s2), nbytes)
|
||||
# This is crude but hey
|
||||
self.assertNotEquals(s2, s)
|
||||
|
||||
|
||||
|
||||
class SecureRandomTestCase(SecureRandomTestCaseBase, unittest.TestCase):
|
||||
"""
|
||||
Test secureRandom under normal conditions.
|
||||
"""
|
||||
|
||||
def test_normal(self):
|
||||
"""
|
||||
L{randbytes.secureRandom} should return a string of the requested
|
||||
length and make some effort to make its result otherwise unpredictable.
|
||||
"""
|
||||
self._check(randbytes.secureRandom)
|
||||
|
||||
|
||||
|
||||
class ConditionalSecureRandomTestCase(SecureRandomTestCaseBase,
|
||||
unittest.SynchronousTestCase):
|
||||
"""
|
||||
Test random sources one by one, then remove it to.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Create a L{randbytes.RandomFactory} to use in the tests.
|
||||
"""
|
||||
self.factory = randbytes.RandomFactory()
|
||||
|
||||
|
||||
def errorFactory(self, nbytes):
|
||||
"""
|
||||
A factory raising an error when a source is not available.
|
||||
"""
|
||||
raise randbytes.SourceNotAvailable()
|
||||
|
||||
|
||||
def test_osUrandom(self):
|
||||
"""
|
||||
L{RandomFactory._osUrandom} should work as a random source whenever
|
||||
L{os.urandom} is available.
|
||||
"""
|
||||
self._check(self.factory._osUrandom)
|
||||
|
||||
|
||||
def test_withoutAnything(self):
|
||||
"""
|
||||
Remove all secure sources and assert it raises a failure. Then try the
|
||||
fallback parameter.
|
||||
"""
|
||||
self.factory._osUrandom = self.errorFactory
|
||||
self.assertRaises(randbytes.SecureRandomNotAvailable,
|
||||
self.factory.secureRandom, 18)
|
||||
def wrapper():
|
||||
return self.factory.secureRandom(18, fallback=True)
|
||||
s = self.assertWarns(
|
||||
RuntimeWarning,
|
||||
"urandom unavailable - "
|
||||
"proceeding with non-cryptographically secure random source",
|
||||
__file__,
|
||||
wrapper)
|
||||
self.assertEqual(len(s), 18)
|
||||
|
||||
|
||||
|
||||
class RandomTestCaseBase(SecureRandomTestCaseBase, unittest.SynchronousTestCase):
|
||||
"""
|
||||
'Normal' random test cases.
|
||||
"""
|
||||
|
||||
def test_normal(self):
|
||||
"""
|
||||
Test basic case.
|
||||
"""
|
||||
self._check(randbytes.insecureRandom)
|
||||
|
||||
|
||||
def test_withoutGetrandbits(self):
|
||||
"""
|
||||
Test C{insecureRandom} without C{random.getrandbits}.
|
||||
"""
|
||||
factory = randbytes.RandomFactory()
|
||||
factory.getrandbits = None
|
||||
self._check(factory.insecureRandom)
|
||||
|
||||
|
|
@ -0,0 +1,252 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
import sys, os
|
||||
import types
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.python import rebuild
|
||||
|
||||
import crash_test_dummy
|
||||
f = crash_test_dummy.foo
|
||||
|
||||
class Foo: pass
|
||||
class Bar(Foo): pass
|
||||
class Baz(object): pass
|
||||
class Buz(Bar, Baz): pass
|
||||
|
||||
class HashRaisesRuntimeError:
|
||||
"""
|
||||
Things that don't hash (raise an Exception) should be ignored by the
|
||||
rebuilder.
|
||||
|
||||
@ivar hashCalled: C{bool} set to True when __hash__ is called.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.hashCalled = False
|
||||
|
||||
def __hash__(self):
|
||||
self.hashCalled = True
|
||||
raise RuntimeError('not a TypeError!')
|
||||
|
||||
|
||||
|
||||
unhashableObject = None # set in test_hashException
|
||||
|
||||
|
||||
|
||||
class RebuildTestCase(unittest.TestCase):
|
||||
"""
|
||||
Simple testcase for rebuilding, to at least exercise the code.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.libPath = self.mktemp()
|
||||
os.mkdir(self.libPath)
|
||||
self.fakelibPath = os.path.join(self.libPath, 'twisted_rebuild_fakelib')
|
||||
os.mkdir(self.fakelibPath)
|
||||
file(os.path.join(self.fakelibPath, '__init__.py'), 'w').close()
|
||||
sys.path.insert(0, self.libPath)
|
||||
|
||||
def tearDown(self):
|
||||
sys.path.remove(self.libPath)
|
||||
|
||||
def testFileRebuild(self):
|
||||
from twisted.python.util import sibpath
|
||||
import shutil, time
|
||||
shutil.copyfile(sibpath(__file__, "myrebuilder1.py"),
|
||||
os.path.join(self.fakelibPath, "myrebuilder.py"))
|
||||
from twisted_rebuild_fakelib import myrebuilder
|
||||
a = myrebuilder.A()
|
||||
try:
|
||||
object
|
||||
except NameError:
|
||||
pass
|
||||
else:
|
||||
from twisted.test import test_rebuild
|
||||
b = myrebuilder.B()
|
||||
class C(myrebuilder.B):
|
||||
pass
|
||||
test_rebuild.C = C
|
||||
c = C()
|
||||
i = myrebuilder.Inherit()
|
||||
self.assertEqual(a.a(), 'a')
|
||||
# necessary because the file has not "changed" if a second has not gone
|
||||
# by in unix. This sucks, but it's not often that you'll be doing more
|
||||
# than one reload per second.
|
||||
time.sleep(1.1)
|
||||
shutil.copyfile(sibpath(__file__, "myrebuilder2.py"),
|
||||
os.path.join(self.fakelibPath, "myrebuilder.py"))
|
||||
rebuild.rebuild(myrebuilder)
|
||||
try:
|
||||
object
|
||||
except NameError:
|
||||
pass
|
||||
else:
|
||||
b2 = myrebuilder.B()
|
||||
self.assertEqual(b2.b(), 'c')
|
||||
self.assertEqual(b.b(), 'c')
|
||||
self.assertEqual(i.a(), 'd')
|
||||
self.assertEqual(a.a(), 'b')
|
||||
# more work to be done on new-style classes
|
||||
# self.assertEqual(c.b(), 'c')
|
||||
|
||||
def testRebuild(self):
|
||||
"""
|
||||
Rebuilding an unchanged module.
|
||||
"""
|
||||
# This test would actually pass if rebuild was a no-op, but it
|
||||
# ensures rebuild doesn't break stuff while being a less
|
||||
# complex test than testFileRebuild.
|
||||
|
||||
x = crash_test_dummy.X('a')
|
||||
|
||||
rebuild.rebuild(crash_test_dummy, doLog=False)
|
||||
# Instance rebuilding is triggered by attribute access.
|
||||
x.do()
|
||||
self.failUnlessIdentical(x.__class__, crash_test_dummy.X)
|
||||
|
||||
self.failUnlessIdentical(f, crash_test_dummy.foo)
|
||||
|
||||
def testComponentInteraction(self):
|
||||
x = crash_test_dummy.XComponent()
|
||||
x.setAdapter(crash_test_dummy.IX, crash_test_dummy.XA)
|
||||
oldComponent = x.getComponent(crash_test_dummy.IX)
|
||||
rebuild.rebuild(crash_test_dummy, 0)
|
||||
newComponent = x.getComponent(crash_test_dummy.IX)
|
||||
|
||||
newComponent.method()
|
||||
|
||||
self.assertEqual(newComponent.__class__, crash_test_dummy.XA)
|
||||
|
||||
# Test that a duplicate registerAdapter is not allowed
|
||||
from twisted.python import components
|
||||
self.failUnlessRaises(ValueError, components.registerAdapter,
|
||||
crash_test_dummy.XA, crash_test_dummy.X,
|
||||
crash_test_dummy.IX)
|
||||
|
||||
def testUpdateInstance(self):
|
||||
global Foo, Buz
|
||||
|
||||
b = Buz()
|
||||
|
||||
class Foo:
|
||||
def foo(self):
|
||||
pass
|
||||
class Buz(Bar, Baz):
|
||||
x = 10
|
||||
|
||||
rebuild.updateInstance(b)
|
||||
assert hasattr(b, 'foo'), "Missing method on rebuilt instance"
|
||||
assert hasattr(b, 'x'), "Missing class attribute on rebuilt instance"
|
||||
|
||||
def testBananaInteraction(self):
|
||||
from twisted.python import rebuild
|
||||
from twisted.spread import banana
|
||||
rebuild.latestClass(banana.Banana)
|
||||
|
||||
|
||||
def test_hashException(self):
|
||||
"""
|
||||
Rebuilding something that has a __hash__ that raises a non-TypeError
|
||||
shouldn't cause rebuild to die.
|
||||
"""
|
||||
global unhashableObject
|
||||
unhashableObject = HashRaisesRuntimeError()
|
||||
def _cleanup():
|
||||
global unhashableObject
|
||||
unhashableObject = None
|
||||
self.addCleanup(_cleanup)
|
||||
rebuild.rebuild(rebuild)
|
||||
self.assertEqual(unhashableObject.hashCalled, True)
|
||||
|
||||
|
||||
|
||||
class NewStyleTestCase(unittest.TestCase):
|
||||
"""
|
||||
Tests for rebuilding new-style classes of various sorts.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.m = types.ModuleType('whipping')
|
||||
sys.modules['whipping'] = self.m
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
del sys.modules['whipping']
|
||||
del self.m
|
||||
|
||||
|
||||
def test_slots(self):
|
||||
"""
|
||||
Try to rebuild a new style class with slots defined.
|
||||
"""
|
||||
classDefinition = (
|
||||
"class SlottedClass(object):\n"
|
||||
" __slots__ = ['a']\n")
|
||||
|
||||
exec classDefinition in self.m.__dict__
|
||||
inst = self.m.SlottedClass()
|
||||
inst.a = 7
|
||||
exec classDefinition in self.m.__dict__
|
||||
rebuild.updateInstance(inst)
|
||||
self.assertEqual(inst.a, 7)
|
||||
self.assertIdentical(type(inst), self.m.SlottedClass)
|
||||
|
||||
if sys.version_info < (2, 6):
|
||||
test_slots.skip = "__class__ assignment for class with slots is only available starting Python 2.6"
|
||||
|
||||
|
||||
def test_errorSlots(self):
|
||||
"""
|
||||
Try to rebuild a new style class with slots defined: this should fail.
|
||||
"""
|
||||
classDefinition = (
|
||||
"class SlottedClass(object):\n"
|
||||
" __slots__ = ['a']\n")
|
||||
|
||||
exec classDefinition in self.m.__dict__
|
||||
inst = self.m.SlottedClass()
|
||||
inst.a = 7
|
||||
exec classDefinition in self.m.__dict__
|
||||
self.assertRaises(rebuild.RebuildError, rebuild.updateInstance, inst)
|
||||
|
||||
if sys.version_info >= (2, 6):
|
||||
test_errorSlots.skip = "__class__ assignment for class with slots should work starting Python 2.6"
|
||||
|
||||
|
||||
def test_typeSubclass(self):
|
||||
"""
|
||||
Try to rebuild a base type subclass.
|
||||
"""
|
||||
classDefinition = (
|
||||
"class ListSubclass(list):\n"
|
||||
" pass\n")
|
||||
|
||||
exec classDefinition in self.m.__dict__
|
||||
inst = self.m.ListSubclass()
|
||||
inst.append(2)
|
||||
exec classDefinition in self.m.__dict__
|
||||
rebuild.updateInstance(inst)
|
||||
self.assertEqual(inst[0], 2)
|
||||
self.assertIdentical(type(inst), self.m.ListSubclass)
|
||||
|
||||
|
||||
def test_instanceSlots(self):
|
||||
"""
|
||||
Test that when rebuilding an instance with a __slots__ attribute, it
|
||||
fails accurately instead of giving a L{rebuild.RebuildError}.
|
||||
"""
|
||||
classDefinition = (
|
||||
"class NotSlottedClass(object):\n"
|
||||
" pass\n")
|
||||
|
||||
exec classDefinition in self.m.__dict__
|
||||
inst = self.m.NotSlottedClass()
|
||||
inst.__slots__ = ['a']
|
||||
classDefinition = (
|
||||
"class NotSlottedClass:\n"
|
||||
" pass\n")
|
||||
exec classDefinition in self.m.__dict__
|
||||
# Moving from new-style class to old-style should fail.
|
||||
self.assertRaises(TypeError, rebuild.updateInstance, inst)
|
||||
|
||||
|
|
@ -0,0 +1,900 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test cases for the L{twisted.python.reflect} module.
|
||||
"""
|
||||
|
||||
from __future__ import division, absolute_import
|
||||
|
||||
import os
|
||||
import weakref
|
||||
from collections import deque
|
||||
|
||||
from twisted.python.compat import _PY3
|
||||
from twisted.trial import unittest
|
||||
from twisted.trial.unittest import SynchronousTestCase as TestCase
|
||||
from twisted.python import reflect
|
||||
from twisted.python.reflect import (
|
||||
accumulateMethods, prefixedMethods, prefixedMethodNames,
|
||||
addMethodNamesToDict, fullyQualifiedName)
|
||||
from twisted.python.versions import Version
|
||||
|
||||
|
||||
class Base(object):
|
||||
"""
|
||||
A no-op class which can be used to verify the behavior of
|
||||
method-discovering APIs.
|
||||
"""
|
||||
|
||||
def method(self):
|
||||
"""
|
||||
A no-op method which can be discovered.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class Sub(Base):
|
||||
"""
|
||||
A subclass of a class with a method which can be discovered.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class Separate(object):
|
||||
"""
|
||||
A no-op class with methods with differing prefixes.
|
||||
"""
|
||||
|
||||
def good_method(self):
|
||||
"""
|
||||
A no-op method which a matching prefix to be discovered.
|
||||
"""
|
||||
|
||||
|
||||
def bad_method(self):
|
||||
"""
|
||||
A no-op method with a mismatched prefix to not be discovered.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class AccumulateMethodsTests(TestCase):
|
||||
"""
|
||||
Tests for L{accumulateMethods} which finds methods on a class hierarchy and
|
||||
adds them to a dictionary.
|
||||
"""
|
||||
|
||||
def test_ownClass(self):
|
||||
"""
|
||||
If x is and instance of Base and Base defines a method named method,
|
||||
L{accumulateMethods} adds an item to the given dictionary with
|
||||
C{"method"} as the key and a bound method object for Base.method value.
|
||||
"""
|
||||
x = Base()
|
||||
output = {}
|
||||
accumulateMethods(x, output)
|
||||
self.assertEqual({"method": x.method}, output)
|
||||
|
||||
|
||||
def test_baseClass(self):
|
||||
"""
|
||||
If x is an instance of Sub and Sub is a subclass of Base and Base
|
||||
defines a method named method, L{accumulateMethods} adds an item to the
|
||||
given dictionary with C{"method"} as the key and a bound method object
|
||||
for Base.method as the value.
|
||||
"""
|
||||
x = Sub()
|
||||
output = {}
|
||||
accumulateMethods(x, output)
|
||||
self.assertEqual({"method": x.method}, output)
|
||||
|
||||
|
||||
def test_prefix(self):
|
||||
"""
|
||||
If a prefix is given, L{accumulateMethods} limits its results to
|
||||
methods beginning with that prefix. Keys in the resulting dictionary
|
||||
also have the prefix removed from them.
|
||||
"""
|
||||
x = Separate()
|
||||
output = {}
|
||||
accumulateMethods(x, output, 'good_')
|
||||
self.assertEqual({'method': x.good_method}, output)
|
||||
|
||||
|
||||
|
||||
class PrefixedMethodsTests(TestCase):
|
||||
"""
|
||||
Tests for L{prefixedMethods} which finds methods on a class hierarchy and
|
||||
adds them to a dictionary.
|
||||
"""
|
||||
|
||||
def test_onlyObject(self):
|
||||
"""
|
||||
L{prefixedMethods} returns a list of the methods discovered on an
|
||||
object.
|
||||
"""
|
||||
x = Base()
|
||||
output = prefixedMethods(x)
|
||||
self.assertEqual([x.method], output)
|
||||
|
||||
|
||||
def test_prefix(self):
|
||||
"""
|
||||
If a prefix is given, L{prefixedMethods} returns only methods named
|
||||
with that prefix.
|
||||
"""
|
||||
x = Separate()
|
||||
output = prefixedMethods(x, 'good_')
|
||||
self.assertEqual([x.good_method], output)
|
||||
|
||||
|
||||
|
||||
class PrefixedMethodNamesTests(TestCase):
|
||||
"""
|
||||
Tests for L{prefixedMethodNames}.
|
||||
"""
|
||||
def test_method(self):
|
||||
"""
|
||||
L{prefixedMethodNames} returns a list including methods with the given
|
||||
prefix defined on the class passed to it.
|
||||
"""
|
||||
self.assertEqual(["method"], prefixedMethodNames(Separate, "good_"))
|
||||
|
||||
|
||||
def test_inheritedMethod(self):
|
||||
"""
|
||||
L{prefixedMethodNames} returns a list included methods with the given
|
||||
prefix defined on base classes of the class passed to it.
|
||||
"""
|
||||
class Child(Separate):
|
||||
pass
|
||||
self.assertEqual(["method"], prefixedMethodNames(Child, "good_"))
|
||||
|
||||
|
||||
|
||||
class AddMethodNamesToDictTests(TestCase):
|
||||
"""
|
||||
Tests for L{addMethodNamesToDict}.
|
||||
"""
|
||||
def test_baseClass(self):
|
||||
"""
|
||||
If C{baseClass} is passed to L{addMethodNamesToDict}, only methods which
|
||||
are a subclass of C{baseClass} are added to the result dictionary.
|
||||
"""
|
||||
class Alternate(object):
|
||||
pass
|
||||
|
||||
class Child(Separate, Alternate):
|
||||
def good_alternate(self):
|
||||
pass
|
||||
|
||||
result = {}
|
||||
addMethodNamesToDict(Child, result, 'good_', Alternate)
|
||||
self.assertEqual({'alternate': 1}, result)
|
||||
|
||||
|
||||
|
||||
class Summer(object):
|
||||
"""
|
||||
A class we look up as part of the LookupsTestCase.
|
||||
"""
|
||||
|
||||
def reallySet(self):
|
||||
"""
|
||||
Do something.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class LookupsTestCase(TestCase):
|
||||
"""
|
||||
Tests for L{namedClass}, L{namedModule}, and L{namedAny}.
|
||||
"""
|
||||
|
||||
def test_namedClassLookup(self):
|
||||
"""
|
||||
L{namedClass} should return the class object for the name it is passed.
|
||||
"""
|
||||
self.assertIdentical(
|
||||
reflect.namedClass("twisted.test.test_reflect.Summer"),
|
||||
Summer)
|
||||
|
||||
|
||||
def test_namedModuleLookup(self):
|
||||
"""
|
||||
L{namedModule} should return the module object for the name it is
|
||||
passed.
|
||||
"""
|
||||
from twisted.python import monkey
|
||||
self.assertIdentical(
|
||||
reflect.namedModule("twisted.python.monkey"), monkey)
|
||||
|
||||
|
||||
def test_namedAnyPackageLookup(self):
|
||||
"""
|
||||
L{namedAny} should return the package object for the name it is passed.
|
||||
"""
|
||||
import twisted.python
|
||||
self.assertIdentical(
|
||||
reflect.namedAny("twisted.python"), twisted.python)
|
||||
|
||||
|
||||
def test_namedAnyModuleLookup(self):
|
||||
"""
|
||||
L{namedAny} should return the module object for the name it is passed.
|
||||
"""
|
||||
from twisted.python import monkey
|
||||
self.assertIdentical(
|
||||
reflect.namedAny("twisted.python.monkey"), monkey)
|
||||
|
||||
|
||||
def test_namedAnyClassLookup(self):
|
||||
"""
|
||||
L{namedAny} should return the class object for the name it is passed.
|
||||
"""
|
||||
self.assertIdentical(
|
||||
reflect.namedAny("twisted.test.test_reflect.Summer"),
|
||||
Summer)
|
||||
|
||||
|
||||
def test_namedAnyAttributeLookup(self):
|
||||
"""
|
||||
L{namedAny} should return the object an attribute of a non-module,
|
||||
non-package object is bound to for the name it is passed.
|
||||
"""
|
||||
# Note - not assertEqual because unbound method lookup creates a new
|
||||
# object every time. This is a foolishness of Python's object
|
||||
# implementation, not a bug in Twisted.
|
||||
self.assertEqual(
|
||||
reflect.namedAny(
|
||||
"twisted.test.test_reflect.Summer.reallySet"),
|
||||
Summer.reallySet)
|
||||
|
||||
|
||||
def test_namedAnySecondAttributeLookup(self):
|
||||
"""
|
||||
L{namedAny} should return the object an attribute of an object which
|
||||
itself was an attribute of a non-module, non-package object is bound to
|
||||
for the name it is passed.
|
||||
"""
|
||||
self.assertIdentical(
|
||||
reflect.namedAny(
|
||||
"twisted.test.test_reflect."
|
||||
"Summer.reallySet.__doc__"),
|
||||
Summer.reallySet.__doc__)
|
||||
|
||||
|
||||
def test_importExceptions(self):
|
||||
"""
|
||||
Exceptions raised by modules which L{namedAny} causes to be imported
|
||||
should pass through L{namedAny} to the caller.
|
||||
"""
|
||||
self.assertRaises(
|
||||
ZeroDivisionError,
|
||||
reflect.namedAny, "twisted.test.reflect_helper_ZDE")
|
||||
# Make sure that there is post-failed-import cleanup
|
||||
self.assertRaises(
|
||||
ZeroDivisionError,
|
||||
reflect.namedAny, "twisted.test.reflect_helper_ZDE")
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
reflect.namedAny, "twisted.test.reflect_helper_VE")
|
||||
# Modules which themselves raise ImportError when imported should
|
||||
# result in an ImportError
|
||||
self.assertRaises(
|
||||
ImportError,
|
||||
reflect.namedAny, "twisted.test.reflect_helper_IE")
|
||||
|
||||
|
||||
def test_attributeExceptions(self):
|
||||
"""
|
||||
If segments on the end of a fully-qualified Python name represents
|
||||
attributes which aren't actually present on the object represented by
|
||||
the earlier segments, L{namedAny} should raise an L{AttributeError}.
|
||||
"""
|
||||
self.assertRaises(
|
||||
AttributeError,
|
||||
reflect.namedAny, "twisted.nosuchmoduleintheworld")
|
||||
# ImportError behaves somewhat differently between "import
|
||||
# extant.nonextant" and "import extant.nonextant.nonextant", so test
|
||||
# the latter as well.
|
||||
self.assertRaises(
|
||||
AttributeError,
|
||||
reflect.namedAny, "twisted.nosuch.modulein.theworld")
|
||||
self.assertRaises(
|
||||
AttributeError,
|
||||
reflect.namedAny,
|
||||
"twisted.test.test_reflect.Summer.nosuchattribute")
|
||||
|
||||
|
||||
def test_invalidNames(self):
|
||||
"""
|
||||
Passing a name which isn't a fully-qualified Python name to L{namedAny}
|
||||
should result in one of the following exceptions:
|
||||
- L{InvalidName}: the name is not a dot-separated list of Python
|
||||
objects
|
||||
- L{ObjectNotFound}: the object doesn't exist
|
||||
- L{ModuleNotFound}: the object doesn't exist and there is only one
|
||||
component in the name
|
||||
"""
|
||||
err = self.assertRaises(reflect.ModuleNotFound, reflect.namedAny,
|
||||
'nosuchmoduleintheworld')
|
||||
self.assertEqual(str(err), "No module named 'nosuchmoduleintheworld'")
|
||||
|
||||
# This is a dot-separated list, but it isn't valid!
|
||||
err = self.assertRaises(reflect.ObjectNotFound, reflect.namedAny,
|
||||
"@#$@(#.!@(#!@#")
|
||||
self.assertEqual(str(err), "'@#$@(#.!@(#!@#' does not name an object")
|
||||
|
||||
err = self.assertRaises(reflect.ObjectNotFound, reflect.namedAny,
|
||||
"tcelfer.nohtyp.detsiwt")
|
||||
self.assertEqual(
|
||||
str(err),
|
||||
"'tcelfer.nohtyp.detsiwt' does not name an object")
|
||||
|
||||
err = self.assertRaises(reflect.InvalidName, reflect.namedAny, '')
|
||||
self.assertEqual(str(err), 'Empty module name')
|
||||
|
||||
for invalidName in ['.twisted', 'twisted.', 'twisted..python']:
|
||||
err = self.assertRaises(
|
||||
reflect.InvalidName, reflect.namedAny, invalidName)
|
||||
self.assertEqual(
|
||||
str(err),
|
||||
"name must be a string giving a '.'-separated list of Python "
|
||||
"identifiers, not %r" % (invalidName,))
|
||||
|
||||
|
||||
def test_requireModuleImportError(self):
|
||||
"""
|
||||
When module import fails with ImportError it returns the specified
|
||||
default value.
|
||||
"""
|
||||
for name in ['nosuchmtopodule', 'no.such.module']:
|
||||
default = object()
|
||||
|
||||
result = reflect.requireModule(name, default=default)
|
||||
|
||||
self.assertIs(result, default)
|
||||
|
||||
|
||||
def test_requireModuleDefaultNone(self):
|
||||
"""
|
||||
When module import fails it returns C{None} by default.
|
||||
"""
|
||||
result = reflect.requireModule('no.such.module')
|
||||
|
||||
self.assertIs(None, result)
|
||||
|
||||
|
||||
def test_requireModuleRequestedImport(self):
|
||||
"""
|
||||
When module import succeed it returns the module and not the default
|
||||
value.
|
||||
"""
|
||||
from twisted.python import monkey
|
||||
default = object()
|
||||
|
||||
self.assertIs(
|
||||
reflect.requireModule('twisted.python.monkey', default=default),
|
||||
monkey,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class Breakable(object):
|
||||
|
||||
breakRepr = False
|
||||
breakStr = False
|
||||
|
||||
def __str__(self):
|
||||
if self.breakStr:
|
||||
raise RuntimeError("str!")
|
||||
else:
|
||||
return '<Breakable>'
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
if self.breakRepr:
|
||||
raise RuntimeError("repr!")
|
||||
else:
|
||||
return 'Breakable()'
|
||||
|
||||
|
||||
|
||||
class BrokenType(Breakable, type):
|
||||
breakName = False
|
||||
|
||||
def get___name__(self):
|
||||
if self.breakName:
|
||||
raise RuntimeError("no name")
|
||||
return 'BrokenType'
|
||||
__name__ = property(get___name__)
|
||||
|
||||
|
||||
|
||||
BTBase = BrokenType('BTBase', (Breakable,),
|
||||
{"breakRepr": True,
|
||||
"breakStr": True})
|
||||
|
||||
|
||||
|
||||
class NoClassAttr(Breakable):
|
||||
__class__ = property(lambda x: x.not_class)
|
||||
|
||||
|
||||
|
||||
class SafeRepr(TestCase):
|
||||
"""
|
||||
Tests for L{reflect.safe_repr} function.
|
||||
"""
|
||||
|
||||
def test_workingRepr(self):
|
||||
"""
|
||||
L{reflect.safe_repr} produces the same output as C{repr} on a working
|
||||
object.
|
||||
"""
|
||||
x = [1, 2, 3]
|
||||
self.assertEqual(reflect.safe_repr(x), repr(x))
|
||||
|
||||
|
||||
def test_brokenRepr(self):
|
||||
"""
|
||||
L{reflect.safe_repr} returns a string with class name, address, and
|
||||
traceback when the repr call failed.
|
||||
"""
|
||||
b = Breakable()
|
||||
b.breakRepr = True
|
||||
bRepr = reflect.safe_repr(b)
|
||||
self.assertIn("Breakable instance at 0x", bRepr)
|
||||
# Check that the file is in the repr, but without the extension as it
|
||||
# can be .py/.pyc
|
||||
self.assertIn(os.path.splitext(__file__)[0], bRepr)
|
||||
self.assertIn("RuntimeError: repr!", bRepr)
|
||||
|
||||
|
||||
def test_brokenStr(self):
|
||||
"""
|
||||
L{reflect.safe_repr} isn't affected by a broken C{__str__} method.
|
||||
"""
|
||||
b = Breakable()
|
||||
b.breakStr = True
|
||||
self.assertEqual(reflect.safe_repr(b), repr(b))
|
||||
|
||||
|
||||
def test_brokenClassRepr(self):
|
||||
class X(BTBase):
|
||||
breakRepr = True
|
||||
reflect.safe_repr(X)
|
||||
reflect.safe_repr(X())
|
||||
|
||||
|
||||
def test_brokenReprIncludesID(self):
|
||||
"""
|
||||
C{id} is used to print the ID of the object in case of an error.
|
||||
|
||||
L{safe_repr} includes a traceback after a newline, so we only check
|
||||
against the first line of the repr.
|
||||
"""
|
||||
class X(BTBase):
|
||||
breakRepr = True
|
||||
|
||||
xRepr = reflect.safe_repr(X)
|
||||
xReprExpected = ('<BrokenType instance at 0x%x with repr error:'
|
||||
% (id(X),))
|
||||
self.assertEqual(xReprExpected, xRepr.split('\n')[0])
|
||||
|
||||
|
||||
def test_brokenClassStr(self):
|
||||
class X(BTBase):
|
||||
breakStr = True
|
||||
reflect.safe_repr(X)
|
||||
reflect.safe_repr(X())
|
||||
|
||||
|
||||
def test_brokenClassAttribute(self):
|
||||
"""
|
||||
If an object raises an exception when accessing its C{__class__}
|
||||
attribute, L{reflect.safe_repr} uses C{type} to retrieve the class
|
||||
object.
|
||||
"""
|
||||
b = NoClassAttr()
|
||||
b.breakRepr = True
|
||||
bRepr = reflect.safe_repr(b)
|
||||
self.assertIn("NoClassAttr instance at 0x", bRepr)
|
||||
self.assertIn(os.path.splitext(__file__)[0], bRepr)
|
||||
self.assertIn("RuntimeError: repr!", bRepr)
|
||||
|
||||
|
||||
def test_brokenClassNameAttribute(self):
|
||||
"""
|
||||
If a class raises an exception when accessing its C{__name__} attribute
|
||||
B{and} when calling its C{__str__} implementation, L{reflect.safe_repr}
|
||||
returns 'BROKEN CLASS' instead of the class name.
|
||||
"""
|
||||
class X(BTBase):
|
||||
breakName = True
|
||||
xRepr = reflect.safe_repr(X())
|
||||
self.assertIn("<BROKEN CLASS AT 0x", xRepr)
|
||||
self.assertIn(os.path.splitext(__file__)[0], xRepr)
|
||||
self.assertIn("RuntimeError: repr!", xRepr)
|
||||
|
||||
|
||||
|
||||
class SafeStr(TestCase):
|
||||
"""
|
||||
Tests for L{reflect.safe_str} function.
|
||||
"""
|
||||
|
||||
def test_workingStr(self):
|
||||
x = [1, 2, 3]
|
||||
self.assertEqual(reflect.safe_str(x), str(x))
|
||||
|
||||
|
||||
def test_brokenStr(self):
|
||||
b = Breakable()
|
||||
b.breakStr = True
|
||||
reflect.safe_str(b)
|
||||
|
||||
|
||||
def test_brokenRepr(self):
|
||||
b = Breakable()
|
||||
b.breakRepr = True
|
||||
reflect.safe_str(b)
|
||||
|
||||
|
||||
def test_brokenClassStr(self):
|
||||
class X(BTBase):
|
||||
breakStr = True
|
||||
reflect.safe_str(X)
|
||||
reflect.safe_str(X())
|
||||
|
||||
|
||||
def test_brokenClassRepr(self):
|
||||
class X(BTBase):
|
||||
breakRepr = True
|
||||
reflect.safe_str(X)
|
||||
reflect.safe_str(X())
|
||||
|
||||
|
||||
def test_brokenClassAttribute(self):
|
||||
"""
|
||||
If an object raises an exception when accessing its C{__class__}
|
||||
attribute, L{reflect.safe_str} uses C{type} to retrieve the class
|
||||
object.
|
||||
"""
|
||||
b = NoClassAttr()
|
||||
b.breakStr = True
|
||||
bStr = reflect.safe_str(b)
|
||||
self.assertIn("NoClassAttr instance at 0x", bStr)
|
||||
self.assertIn(os.path.splitext(__file__)[0], bStr)
|
||||
self.assertIn("RuntimeError: str!", bStr)
|
||||
|
||||
|
||||
def test_brokenClassNameAttribute(self):
|
||||
"""
|
||||
If a class raises an exception when accessing its C{__name__} attribute
|
||||
B{and} when calling its C{__str__} implementation, L{reflect.safe_str}
|
||||
returns 'BROKEN CLASS' instead of the class name.
|
||||
"""
|
||||
class X(BTBase):
|
||||
breakName = True
|
||||
xStr = reflect.safe_str(X())
|
||||
self.assertIn("<BROKEN CLASS AT 0x", xStr)
|
||||
self.assertIn(os.path.splitext(__file__)[0], xStr)
|
||||
self.assertIn("RuntimeError: str!", xStr)
|
||||
|
||||
|
||||
|
||||
class FilenameToModule(TestCase):
|
||||
"""
|
||||
Test L{filenameToModuleName} detection.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.path = os.path.join(self.mktemp(), "fakepackage", "test")
|
||||
os.makedirs(self.path)
|
||||
with open(os.path.join(self.path, "__init__.py"), "w") as f:
|
||||
f.write("")
|
||||
with open(os.path.join(os.path.dirname(self.path), "__init__.py"),
|
||||
"w") as f:
|
||||
f.write("")
|
||||
|
||||
|
||||
def test_directory(self):
|
||||
"""
|
||||
L{filenameToModuleName} returns the correct module (a package) given a
|
||||
directory.
|
||||
"""
|
||||
module = reflect.filenameToModuleName(self.path)
|
||||
self.assertEqual(module, 'fakepackage.test')
|
||||
module = reflect.filenameToModuleName(self.path + os.path.sep)
|
||||
self.assertEqual(module, 'fakepackage.test')
|
||||
|
||||
|
||||
def test_file(self):
|
||||
"""
|
||||
L{filenameToModuleName} returns the correct module given the path to
|
||||
its file.
|
||||
"""
|
||||
module = reflect.filenameToModuleName(
|
||||
os.path.join(self.path, 'test_reflect.py'))
|
||||
self.assertEqual(module, 'fakepackage.test.test_reflect')
|
||||
|
||||
|
||||
def test_bytes(self):
|
||||
"""
|
||||
L{filenameToModuleName} returns the correct module given a C{bytes}
|
||||
path to its file.
|
||||
"""
|
||||
module = reflect.filenameToModuleName(
|
||||
os.path.join(self.path.encode("utf-8"), b'test_reflect.py'))
|
||||
# Module names are always native string:
|
||||
self.assertEqual(module, 'fakepackage.test.test_reflect')
|
||||
|
||||
|
||||
|
||||
class FullyQualifiedNameTests(TestCase):
|
||||
"""
|
||||
Test for L{fullyQualifiedName}.
|
||||
"""
|
||||
|
||||
def _checkFullyQualifiedName(self, obj, expected):
|
||||
"""
|
||||
Helper to check that fully qualified name of C{obj} results to
|
||||
C{expected}.
|
||||
"""
|
||||
self.assertEqual(fullyQualifiedName(obj), expected)
|
||||
|
||||
|
||||
def test_package(self):
|
||||
"""
|
||||
L{fullyQualifiedName} returns the full name of a package and a
|
||||
subpackage.
|
||||
"""
|
||||
import twisted
|
||||
self._checkFullyQualifiedName(twisted, 'twisted')
|
||||
import twisted.python
|
||||
self._checkFullyQualifiedName(twisted.python, 'twisted.python')
|
||||
|
||||
|
||||
def test_module(self):
|
||||
"""
|
||||
L{fullyQualifiedName} returns the name of a module inside a a package.
|
||||
"""
|
||||
import twisted.python.compat
|
||||
self._checkFullyQualifiedName(
|
||||
twisted.python.compat, 'twisted.python.compat')
|
||||
|
||||
|
||||
def test_class(self):
|
||||
"""
|
||||
L{fullyQualifiedName} returns the name of a class and its module.
|
||||
"""
|
||||
self._checkFullyQualifiedName(
|
||||
FullyQualifiedNameTests,
|
||||
'%s.FullyQualifiedNameTests' % (__name__,))
|
||||
|
||||
|
||||
def test_function(self):
|
||||
"""
|
||||
L{fullyQualifiedName} returns the name of a function inside its module.
|
||||
"""
|
||||
self._checkFullyQualifiedName(
|
||||
fullyQualifiedName, "twisted.python.reflect.fullyQualifiedName")
|
||||
|
||||
|
||||
def test_boundMethod(self):
|
||||
"""
|
||||
L{fullyQualifiedName} returns the name of a bound method inside its
|
||||
class and its module.
|
||||
"""
|
||||
self._checkFullyQualifiedName(
|
||||
self.test_boundMethod,
|
||||
"%s.%s.test_boundMethod" % (__name__, self.__class__.__name__))
|
||||
|
||||
|
||||
def test_unboundMethod(self):
|
||||
"""
|
||||
L{fullyQualifiedName} returns the name of an unbound method inside its
|
||||
class and its module.
|
||||
"""
|
||||
self._checkFullyQualifiedName(
|
||||
self.__class__.test_unboundMethod,
|
||||
"%s.%s.test_unboundMethod" % (__name__, self.__class__.__name__))
|
||||
|
||||
|
||||
class ObjectGrep(unittest.TestCase):
|
||||
if _PY3:
|
||||
# This is to be removed when fixing #6986
|
||||
skip = "twisted.python.reflect.objgrep hasn't been ported to Python 3"
|
||||
|
||||
|
||||
def test_dictionary(self):
|
||||
"""
|
||||
Test references search through a dictionnary, as a key or as a value.
|
||||
"""
|
||||
o = object()
|
||||
d1 = {None: o}
|
||||
d2 = {o: None}
|
||||
|
||||
self.assertIn("[None]", reflect.objgrep(d1, o, reflect.isSame))
|
||||
self.assertIn("{None}", reflect.objgrep(d2, o, reflect.isSame))
|
||||
|
||||
def test_list(self):
|
||||
"""
|
||||
Test references search through a list.
|
||||
"""
|
||||
o = object()
|
||||
L = [None, o]
|
||||
|
||||
self.assertIn("[1]", reflect.objgrep(L, o, reflect.isSame))
|
||||
|
||||
def test_tuple(self):
|
||||
"""
|
||||
Test references search through a tuple.
|
||||
"""
|
||||
o = object()
|
||||
T = (o, None)
|
||||
|
||||
self.assertIn("[0]", reflect.objgrep(T, o, reflect.isSame))
|
||||
|
||||
def test_instance(self):
|
||||
"""
|
||||
Test references search through an object attribute.
|
||||
"""
|
||||
class Dummy:
|
||||
pass
|
||||
o = object()
|
||||
d = Dummy()
|
||||
d.o = o
|
||||
|
||||
self.assertIn(".o", reflect.objgrep(d, o, reflect.isSame))
|
||||
|
||||
def test_weakref(self):
|
||||
"""
|
||||
Test references search through a weakref object.
|
||||
"""
|
||||
class Dummy:
|
||||
pass
|
||||
o = Dummy()
|
||||
w1 = weakref.ref(o)
|
||||
|
||||
self.assertIn("()", reflect.objgrep(w1, o, reflect.isSame))
|
||||
|
||||
def test_boundMethod(self):
|
||||
"""
|
||||
Test references search through method special attributes.
|
||||
"""
|
||||
class Dummy:
|
||||
def dummy(self):
|
||||
pass
|
||||
o = Dummy()
|
||||
m = o.dummy
|
||||
|
||||
self.assertIn(".__self__",
|
||||
reflect.objgrep(m, m.__self__, reflect.isSame))
|
||||
self.assertIn(".__self__.__class__",
|
||||
reflect.objgrep(m, m.__self__.__class__, reflect.isSame))
|
||||
self.assertIn(".__func__",
|
||||
reflect.objgrep(m, m.__func__, reflect.isSame))
|
||||
|
||||
def test_everything(self):
|
||||
"""
|
||||
Test references search using complex set of objects.
|
||||
"""
|
||||
class Dummy:
|
||||
def method(self):
|
||||
pass
|
||||
|
||||
o = Dummy()
|
||||
D1 = {(): "baz", None: "Quux", o: "Foosh"}
|
||||
L = [None, (), D1, 3]
|
||||
T = (L, {}, Dummy())
|
||||
D2 = {0: "foo", 1: "bar", 2: T}
|
||||
i = Dummy()
|
||||
i.attr = D2
|
||||
m = i.method
|
||||
w = weakref.ref(m)
|
||||
|
||||
self.assertIn("().__self__.attr[2][0][2]{'Foosh'}",
|
||||
reflect.objgrep(w, o, reflect.isSame))
|
||||
|
||||
def test_depthLimit(self):
|
||||
"""
|
||||
Test the depth of references search.
|
||||
"""
|
||||
a = []
|
||||
b = [a]
|
||||
c = [a, b]
|
||||
d = [a, c]
|
||||
|
||||
self.assertEqual(['[0]'], reflect.objgrep(d, a, reflect.isSame, maxDepth=1))
|
||||
self.assertEqual(['[0]', '[1][0]'], reflect.objgrep(d, a, reflect.isSame, maxDepth=2))
|
||||
self.assertEqual(['[0]', '[1][0]', '[1][1][0]'], reflect.objgrep(d, a, reflect.isSame, maxDepth=3))
|
||||
|
||||
|
||||
def test_deque(self):
|
||||
"""
|
||||
Test references search through a deque object.
|
||||
"""
|
||||
o = object()
|
||||
D = deque()
|
||||
D.append(None)
|
||||
D.append(o)
|
||||
|
||||
self.assertIn("[1]", reflect.objgrep(D, o, reflect.isSame))
|
||||
|
||||
|
||||
class GetClass(unittest.TestCase):
|
||||
if _PY3:
|
||||
oldClassNames = ['type']
|
||||
else:
|
||||
oldClassNames = ['class', 'classobj']
|
||||
|
||||
def testOld(self):
|
||||
class OldClass:
|
||||
pass
|
||||
old = OldClass()
|
||||
self.assertIn(reflect.getClass(OldClass).__name__, self.oldClassNames)
|
||||
self.assertEqual(reflect.getClass(old).__name__, 'OldClass')
|
||||
|
||||
def testNew(self):
|
||||
class NewClass(object):
|
||||
pass
|
||||
new = NewClass()
|
||||
self.assertEqual(reflect.getClass(NewClass).__name__, 'type')
|
||||
self.assertEqual(reflect.getClass(new).__name__, 'NewClass')
|
||||
|
||||
|
||||
if not _PY3:
|
||||
# The functions tested below are deprecated but still used by external
|
||||
# projects like Nevow 0.10. They are not going to be ported to Python 3
|
||||
# (hence the condition above) and will be removed as soon as no project used
|
||||
# by Twisted will depend on these functions. Also, have a look at the
|
||||
# comments related to those functions in twisted.python.reflect.
|
||||
class DeprecationTestCase(unittest.TestCase):
|
||||
"""
|
||||
Test deprecations in twisted.python.reflect
|
||||
"""
|
||||
|
||||
def test_allYourBase(self):
|
||||
"""
|
||||
Test deprecation of L{reflect.allYourBase}. See #5481 for removal.
|
||||
"""
|
||||
self.callDeprecated(
|
||||
(Version("Twisted", 11, 0, 0), "inspect.getmro"),
|
||||
reflect.allYourBase, DeprecationTestCase)
|
||||
|
||||
|
||||
def test_accumulateBases(self):
|
||||
"""
|
||||
Test deprecation of L{reflect.accumulateBases}. See #5481 for removal.
|
||||
"""
|
||||
l = []
|
||||
self.callDeprecated(
|
||||
(Version("Twisted", 11, 0, 0), "inspect.getmro"),
|
||||
reflect.accumulateBases, DeprecationTestCase, l, None)
|
||||
|
||||
|
||||
def test_getcurrent(self):
|
||||
"""
|
||||
Test deprecation of L{reflect.getcurrent}.
|
||||
"""
|
||||
|
||||
class C:
|
||||
pass
|
||||
|
||||
self.callDeprecated(
|
||||
Version("Twisted", 14, 0, 0),
|
||||
reflect.getcurrent, C)
|
||||
|
||||
|
||||
def test_isinst(self):
|
||||
"""
|
||||
Test deprecation of L{reflect.isinst}.
|
||||
"""
|
||||
|
||||
self.callDeprecated(
|
||||
(Version("Twisted", 14, 0, 0), "isinstance"),
|
||||
reflect.isinst, object(), object)
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.python import roots
|
||||
import types
|
||||
|
||||
class RootsTest(unittest.TestCase):
|
||||
|
||||
def testExceptions(self):
|
||||
request = roots.Request()
|
||||
try:
|
||||
request.write("blah")
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
self.fail()
|
||||
try:
|
||||
request.finish()
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
self.fail()
|
||||
|
||||
def testCollection(self):
|
||||
collection = roots.Collection()
|
||||
collection.putEntity("x", 'test')
|
||||
self.assertEqual(collection.getStaticEntity("x"),
|
||||
'test')
|
||||
collection.delEntity("x")
|
||||
self.assertEqual(collection.getStaticEntity('x'),
|
||||
None)
|
||||
try:
|
||||
collection.storeEntity("x", None)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
self.fail()
|
||||
try:
|
||||
collection.removeEntity("x", None)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
self.fail()
|
||||
|
||||
def testConstrained(self):
|
||||
class const(roots.Constrained):
|
||||
def nameConstraint(self, name):
|
||||
return (name == 'x')
|
||||
c = const()
|
||||
self.assertEqual(c.putEntity('x', 'test'), None)
|
||||
self.failUnlessRaises(roots.ConstraintViolation,
|
||||
c.putEntity, 'y', 'test')
|
||||
|
||||
|
||||
def testHomogenous(self):
|
||||
h = roots.Homogenous()
|
||||
h.entityType = types.IntType
|
||||
h.putEntity('a', 1)
|
||||
self.assertEqual(h.getStaticEntity('a'),1 )
|
||||
self.failUnlessRaises(roots.ConstraintViolation,
|
||||
h.putEntity, 'x', 'y')
|
||||
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for C{setup.py}, Twisted's distutils integration file.
|
||||
"""
|
||||
|
||||
from __future__ import division, absolute_import
|
||||
|
||||
import os, sys
|
||||
|
||||
import twisted
|
||||
from twisted.trial.unittest import SynchronousTestCase
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.python.dist import getExtensions
|
||||
|
||||
# Get rid of the UTF-8 encoding and bytes topfiles segment when FilePath
|
||||
# supports unicode. #2366, #4736, #5203. Also #4743, which requires checking
|
||||
# setup.py, not just the topfiles directory.
|
||||
if not FilePath(twisted.__file__.encode('utf-8')).sibling(b'topfiles').child(b'setup.py').exists():
|
||||
sourceSkip = "Only applies to source checkout of Twisted"
|
||||
else:
|
||||
sourceSkip = None
|
||||
|
||||
|
||||
class TwistedExtensionsTests(SynchronousTestCase):
|
||||
if sourceSkip is not None:
|
||||
skip = sourceSkip
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Change the working directory to the parent of the C{twisted} package so
|
||||
that L{twisted.python.dist.getExtensions} finds Twisted's own extension
|
||||
definitions.
|
||||
"""
|
||||
self.addCleanup(os.chdir, os.getcwd())
|
||||
os.chdir(FilePath(twisted.__file__).parent().parent().path)
|
||||
|
||||
|
||||
def test_initgroups(self):
|
||||
"""
|
||||
If C{os.initgroups} is present (Python 2.7 and Python 3.3 and newer),
|
||||
L{twisted.python._initgroups} is not returned as an extension to build
|
||||
from L{getExtensions}.
|
||||
"""
|
||||
extensions = getExtensions()
|
||||
found = None
|
||||
for extension in extensions:
|
||||
if extension.name == "twisted.python._initgroups":
|
||||
found = extension
|
||||
|
||||
if sys.version_info[:2] >= (2, 7):
|
||||
self.assertIdentical(
|
||||
None, found,
|
||||
"Should not have found twisted.python._initgroups extension "
|
||||
"definition.")
|
||||
else:
|
||||
self.assertNotIdentical(
|
||||
None, found,
|
||||
"Should have found twisted.python._initgroups extension "
|
||||
"definition.")
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
"""Test win32 shortcut script
|
||||
"""
|
||||
|
||||
from twisted.trial import unittest
|
||||
|
||||
import os
|
||||
if os.name == 'nt':
|
||||
|
||||
skipWindowsNopywin32 = None
|
||||
try:
|
||||
from twisted.python import shortcut
|
||||
except ImportError:
|
||||
skipWindowsNopywin32 = ("On windows, twisted.python.shortcut is not "
|
||||
"available in the absence of win32com.")
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
class ShortcutTest(unittest.TestCase):
|
||||
def testCreate(self):
|
||||
s1=shortcut.Shortcut("test_shortcut.py")
|
||||
tempname=self.mktemp() + '.lnk'
|
||||
s1.save(tempname)
|
||||
self.assert_(os.path.exists(tempname))
|
||||
sc=shortcut.open(tempname)
|
||||
self.assert_(sc.GetPath(0)[0].endswith('test_shortcut.py'))
|
||||
ShortcutTest.skip = skipWindowsNopywin32
|
||||
984
Linux_i686/lib/python2.7/site-packages/twisted/test/test_sip.py
Normal file
984
Linux_i686/lib/python2.7/site-packages/twisted/test/test_sip.py
Normal file
|
|
@ -0,0 +1,984 @@
|
|||
# -*- test-case-name: twisted.test.test_sip -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""Session Initialization Protocol tests."""
|
||||
|
||||
from twisted.trial import unittest, util
|
||||
from twisted.protocols import sip
|
||||
from twisted.internet import defer, reactor, utils
|
||||
from twisted.python.versions import Version
|
||||
|
||||
from twisted.test import proto_helpers
|
||||
|
||||
from twisted import cred
|
||||
import twisted.cred.portal
|
||||
import twisted.cred.checkers
|
||||
|
||||
from zope.interface import implements
|
||||
|
||||
|
||||
# request, prefixed by random CRLFs
|
||||
request1 = "\n\r\n\n\r" + """\
|
||||
INVITE sip:foo SIP/2.0
|
||||
From: mo
|
||||
To: joe
|
||||
Content-Length: 4
|
||||
|
||||
abcd""".replace("\n", "\r\n")
|
||||
|
||||
# request, no content-length
|
||||
request2 = """INVITE sip:foo SIP/2.0
|
||||
From: mo
|
||||
To: joe
|
||||
|
||||
1234""".replace("\n", "\r\n")
|
||||
|
||||
# request, with garbage after
|
||||
request3 = """INVITE sip:foo SIP/2.0
|
||||
From: mo
|
||||
To: joe
|
||||
Content-Length: 4
|
||||
|
||||
1234
|
||||
|
||||
lalalal""".replace("\n", "\r\n")
|
||||
|
||||
# three requests
|
||||
request4 = """INVITE sip:foo SIP/2.0
|
||||
From: mo
|
||||
To: joe
|
||||
Content-Length: 0
|
||||
|
||||
INVITE sip:loop SIP/2.0
|
||||
From: foo
|
||||
To: bar
|
||||
Content-Length: 4
|
||||
|
||||
abcdINVITE sip:loop SIP/2.0
|
||||
From: foo
|
||||
To: bar
|
||||
Content-Length: 4
|
||||
|
||||
1234""".replace("\n", "\r\n")
|
||||
|
||||
# response, no content
|
||||
response1 = """SIP/2.0 200 OK
|
||||
From: foo
|
||||
To:bar
|
||||
Content-Length: 0
|
||||
|
||||
""".replace("\n", "\r\n")
|
||||
|
||||
# short header version
|
||||
request_short = """\
|
||||
INVITE sip:foo SIP/2.0
|
||||
f: mo
|
||||
t: joe
|
||||
l: 4
|
||||
|
||||
abcd""".replace("\n", "\r\n")
|
||||
|
||||
request_natted = """\
|
||||
INVITE sip:foo SIP/2.0
|
||||
Via: SIP/2.0/UDP 10.0.0.1:5060;rport
|
||||
|
||||
""".replace("\n", "\r\n")
|
||||
|
||||
# multiline headers (example from RFC 3621).
|
||||
response_multiline = """\
|
||||
SIP/2.0 200 OK
|
||||
Via: SIP/2.0/UDP server10.biloxi.com
|
||||
;branch=z9hG4bKnashds8;received=192.0.2.3
|
||||
Via: SIP/2.0/UDP bigbox3.site3.atlanta.com
|
||||
;branch=z9hG4bK77ef4c2312983.1;received=192.0.2.2
|
||||
Via: SIP/2.0/UDP pc33.atlanta.com
|
||||
;branch=z9hG4bK776asdhds ;received=192.0.2.1
|
||||
To: Bob <sip:bob@biloxi.com>;tag=a6c85cf
|
||||
From: Alice <sip:alice@atlanta.com>;tag=1928301774
|
||||
Call-ID: a84b4c76e66710@pc33.atlanta.com
|
||||
CSeq: 314159 INVITE
|
||||
Contact: <sip:bob@192.0.2.4>
|
||||
Content-Type: application/sdp
|
||||
Content-Length: 0
|
||||
\n""".replace("\n", "\r\n")
|
||||
|
||||
class TestRealm:
|
||||
def requestAvatar(self, avatarId, mind, *interfaces):
|
||||
return sip.IContact, None, lambda: None
|
||||
|
||||
class MessageParsingTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.l = []
|
||||
self.parser = sip.MessagesParser(self.l.append)
|
||||
|
||||
def feedMessage(self, message):
|
||||
self.parser.dataReceived(message)
|
||||
self.parser.dataDone()
|
||||
|
||||
def validateMessage(self, m, method, uri, headers, body):
|
||||
"""Validate Requests."""
|
||||
self.assertEqual(m.method, method)
|
||||
self.assertEqual(m.uri.toString(), uri)
|
||||
self.assertEqual(m.headers, headers)
|
||||
self.assertEqual(m.body, body)
|
||||
self.assertEqual(m.finished, 1)
|
||||
|
||||
def testSimple(self):
|
||||
l = self.l
|
||||
self.feedMessage(request1)
|
||||
self.assertEqual(len(l), 1)
|
||||
self.validateMessage(
|
||||
l[0], "INVITE", "sip:foo",
|
||||
{"from": ["mo"], "to": ["joe"], "content-length": ["4"]},
|
||||
"abcd")
|
||||
|
||||
def testTwoMessages(self):
|
||||
l = self.l
|
||||
self.feedMessage(request1)
|
||||
self.feedMessage(request2)
|
||||
self.assertEqual(len(l), 2)
|
||||
self.validateMessage(
|
||||
l[0], "INVITE", "sip:foo",
|
||||
{"from": ["mo"], "to": ["joe"], "content-length": ["4"]},
|
||||
"abcd")
|
||||
self.validateMessage(l[1], "INVITE", "sip:foo",
|
||||
{"from": ["mo"], "to": ["joe"]},
|
||||
"1234")
|
||||
|
||||
def testGarbage(self):
|
||||
l = self.l
|
||||
self.feedMessage(request3)
|
||||
self.assertEqual(len(l), 1)
|
||||
self.validateMessage(
|
||||
l[0], "INVITE", "sip:foo",
|
||||
{"from": ["mo"], "to": ["joe"], "content-length": ["4"]},
|
||||
"1234")
|
||||
|
||||
def testThreeInOne(self):
|
||||
l = self.l
|
||||
self.feedMessage(request4)
|
||||
self.assertEqual(len(l), 3)
|
||||
self.validateMessage(
|
||||
l[0], "INVITE", "sip:foo",
|
||||
{"from": ["mo"], "to": ["joe"], "content-length": ["0"]},
|
||||
"")
|
||||
self.validateMessage(
|
||||
l[1], "INVITE", "sip:loop",
|
||||
{"from": ["foo"], "to": ["bar"], "content-length": ["4"]},
|
||||
"abcd")
|
||||
self.validateMessage(
|
||||
l[2], "INVITE", "sip:loop",
|
||||
{"from": ["foo"], "to": ["bar"], "content-length": ["4"]},
|
||||
"1234")
|
||||
|
||||
def testShort(self):
|
||||
l = self.l
|
||||
self.feedMessage(request_short)
|
||||
self.assertEqual(len(l), 1)
|
||||
self.validateMessage(
|
||||
l[0], "INVITE", "sip:foo",
|
||||
{"from": ["mo"], "to": ["joe"], "content-length": ["4"]},
|
||||
"abcd")
|
||||
|
||||
def testSimpleResponse(self):
|
||||
l = self.l
|
||||
self.feedMessage(response1)
|
||||
self.assertEqual(len(l), 1)
|
||||
m = l[0]
|
||||
self.assertEqual(m.code, 200)
|
||||
self.assertEqual(m.phrase, "OK")
|
||||
self.assertEqual(
|
||||
m.headers,
|
||||
{"from": ["foo"], "to": ["bar"], "content-length": ["0"]})
|
||||
self.assertEqual(m.body, "")
|
||||
self.assertEqual(m.finished, 1)
|
||||
|
||||
|
||||
def test_multiLine(self):
|
||||
"""
|
||||
A header may be split across multiple lines. Subsequent lines begin
|
||||
with C{" "} or C{"\\t"}.
|
||||
"""
|
||||
l = self.l
|
||||
self.feedMessage(response_multiline)
|
||||
self.assertEquals(len(l), 1)
|
||||
m = l[0]
|
||||
self.assertEquals(
|
||||
m.headers['via'][0],
|
||||
"SIP/2.0/UDP server10.biloxi.com;"
|
||||
"branch=z9hG4bKnashds8;received=192.0.2.3")
|
||||
self.assertEquals(
|
||||
m.headers['via'][1],
|
||||
"SIP/2.0/UDP bigbox3.site3.atlanta.com;"
|
||||
"branch=z9hG4bK77ef4c2312983.1;received=192.0.2.2")
|
||||
self.assertEquals(
|
||||
m.headers['via'][2],
|
||||
"SIP/2.0/UDP pc33.atlanta.com;"
|
||||
"branch=z9hG4bK776asdhds ;received=192.0.2.1")
|
||||
|
||||
|
||||
|
||||
class MessageParsingTestCase2(MessageParsingTestCase):
|
||||
"""Same as base class, but feed data char by char."""
|
||||
|
||||
def feedMessage(self, message):
|
||||
for c in message:
|
||||
self.parser.dataReceived(c)
|
||||
self.parser.dataDone()
|
||||
|
||||
|
||||
class MakeMessageTestCase(unittest.TestCase):
|
||||
|
||||
def testRequest(self):
|
||||
r = sip.Request("INVITE", "sip:foo")
|
||||
r.addHeader("foo", "bar")
|
||||
self.assertEqual(
|
||||
r.toString(),
|
||||
"INVITE sip:foo SIP/2.0\r\nFoo: bar\r\n\r\n")
|
||||
|
||||
def testResponse(self):
|
||||
r = sip.Response(200, "OK")
|
||||
r.addHeader("foo", "bar")
|
||||
r.addHeader("Content-Length", "4")
|
||||
r.bodyDataReceived("1234")
|
||||
self.assertEqual(
|
||||
r.toString(),
|
||||
"SIP/2.0 200 OK\r\nFoo: bar\r\nContent-Length: 4\r\n\r\n1234")
|
||||
|
||||
def testStatusCode(self):
|
||||
r = sip.Response(200)
|
||||
self.assertEqual(r.toString(), "SIP/2.0 200 OK\r\n\r\n")
|
||||
|
||||
|
||||
class ViaTestCase(unittest.TestCase):
|
||||
|
||||
def checkRoundtrip(self, v):
|
||||
s = v.toString()
|
||||
self.assertEqual(s, sip.parseViaHeader(s).toString())
|
||||
|
||||
def testExtraWhitespace(self):
|
||||
v1 = sip.parseViaHeader('SIP/2.0/UDP 192.168.1.1:5060')
|
||||
v2 = sip.parseViaHeader('SIP/2.0/UDP 192.168.1.1:5060')
|
||||
self.assertEqual(v1.transport, v2.transport)
|
||||
self.assertEqual(v1.host, v2.host)
|
||||
self.assertEqual(v1.port, v2.port)
|
||||
|
||||
def test_complex(self):
|
||||
"""
|
||||
Test parsing a Via header with one of everything.
|
||||
"""
|
||||
s = ("SIP/2.0/UDP first.example.com:4000;ttl=16;maddr=224.2.0.1"
|
||||
" ;branch=a7c6a8dlze (Example)")
|
||||
v = sip.parseViaHeader(s)
|
||||
self.assertEqual(v.transport, "UDP")
|
||||
self.assertEqual(v.host, "first.example.com")
|
||||
self.assertEqual(v.port, 4000)
|
||||
self.assertEqual(v.rport, None)
|
||||
self.assertEqual(v.rportValue, None)
|
||||
self.assertEqual(v.rportRequested, False)
|
||||
self.assertEqual(v.ttl, 16)
|
||||
self.assertEqual(v.maddr, "224.2.0.1")
|
||||
self.assertEqual(v.branch, "a7c6a8dlze")
|
||||
self.assertEqual(v.hidden, 0)
|
||||
self.assertEqual(v.toString(),
|
||||
"SIP/2.0/UDP first.example.com:4000"
|
||||
";ttl=16;branch=a7c6a8dlze;maddr=224.2.0.1")
|
||||
self.checkRoundtrip(v)
|
||||
|
||||
def test_simple(self):
|
||||
"""
|
||||
Test parsing a simple Via header.
|
||||
"""
|
||||
s = "SIP/2.0/UDP example.com;hidden"
|
||||
v = sip.parseViaHeader(s)
|
||||
self.assertEqual(v.transport, "UDP")
|
||||
self.assertEqual(v.host, "example.com")
|
||||
self.assertEqual(v.port, 5060)
|
||||
self.assertEqual(v.rport, None)
|
||||
self.assertEqual(v.rportValue, None)
|
||||
self.assertEqual(v.rportRequested, False)
|
||||
self.assertEqual(v.ttl, None)
|
||||
self.assertEqual(v.maddr, None)
|
||||
self.assertEqual(v.branch, None)
|
||||
self.assertEqual(v.hidden, True)
|
||||
self.assertEqual(v.toString(),
|
||||
"SIP/2.0/UDP example.com:5060;hidden")
|
||||
self.checkRoundtrip(v)
|
||||
|
||||
def testSimpler(self):
|
||||
v = sip.Via("example.com")
|
||||
self.checkRoundtrip(v)
|
||||
|
||||
|
||||
def test_deprecatedRPort(self):
|
||||
"""
|
||||
Setting rport to True is deprecated, but still produces a Via header
|
||||
with the expected properties.
|
||||
"""
|
||||
v = sip.Via("foo.bar", rport=True)
|
||||
|
||||
warnings = self.flushWarnings(
|
||||
offendingFunctions=[self.test_deprecatedRPort])
|
||||
self.assertEqual(len(warnings), 1)
|
||||
self.assertEqual(
|
||||
warnings[0]['message'],
|
||||
'rport=True is deprecated since Twisted 9.0.')
|
||||
self.assertEqual(
|
||||
warnings[0]['category'],
|
||||
DeprecationWarning)
|
||||
|
||||
self.assertEqual(v.toString(), "SIP/2.0/UDP foo.bar:5060;rport")
|
||||
self.assertEqual(v.rport, True)
|
||||
self.assertEqual(v.rportRequested, True)
|
||||
self.assertEqual(v.rportValue, None)
|
||||
|
||||
|
||||
def test_rport(self):
|
||||
"""
|
||||
An rport setting of None should insert the parameter with no value.
|
||||
"""
|
||||
v = sip.Via("foo.bar", rport=None)
|
||||
self.assertEqual(v.toString(), "SIP/2.0/UDP foo.bar:5060;rport")
|
||||
self.assertEqual(v.rportRequested, True)
|
||||
self.assertEqual(v.rportValue, None)
|
||||
|
||||
|
||||
def test_rportValue(self):
|
||||
"""
|
||||
An rport numeric setting should insert the parameter with the number
|
||||
value given.
|
||||
"""
|
||||
v = sip.Via("foo.bar", rport=1)
|
||||
self.assertEqual(v.toString(), "SIP/2.0/UDP foo.bar:5060;rport=1")
|
||||
self.assertEqual(v.rportRequested, False)
|
||||
self.assertEqual(v.rportValue, 1)
|
||||
self.assertEqual(v.rport, 1)
|
||||
|
||||
|
||||
def testNAT(self):
|
||||
s = "SIP/2.0/UDP 10.0.0.1:5060;received=22.13.1.5;rport=12345"
|
||||
v = sip.parseViaHeader(s)
|
||||
self.assertEqual(v.transport, "UDP")
|
||||
self.assertEqual(v.host, "10.0.0.1")
|
||||
self.assertEqual(v.port, 5060)
|
||||
self.assertEqual(v.received, "22.13.1.5")
|
||||
self.assertEqual(v.rport, 12345)
|
||||
|
||||
self.assertNotEquals(v.toString().find("rport=12345"), -1)
|
||||
|
||||
|
||||
def test_unknownParams(self):
|
||||
"""
|
||||
Parsing and serializing Via headers with unknown parameters should work.
|
||||
"""
|
||||
s = "SIP/2.0/UDP example.com:5060;branch=a12345b;bogus;pie=delicious"
|
||||
v = sip.parseViaHeader(s)
|
||||
self.assertEqual(v.toString(), s)
|
||||
|
||||
|
||||
|
||||
class URLTestCase(unittest.TestCase):
|
||||
|
||||
def testRoundtrip(self):
|
||||
for url in [
|
||||
"sip:j.doe@big.com",
|
||||
"sip:j.doe:secret@big.com;transport=tcp",
|
||||
"sip:j.doe@big.com?subject=project",
|
||||
"sip:example.com",
|
||||
]:
|
||||
self.assertEqual(sip.parseURL(url).toString(), url)
|
||||
|
||||
def testComplex(self):
|
||||
s = ("sip:user:pass@hosta:123;transport=udp;user=phone;method=foo;"
|
||||
"ttl=12;maddr=1.2.3.4;blah;goo=bar?a=b&c=d")
|
||||
url = sip.parseURL(s)
|
||||
for k, v in [("username", "user"), ("password", "pass"),
|
||||
("host", "hosta"), ("port", 123),
|
||||
("transport", "udp"), ("usertype", "phone"),
|
||||
("method", "foo"), ("ttl", 12),
|
||||
("maddr", "1.2.3.4"), ("other", ["blah", "goo=bar"]),
|
||||
("headers", {"a": "b", "c": "d"})]:
|
||||
self.assertEqual(getattr(url, k), v)
|
||||
|
||||
|
||||
class ParseTestCase(unittest.TestCase):
|
||||
|
||||
def testParseAddress(self):
|
||||
for address, name, urls, params in [
|
||||
('"A. G. Bell" <sip:foo@example.com>',
|
||||
"A. G. Bell", "sip:foo@example.com", {}),
|
||||
("Anon <sip:foo@example.com>", "Anon", "sip:foo@example.com", {}),
|
||||
("sip:foo@example.com", "", "sip:foo@example.com", {}),
|
||||
("<sip:foo@example.com>", "", "sip:foo@example.com", {}),
|
||||
("foo <sip:foo@example.com>;tag=bar;foo=baz", "foo",
|
||||
"sip:foo@example.com", {"tag": "bar", "foo": "baz"}),
|
||||
]:
|
||||
gname, gurl, gparams = sip.parseAddress(address)
|
||||
self.assertEqual(name, gname)
|
||||
self.assertEqual(gurl.toString(), urls)
|
||||
self.assertEqual(gparams, params)
|
||||
|
||||
|
||||
class DummyLocator:
|
||||
implements(sip.ILocator)
|
||||
def getAddress(self, logicalURL):
|
||||
return defer.succeed(sip.URL("server.com", port=5060))
|
||||
|
||||
class FailingLocator:
|
||||
implements(sip.ILocator)
|
||||
def getAddress(self, logicalURL):
|
||||
return defer.fail(LookupError())
|
||||
|
||||
|
||||
class ProxyTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.proxy = sip.Proxy("127.0.0.1")
|
||||
self.proxy.locator = DummyLocator()
|
||||
self.sent = []
|
||||
self.proxy.sendMessage = lambda dest, msg: self.sent.append((dest, msg))
|
||||
|
||||
def testRequestForward(self):
|
||||
r = sip.Request("INVITE", "sip:foo")
|
||||
r.addHeader("via", sip.Via("1.2.3.4").toString())
|
||||
r.addHeader("via", sip.Via("1.2.3.5").toString())
|
||||
r.addHeader("foo", "bar")
|
||||
r.addHeader("to", "<sip:joe@server.com>")
|
||||
r.addHeader("contact", "<sip:joe@1.2.3.5>")
|
||||
self.proxy.datagramReceived(r.toString(), ("1.2.3.4", 5060))
|
||||
self.assertEqual(len(self.sent), 1)
|
||||
dest, m = self.sent[0]
|
||||
self.assertEqual(dest.port, 5060)
|
||||
self.assertEqual(dest.host, "server.com")
|
||||
self.assertEqual(m.uri.toString(), "sip:foo")
|
||||
self.assertEqual(m.method, "INVITE")
|
||||
self.assertEqual(m.headers["via"],
|
||||
["SIP/2.0/UDP 127.0.0.1:5060",
|
||||
"SIP/2.0/UDP 1.2.3.4:5060",
|
||||
"SIP/2.0/UDP 1.2.3.5:5060"])
|
||||
|
||||
|
||||
def testReceivedRequestForward(self):
|
||||
r = sip.Request("INVITE", "sip:foo")
|
||||
r.addHeader("via", sip.Via("1.2.3.4").toString())
|
||||
r.addHeader("foo", "bar")
|
||||
r.addHeader("to", "<sip:joe@server.com>")
|
||||
r.addHeader("contact", "<sip:joe@1.2.3.4>")
|
||||
self.proxy.datagramReceived(r.toString(), ("1.1.1.1", 5060))
|
||||
dest, m = self.sent[0]
|
||||
self.assertEqual(m.headers["via"],
|
||||
["SIP/2.0/UDP 127.0.0.1:5060",
|
||||
"SIP/2.0/UDP 1.2.3.4:5060;received=1.1.1.1"])
|
||||
|
||||
|
||||
def testResponseWrongVia(self):
|
||||
# first via must match proxy's address
|
||||
r = sip.Response(200)
|
||||
r.addHeader("via", sip.Via("foo.com").toString())
|
||||
self.proxy.datagramReceived(r.toString(), ("1.1.1.1", 5060))
|
||||
self.assertEqual(len(self.sent), 0)
|
||||
|
||||
def testResponseForward(self):
|
||||
r = sip.Response(200)
|
||||
r.addHeader("via", sip.Via("127.0.0.1").toString())
|
||||
r.addHeader("via", sip.Via("client.com", port=1234).toString())
|
||||
self.proxy.datagramReceived(r.toString(), ("1.1.1.1", 5060))
|
||||
self.assertEqual(len(self.sent), 1)
|
||||
dest, m = self.sent[0]
|
||||
self.assertEqual((dest.host, dest.port), ("client.com", 1234))
|
||||
self.assertEqual(m.code, 200)
|
||||
self.assertEqual(m.headers["via"], ["SIP/2.0/UDP client.com:1234"])
|
||||
|
||||
def testReceivedResponseForward(self):
|
||||
r = sip.Response(200)
|
||||
r.addHeader("via", sip.Via("127.0.0.1").toString())
|
||||
r.addHeader(
|
||||
"via",
|
||||
sip.Via("10.0.0.1", received="client.com").toString())
|
||||
self.proxy.datagramReceived(r.toString(), ("1.1.1.1", 5060))
|
||||
self.assertEqual(len(self.sent), 1)
|
||||
dest, m = self.sent[0]
|
||||
self.assertEqual((dest.host, dest.port), ("client.com", 5060))
|
||||
|
||||
def testResponseToUs(self):
|
||||
r = sip.Response(200)
|
||||
r.addHeader("via", sip.Via("127.0.0.1").toString())
|
||||
l = []
|
||||
self.proxy.gotResponse = lambda *a: l.append(a)
|
||||
self.proxy.datagramReceived(r.toString(), ("1.1.1.1", 5060))
|
||||
self.assertEqual(len(l), 1)
|
||||
m, addr = l[0]
|
||||
self.assertEqual(len(m.headers.get("via", [])), 0)
|
||||
self.assertEqual(m.code, 200)
|
||||
|
||||
def testLoop(self):
|
||||
r = sip.Request("INVITE", "sip:foo")
|
||||
r.addHeader("via", sip.Via("1.2.3.4").toString())
|
||||
r.addHeader("via", sip.Via("127.0.0.1").toString())
|
||||
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
|
||||
self.assertEqual(self.sent, [])
|
||||
|
||||
def testCantForwardRequest(self):
|
||||
r = sip.Request("INVITE", "sip:foo")
|
||||
r.addHeader("via", sip.Via("1.2.3.4").toString())
|
||||
r.addHeader("to", "<sip:joe@server.com>")
|
||||
self.proxy.locator = FailingLocator()
|
||||
self.proxy.datagramReceived(r.toString(), ("1.2.3.4", 5060))
|
||||
self.assertEqual(len(self.sent), 1)
|
||||
dest, m = self.sent[0]
|
||||
self.assertEqual((dest.host, dest.port), ("1.2.3.4", 5060))
|
||||
self.assertEqual(m.code, 404)
|
||||
self.assertEqual(m.headers["via"], ["SIP/2.0/UDP 1.2.3.4:5060"])
|
||||
|
||||
def testCantForwardResponse(self):
|
||||
pass
|
||||
|
||||
#testCantForwardResponse.skip = "not implemented yet"
|
||||
|
||||
|
||||
class RegistrationTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.proxy = sip.RegisterProxy(host="127.0.0.1")
|
||||
self.registry = sip.InMemoryRegistry("bell.example.com")
|
||||
self.proxy.registry = self.proxy.locator = self.registry
|
||||
self.sent = []
|
||||
self.proxy.sendMessage = lambda dest, msg: self.sent.append((dest, msg))
|
||||
setUp = utils.suppressWarnings(setUp,
|
||||
util.suppress(category=DeprecationWarning,
|
||||
message=r'twisted.protocols.sip.DigestAuthorizer was deprecated'))
|
||||
|
||||
def tearDown(self):
|
||||
for d, uri in self.registry.users.values():
|
||||
d.cancel()
|
||||
del self.proxy
|
||||
|
||||
def register(self):
|
||||
r = sip.Request("REGISTER", "sip:bell.example.com")
|
||||
r.addHeader("to", "sip:joe@bell.example.com")
|
||||
r.addHeader("contact", "sip:joe@client.com:1234")
|
||||
r.addHeader("via", sip.Via("client.com").toString())
|
||||
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
|
||||
|
||||
def unregister(self):
|
||||
r = sip.Request("REGISTER", "sip:bell.example.com")
|
||||
r.addHeader("to", "sip:joe@bell.example.com")
|
||||
r.addHeader("contact", "*")
|
||||
r.addHeader("via", sip.Via("client.com").toString())
|
||||
r.addHeader("expires", "0")
|
||||
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
|
||||
|
||||
def testRegister(self):
|
||||
self.register()
|
||||
dest, m = self.sent[0]
|
||||
self.assertEqual((dest.host, dest.port), ("client.com", 5060))
|
||||
self.assertEqual(m.code, 200)
|
||||
self.assertEqual(m.headers["via"], ["SIP/2.0/UDP client.com:5060"])
|
||||
self.assertEqual(m.headers["to"], ["sip:joe@bell.example.com"])
|
||||
self.assertEqual(m.headers["contact"], ["sip:joe@client.com:5060"])
|
||||
self.failUnless(
|
||||
int(m.headers["expires"][0]) in (3600, 3601, 3599, 3598))
|
||||
self.assertEqual(len(self.registry.users), 1)
|
||||
dc, uri = self.registry.users["joe"]
|
||||
self.assertEqual(uri.toString(), "sip:joe@client.com:5060")
|
||||
d = self.proxy.locator.getAddress(sip.URL(username="joe",
|
||||
host="bell.example.com"))
|
||||
d.addCallback(lambda desturl : (desturl.host, desturl.port))
|
||||
d.addCallback(self.assertEqual, ('client.com', 5060))
|
||||
return d
|
||||
|
||||
def testUnregister(self):
|
||||
self.register()
|
||||
self.unregister()
|
||||
dest, m = self.sent[1]
|
||||
self.assertEqual((dest.host, dest.port), ("client.com", 5060))
|
||||
self.assertEqual(m.code, 200)
|
||||
self.assertEqual(m.headers["via"], ["SIP/2.0/UDP client.com:5060"])
|
||||
self.assertEqual(m.headers["to"], ["sip:joe@bell.example.com"])
|
||||
self.assertEqual(m.headers["contact"], ["sip:joe@client.com:5060"])
|
||||
self.assertEqual(m.headers["expires"], ["0"])
|
||||
self.assertEqual(self.registry.users, {})
|
||||
|
||||
|
||||
def addPortal(self):
|
||||
r = TestRealm()
|
||||
p = cred.portal.Portal(r)
|
||||
c = cred.checkers.InMemoryUsernamePasswordDatabaseDontUse()
|
||||
c.addUser('userXname@127.0.0.1', 'passXword')
|
||||
p.registerChecker(c)
|
||||
self.proxy.portal = p
|
||||
|
||||
def testFailedAuthentication(self):
|
||||
self.addPortal()
|
||||
self.register()
|
||||
|
||||
self.assertEqual(len(self.registry.users), 0)
|
||||
self.assertEqual(len(self.sent), 1)
|
||||
dest, m = self.sent[0]
|
||||
self.assertEqual(m.code, 401)
|
||||
|
||||
|
||||
def test_basicAuthentication(self):
|
||||
"""
|
||||
Test that registration with basic authentication suceeds.
|
||||
"""
|
||||
self.addPortal()
|
||||
self.proxy.authorizers = self.proxy.authorizers.copy()
|
||||
self.proxy.authorizers['basic'] = sip.BasicAuthorizer()
|
||||
warnings = self.flushWarnings(
|
||||
offendingFunctions=[self.test_basicAuthentication])
|
||||
self.assertEqual(len(warnings), 1)
|
||||
self.assertEqual(
|
||||
warnings[0]['message'],
|
||||
"twisted.protocols.sip.BasicAuthorizer was deprecated in "
|
||||
"Twisted 9.0.0")
|
||||
self.assertEqual(
|
||||
warnings[0]['category'],
|
||||
DeprecationWarning)
|
||||
r = sip.Request("REGISTER", "sip:bell.example.com")
|
||||
r.addHeader("to", "sip:joe@bell.example.com")
|
||||
r.addHeader("contact", "sip:joe@client.com:1234")
|
||||
r.addHeader("via", sip.Via("client.com").toString())
|
||||
r.addHeader("authorization",
|
||||
"Basic " + "userXname:passXword".encode('base64'))
|
||||
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
|
||||
|
||||
self.assertEqual(len(self.registry.users), 1)
|
||||
self.assertEqual(len(self.sent), 1)
|
||||
dest, m = self.sent[0]
|
||||
self.assertEqual(m.code, 200)
|
||||
|
||||
|
||||
def test_failedBasicAuthentication(self):
|
||||
"""
|
||||
Failed registration with basic authentication results in an
|
||||
unauthorized error response.
|
||||
"""
|
||||
self.addPortal()
|
||||
self.proxy.authorizers = self.proxy.authorizers.copy()
|
||||
self.proxy.authorizers['basic'] = sip.BasicAuthorizer()
|
||||
warnings = self.flushWarnings(
|
||||
offendingFunctions=[self.test_failedBasicAuthentication])
|
||||
self.assertEqual(len(warnings), 1)
|
||||
self.assertEqual(
|
||||
warnings[0]['message'],
|
||||
"twisted.protocols.sip.BasicAuthorizer was deprecated in "
|
||||
"Twisted 9.0.0")
|
||||
self.assertEqual(
|
||||
warnings[0]['category'],
|
||||
DeprecationWarning)
|
||||
r = sip.Request("REGISTER", "sip:bell.example.com")
|
||||
r.addHeader("to", "sip:joe@bell.example.com")
|
||||
r.addHeader("contact", "sip:joe@client.com:1234")
|
||||
r.addHeader("via", sip.Via("client.com").toString())
|
||||
r.addHeader(
|
||||
"authorization", "Basic " + "userXname:password".encode('base64'))
|
||||
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
|
||||
|
||||
self.assertEqual(len(self.registry.users), 0)
|
||||
self.assertEqual(len(self.sent), 1)
|
||||
dest, m = self.sent[0]
|
||||
self.assertEqual(m.code, 401)
|
||||
|
||||
|
||||
def testWrongDomainRegister(self):
|
||||
r = sip.Request("REGISTER", "sip:wrong.com")
|
||||
r.addHeader("to", "sip:joe@bell.example.com")
|
||||
r.addHeader("contact", "sip:joe@client.com:1234")
|
||||
r.addHeader("via", sip.Via("client.com").toString())
|
||||
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
|
||||
self.assertEqual(len(self.sent), 0)
|
||||
|
||||
def testWrongToDomainRegister(self):
|
||||
r = sip.Request("REGISTER", "sip:bell.example.com")
|
||||
r.addHeader("to", "sip:joe@foo.com")
|
||||
r.addHeader("contact", "sip:joe@client.com:1234")
|
||||
r.addHeader("via", sip.Via("client.com").toString())
|
||||
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
|
||||
self.assertEqual(len(self.sent), 0)
|
||||
|
||||
def testWrongDomainLookup(self):
|
||||
self.register()
|
||||
url = sip.URL(username="joe", host="foo.com")
|
||||
d = self.proxy.locator.getAddress(url)
|
||||
self.assertFailure(d, LookupError)
|
||||
return d
|
||||
|
||||
def testNoContactLookup(self):
|
||||
self.register()
|
||||
url = sip.URL(username="jane", host="bell.example.com")
|
||||
d = self.proxy.locator.getAddress(url)
|
||||
self.assertFailure(d, LookupError)
|
||||
return d
|
||||
|
||||
|
||||
class Client(sip.Base):
|
||||
|
||||
def __init__(self):
|
||||
sip.Base.__init__(self)
|
||||
self.received = []
|
||||
self.deferred = defer.Deferred()
|
||||
|
||||
def handle_response(self, response, addr):
|
||||
self.received.append(response)
|
||||
self.deferred.callback(self.received)
|
||||
|
||||
|
||||
class LiveTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.proxy = sip.RegisterProxy(host="127.0.0.1")
|
||||
self.registry = sip.InMemoryRegistry("bell.example.com")
|
||||
self.proxy.registry = self.proxy.locator = self.registry
|
||||
self.serverPort = reactor.listenUDP(
|
||||
0, self.proxy, interface="127.0.0.1")
|
||||
self.client = Client()
|
||||
self.clientPort = reactor.listenUDP(
|
||||
0, self.client, interface="127.0.0.1")
|
||||
self.serverAddress = (self.serverPort.getHost().host,
|
||||
self.serverPort.getHost().port)
|
||||
setUp = utils.suppressWarnings(setUp,
|
||||
util.suppress(category=DeprecationWarning,
|
||||
message=r'twisted.protocols.sip.DigestAuthorizer was deprecated'))
|
||||
|
||||
def tearDown(self):
|
||||
for d, uri in self.registry.users.values():
|
||||
d.cancel()
|
||||
d1 = defer.maybeDeferred(self.clientPort.stopListening)
|
||||
d2 = defer.maybeDeferred(self.serverPort.stopListening)
|
||||
return defer.gatherResults([d1, d2])
|
||||
|
||||
def testRegister(self):
|
||||
p = self.clientPort.getHost().port
|
||||
r = sip.Request("REGISTER", "sip:bell.example.com")
|
||||
r.addHeader("to", "sip:joe@bell.example.com")
|
||||
r.addHeader("contact", "sip:joe@127.0.0.1:%d" % p)
|
||||
r.addHeader("via", sip.Via("127.0.0.1", port=p).toString())
|
||||
self.client.sendMessage(
|
||||
sip.URL(host="127.0.0.1", port=self.serverAddress[1]), r)
|
||||
d = self.client.deferred
|
||||
def check(received):
|
||||
self.assertEqual(len(received), 1)
|
||||
r = received[0]
|
||||
self.assertEqual(r.code, 200)
|
||||
d.addCallback(check)
|
||||
return d
|
||||
|
||||
def test_amoralRPort(self):
|
||||
"""
|
||||
rport is allowed without a value, apparently because server
|
||||
implementors might be too stupid to check the received port
|
||||
against 5060 and see if they're equal, and because client
|
||||
implementors might be too stupid to bind to port 5060, or set a
|
||||
value on the rport parameter they send if they bind to another
|
||||
port.
|
||||
"""
|
||||
p = self.clientPort.getHost().port
|
||||
r = sip.Request("REGISTER", "sip:bell.example.com")
|
||||
r.addHeader("to", "sip:joe@bell.example.com")
|
||||
r.addHeader("contact", "sip:joe@127.0.0.1:%d" % p)
|
||||
r.addHeader("via", sip.Via("127.0.0.1", port=p, rport=True).toString())
|
||||
warnings = self.flushWarnings(
|
||||
offendingFunctions=[self.test_amoralRPort])
|
||||
self.assertEqual(len(warnings), 1)
|
||||
self.assertEqual(
|
||||
warnings[0]['message'],
|
||||
'rport=True is deprecated since Twisted 9.0.')
|
||||
self.assertEqual(
|
||||
warnings[0]['category'],
|
||||
DeprecationWarning)
|
||||
self.client.sendMessage(sip.URL(host="127.0.0.1",
|
||||
port=self.serverAddress[1]),
|
||||
r)
|
||||
d = self.client.deferred
|
||||
def check(received):
|
||||
self.assertEqual(len(received), 1)
|
||||
r = received[0]
|
||||
self.assertEqual(r.code, 200)
|
||||
d.addCallback(check)
|
||||
return d
|
||||
|
||||
|
||||
|
||||
registerRequest = """
|
||||
REGISTER sip:intarweb.us SIP/2.0\r
|
||||
Via: SIP/2.0/UDP 192.168.1.100:50609\r
|
||||
From: <sip:exarkun@intarweb.us:50609>\r
|
||||
To: <sip:exarkun@intarweb.us:50609>\r
|
||||
Contact: "exarkun" <sip:exarkun@192.168.1.100:50609>\r
|
||||
Call-ID: 94E7E5DAF39111D791C6000393764646@intarweb.us\r
|
||||
CSeq: 9898 REGISTER\r
|
||||
Expires: 500\r
|
||||
User-Agent: X-Lite build 1061\r
|
||||
Content-Length: 0\r
|
||||
\r
|
||||
"""
|
||||
|
||||
challengeResponse = """\
|
||||
SIP/2.0 401 Unauthorized\r
|
||||
Via: SIP/2.0/UDP 192.168.1.100:50609;received=127.0.0.1;rport=5632\r
|
||||
To: <sip:exarkun@intarweb.us:50609>\r
|
||||
From: <sip:exarkun@intarweb.us:50609>\r
|
||||
Call-ID: 94E7E5DAF39111D791C6000393764646@intarweb.us\r
|
||||
CSeq: 9898 REGISTER\r
|
||||
WWW-Authenticate: Digest nonce="92956076410767313901322208775",opaque="1674186428",qop-options="auth",algorithm="MD5",realm="intarweb.us"\r
|
||||
\r
|
||||
"""
|
||||
|
||||
authRequest = """\
|
||||
REGISTER sip:intarweb.us SIP/2.0\r
|
||||
Via: SIP/2.0/UDP 192.168.1.100:50609\r
|
||||
From: <sip:exarkun@intarweb.us:50609>\r
|
||||
To: <sip:exarkun@intarweb.us:50609>\r
|
||||
Contact: "exarkun" <sip:exarkun@192.168.1.100:50609>\r
|
||||
Call-ID: 94E7E5DAF39111D791C6000393764646@intarweb.us\r
|
||||
CSeq: 9899 REGISTER\r
|
||||
Expires: 500\r
|
||||
Authorization: Digest username="exarkun",realm="intarweb.us",nonce="92956076410767313901322208775",response="4a47980eea31694f997369214292374b",uri="sip:intarweb.us",algorithm=MD5,opaque="1674186428"\r
|
||||
User-Agent: X-Lite build 1061\r
|
||||
Content-Length: 0\r
|
||||
\r
|
||||
"""
|
||||
|
||||
okResponse = """\
|
||||
SIP/2.0 200 OK\r
|
||||
Via: SIP/2.0/UDP 192.168.1.100:50609;received=127.0.0.1;rport=5632\r
|
||||
To: <sip:exarkun@intarweb.us:50609>\r
|
||||
From: <sip:exarkun@intarweb.us:50609>\r
|
||||
Call-ID: 94E7E5DAF39111D791C6000393764646@intarweb.us\r
|
||||
CSeq: 9899 REGISTER\r
|
||||
Contact: sip:exarkun@127.0.0.1:5632\r
|
||||
Expires: 3600\r
|
||||
Content-Length: 0\r
|
||||
\r
|
||||
"""
|
||||
|
||||
class FakeDigestAuthorizer(sip.DigestAuthorizer):
|
||||
def generateNonce(self):
|
||||
return '92956076410767313901322208775'
|
||||
def generateOpaque(self):
|
||||
return '1674186428'
|
||||
|
||||
|
||||
class FakeRegistry(sip.InMemoryRegistry):
|
||||
"""Make sure expiration is always seen to be 3600.
|
||||
|
||||
Otherwise slow reactors fail tests incorrectly.
|
||||
"""
|
||||
|
||||
def _cbReg(self, reg):
|
||||
if 3600 < reg.secondsToExpiry or reg.secondsToExpiry < 3598:
|
||||
raise RuntimeError(
|
||||
"bad seconds to expire: %s" % reg.secondsToExpiry)
|
||||
reg.secondsToExpiry = 3600
|
||||
return reg
|
||||
|
||||
def getRegistrationInfo(self, uri):
|
||||
d = sip.InMemoryRegistry.getRegistrationInfo(self, uri)
|
||||
return d.addCallback(self._cbReg)
|
||||
|
||||
def registerAddress(self, domainURL, logicalURL, physicalURL):
|
||||
d = sip.InMemoryRegistry.registerAddress(
|
||||
self, domainURL, logicalURL, physicalURL)
|
||||
return d.addCallback(self._cbReg)
|
||||
|
||||
class AuthorizationTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.proxy = sip.RegisterProxy(host="intarweb.us")
|
||||
self.proxy.authorizers = self.proxy.authorizers.copy()
|
||||
self.proxy.authorizers['digest'] = FakeDigestAuthorizer()
|
||||
|
||||
self.registry = FakeRegistry("intarweb.us")
|
||||
self.proxy.registry = self.proxy.locator = self.registry
|
||||
self.transport = proto_helpers.FakeDatagramTransport()
|
||||
self.proxy.transport = self.transport
|
||||
|
||||
r = TestRealm()
|
||||
p = cred.portal.Portal(r)
|
||||
c = cred.checkers.InMemoryUsernamePasswordDatabaseDontUse()
|
||||
c.addUser('exarkun@intarweb.us', 'password')
|
||||
p.registerChecker(c)
|
||||
self.proxy.portal = p
|
||||
setUp = utils.suppressWarnings(setUp,
|
||||
util.suppress(category=DeprecationWarning,
|
||||
message=r'twisted.protocols.sip.DigestAuthorizer was deprecated'))
|
||||
|
||||
def tearDown(self):
|
||||
for d, uri in self.registry.users.values():
|
||||
d.cancel()
|
||||
del self.proxy
|
||||
|
||||
def testChallenge(self):
|
||||
self.proxy.datagramReceived(registerRequest, ("127.0.0.1", 5632))
|
||||
|
||||
self.assertEqual(
|
||||
self.transport.written[-1],
|
||||
((challengeResponse, ("127.0.0.1", 5632)))
|
||||
)
|
||||
self.transport.written = []
|
||||
|
||||
self.proxy.datagramReceived(authRequest, ("127.0.0.1", 5632))
|
||||
|
||||
self.assertEqual(
|
||||
self.transport.written[-1],
|
||||
((okResponse, ("127.0.0.1", 5632)))
|
||||
)
|
||||
testChallenge.suppress = [
|
||||
util.suppress(
|
||||
category=DeprecationWarning,
|
||||
message=r'twisted.protocols.sip.DigestAuthorizer was deprecated'),
|
||||
util.suppress(
|
||||
category=DeprecationWarning,
|
||||
message=r'twisted.protocols.sip.DigestedCredentials was deprecated'),
|
||||
util.suppress(
|
||||
category=DeprecationWarning,
|
||||
message=r'twisted.protocols.sip.DigestCalcHA1 was deprecated'),
|
||||
util.suppress(
|
||||
category=DeprecationWarning,
|
||||
message=r'twisted.protocols.sip.DigestCalcResponse was deprecated')]
|
||||
|
||||
|
||||
|
||||
class DeprecationTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for deprecation of obsolete components of L{twisted.protocols.sip}.
|
||||
"""
|
||||
|
||||
def test_deprecatedDigestCalcHA1(self):
|
||||
"""
|
||||
L{sip.DigestCalcHA1} is deprecated.
|
||||
"""
|
||||
self.callDeprecated(Version("Twisted", 9, 0, 0),
|
||||
sip.DigestCalcHA1, '', '', '', '', '', '')
|
||||
|
||||
|
||||
def test_deprecatedDigestCalcResponse(self):
|
||||
"""
|
||||
L{sip.DigestCalcResponse} is deprecated.
|
||||
"""
|
||||
self.callDeprecated(Version("Twisted", 9, 0, 0),
|
||||
sip.DigestCalcResponse, '', '', '', '', '', '', '',
|
||||
'')
|
||||
|
||||
def test_deprecatedBasicAuthorizer(self):
|
||||
"""
|
||||
L{sip.BasicAuthorizer} is deprecated.
|
||||
"""
|
||||
self.callDeprecated(Version("Twisted", 9, 0, 0), sip.BasicAuthorizer)
|
||||
|
||||
|
||||
def test_deprecatedDigestAuthorizer(self):
|
||||
"""
|
||||
L{sip.DigestAuthorizer} is deprecated.
|
||||
"""
|
||||
self.callDeprecated(Version("Twisted", 9, 0, 0), sip.DigestAuthorizer)
|
||||
|
||||
|
||||
def test_deprecatedDigestedCredentials(self):
|
||||
"""
|
||||
L{sip.DigestedCredentials} is deprecated.
|
||||
"""
|
||||
self.callDeprecated(Version("Twisted", 9, 0, 0),
|
||||
sip.DigestedCredentials, '', {}, {})
|
||||
172
Linux_i686/lib/python2.7/site-packages/twisted/test/test_sob.py
Normal file
172
Linux_i686/lib/python2.7/site-packages/twisted/test/test_sob.py
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
import sys, os
|
||||
|
||||
try:
|
||||
import Crypto.Cipher.AES
|
||||
except ImportError:
|
||||
Crypto = None
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.persisted import sob
|
||||
from twisted.python import components
|
||||
|
||||
class Dummy(components.Componentized):
|
||||
pass
|
||||
|
||||
objects = [
|
||||
1,
|
||||
"hello",
|
||||
(1, "hello"),
|
||||
[1, "hello"],
|
||||
{1:"hello"},
|
||||
]
|
||||
|
||||
class FakeModule(object):
|
||||
pass
|
||||
|
||||
class PersistTestCase(unittest.TestCase):
|
||||
def testStyles(self):
|
||||
for o in objects:
|
||||
p = sob.Persistent(o, '')
|
||||
for style in 'source pickle'.split():
|
||||
p.setStyle(style)
|
||||
p.save(filename='persisttest.'+style)
|
||||
o1 = sob.load('persisttest.'+style, style)
|
||||
self.assertEqual(o, o1)
|
||||
|
||||
def testStylesBeingSet(self):
|
||||
o = Dummy()
|
||||
o.foo = 5
|
||||
o.setComponent(sob.IPersistable, sob.Persistent(o, 'lala'))
|
||||
for style in 'source pickle'.split():
|
||||
sob.IPersistable(o).setStyle(style)
|
||||
sob.IPersistable(o).save(filename='lala.'+style)
|
||||
o1 = sob.load('lala.'+style, style)
|
||||
self.assertEqual(o.foo, o1.foo)
|
||||
self.assertEqual(sob.IPersistable(o1).style, style)
|
||||
|
||||
|
||||
def testNames(self):
|
||||
o = [1,2,3]
|
||||
p = sob.Persistent(o, 'object')
|
||||
for style in 'source pickle'.split():
|
||||
p.setStyle(style)
|
||||
p.save()
|
||||
o1 = sob.load('object.ta'+style[0], style)
|
||||
self.assertEqual(o, o1)
|
||||
for tag in 'lala lolo'.split():
|
||||
p.save(tag)
|
||||
o1 = sob.load('object-'+tag+'.ta'+style[0], style)
|
||||
self.assertEqual(o, o1)
|
||||
|
||||
def testEncryptedStyles(self):
|
||||
for o in objects:
|
||||
phrase='once I was the king of spain'
|
||||
p = sob.Persistent(o, '')
|
||||
for style in 'source pickle'.split():
|
||||
p.setStyle(style)
|
||||
p.save(filename='epersisttest.'+style, passphrase=phrase)
|
||||
o1 = sob.load('epersisttest.'+style, style, phrase)
|
||||
self.assertEqual(o, o1)
|
||||
if Crypto is None:
|
||||
testEncryptedStyles.skip = "PyCrypto required for encrypted config"
|
||||
|
||||
def testPython(self):
|
||||
f = open("persisttest.python", 'w')
|
||||
f.write('foo=[1,2,3] ')
|
||||
f.close()
|
||||
o = sob.loadValueFromFile('persisttest.python', 'foo')
|
||||
self.assertEqual(o, [1,2,3])
|
||||
|
||||
def testEncryptedPython(self):
|
||||
phrase='once I was the king of spain'
|
||||
f = open("epersisttest.python", 'w')
|
||||
f.write(
|
||||
sob._encrypt(phrase, 'foo=[1,2,3]'))
|
||||
f.close()
|
||||
o = sob.loadValueFromFile('epersisttest.python', 'foo', phrase)
|
||||
self.assertEqual(o, [1,2,3])
|
||||
if Crypto is None:
|
||||
testEncryptedPython.skip = "PyCrypto required for encrypted config"
|
||||
|
||||
def testTypeGuesser(self):
|
||||
self.assertRaises(KeyError, sob.guessType, "file.blah")
|
||||
self.assertEqual('python', sob.guessType("file.py"))
|
||||
self.assertEqual('python', sob.guessType("file.tac"))
|
||||
self.assertEqual('python', sob.guessType("file.etac"))
|
||||
self.assertEqual('pickle', sob.guessType("file.tap"))
|
||||
self.assertEqual('pickle', sob.guessType("file.etap"))
|
||||
self.assertEqual('source', sob.guessType("file.tas"))
|
||||
self.assertEqual('source', sob.guessType("file.etas"))
|
||||
|
||||
def testEverythingEphemeralGetattr(self):
|
||||
"""
|
||||
Verify that _EverythingEphermal.__getattr__ works.
|
||||
"""
|
||||
self.fakeMain.testMainModGetattr = 1
|
||||
|
||||
dirname = self.mktemp()
|
||||
os.mkdir(dirname)
|
||||
|
||||
filename = os.path.join(dirname, 'persisttest.ee_getattr')
|
||||
|
||||
f = file(filename, 'w')
|
||||
f.write('import __main__\n')
|
||||
f.write('if __main__.testMainModGetattr != 1: raise AssertionError\n')
|
||||
f.write('app = None\n')
|
||||
f.close()
|
||||
|
||||
sob.load(filename, 'source')
|
||||
|
||||
def testEverythingEphemeralSetattr(self):
|
||||
"""
|
||||
Verify that _EverythingEphemeral.__setattr__ won't affect __main__.
|
||||
"""
|
||||
self.fakeMain.testMainModSetattr = 1
|
||||
|
||||
dirname = self.mktemp()
|
||||
os.mkdir(dirname)
|
||||
|
||||
filename = os.path.join(dirname, 'persisttest.ee_setattr')
|
||||
f = file(filename, 'w')
|
||||
f.write('import __main__\n')
|
||||
f.write('__main__.testMainModSetattr = 2\n')
|
||||
f.write('app = None\n')
|
||||
f.close()
|
||||
|
||||
sob.load(filename, 'source')
|
||||
|
||||
self.assertEqual(self.fakeMain.testMainModSetattr, 1)
|
||||
|
||||
def testEverythingEphemeralException(self):
|
||||
"""
|
||||
Test that an exception during load() won't cause _EE to mask __main__
|
||||
"""
|
||||
dirname = self.mktemp()
|
||||
os.mkdir(dirname)
|
||||
filename = os.path.join(dirname, 'persisttest.ee_exception')
|
||||
|
||||
f = file(filename, 'w')
|
||||
f.write('raise ValueError\n')
|
||||
f.close()
|
||||
|
||||
self.assertRaises(ValueError, sob.load, filename, 'source')
|
||||
self.assertEqual(type(sys.modules['__main__']), FakeModule)
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Replace the __main__ module with a fake one, so that it can be mutated
|
||||
in tests
|
||||
"""
|
||||
self.realMain = sys.modules['__main__']
|
||||
self.fakeMain = sys.modules['__main__'] = FakeModule()
|
||||
|
||||
def tearDown(self):
|
||||
"""
|
||||
Restore __main__ to its original value
|
||||
"""
|
||||
sys.modules['__main__'] = self.realMain
|
||||
|
||||
|
|
@ -0,0 +1,498 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.protocol.socks}, an implementation of the SOCKSv4 and
|
||||
SOCKSv4a protocols.
|
||||
"""
|
||||
|
||||
import struct, socket
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.test import proto_helpers
|
||||
from twisted.internet import defer, address, reactor
|
||||
from twisted.internet.error import DNSLookupError
|
||||
from twisted.protocols import socks
|
||||
|
||||
|
||||
class StringTCPTransport(proto_helpers.StringTransport):
|
||||
stringTCPTransport_closing = False
|
||||
peer = None
|
||||
|
||||
def getPeer(self):
|
||||
return self.peer
|
||||
|
||||
def getHost(self):
|
||||
return address.IPv4Address('TCP', '2.3.4.5', 42)
|
||||
|
||||
def loseConnection(self):
|
||||
self.stringTCPTransport_closing = True
|
||||
|
||||
|
||||
|
||||
class FakeResolverReactor:
|
||||
"""
|
||||
Bare-bones reactor with deterministic behavior for the resolve method.
|
||||
"""
|
||||
def __init__(self, names):
|
||||
"""
|
||||
@type names: C{dict} containing C{str} keys and C{str} values.
|
||||
@param names: A hostname to IP address mapping. The IP addresses are
|
||||
stringified dotted quads.
|
||||
"""
|
||||
self.names = names
|
||||
|
||||
|
||||
def resolve(self, hostname):
|
||||
"""
|
||||
Resolve a hostname by looking it up in the C{names} dictionary.
|
||||
"""
|
||||
try:
|
||||
return defer.succeed(self.names[hostname])
|
||||
except KeyError:
|
||||
return defer.fail(
|
||||
DNSLookupError("FakeResolverReactor couldn't find " + hostname))
|
||||
|
||||
|
||||
|
||||
class SOCKSv4Driver(socks.SOCKSv4):
|
||||
# last SOCKSv4Outgoing instantiated
|
||||
driver_outgoing = None
|
||||
|
||||
# last SOCKSv4IncomingFactory instantiated
|
||||
driver_listen = None
|
||||
|
||||
def connectClass(self, host, port, klass, *args):
|
||||
# fake it
|
||||
proto = klass(*args)
|
||||
proto.transport = StringTCPTransport()
|
||||
proto.transport.peer = address.IPv4Address('TCP', host, port)
|
||||
proto.connectionMade()
|
||||
self.driver_outgoing = proto
|
||||
return defer.succeed(proto)
|
||||
|
||||
def listenClass(self, port, klass, *args):
|
||||
# fake it
|
||||
factory = klass(*args)
|
||||
self.driver_listen = factory
|
||||
if port == 0:
|
||||
port = 1234
|
||||
return defer.succeed(('6.7.8.9', port))
|
||||
|
||||
|
||||
|
||||
class Connect(unittest.TestCase):
|
||||
"""
|
||||
Tests for SOCKS and SOCKSv4a connect requests using the L{SOCKSv4} protocol.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.sock = SOCKSv4Driver()
|
||||
self.sock.transport = StringTCPTransport()
|
||||
self.sock.connectionMade()
|
||||
self.sock.reactor = FakeResolverReactor({"localhost":"127.0.0.1"})
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
outgoing = self.sock.driver_outgoing
|
||||
if outgoing is not None:
|
||||
self.assert_(outgoing.transport.stringTCPTransport_closing,
|
||||
"Outgoing SOCKS connections need to be closed.")
|
||||
|
||||
|
||||
def test_simple(self):
|
||||
self.sock.dataReceived(
|
||||
struct.pack('!BBH', 4, 1, 34)
|
||||
+ socket.inet_aton('1.2.3.4')
|
||||
+ 'fooBAR'
|
||||
+ '\0')
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
self.assertEqual(sent,
|
||||
struct.pack('!BBH', 0, 90, 34)
|
||||
+ socket.inet_aton('1.2.3.4'))
|
||||
self.assert_(not self.sock.transport.stringTCPTransport_closing)
|
||||
self.assert_(self.sock.driver_outgoing is not None)
|
||||
|
||||
# pass some data through
|
||||
self.sock.dataReceived('hello, world')
|
||||
self.assertEqual(self.sock.driver_outgoing.transport.value(),
|
||||
'hello, world')
|
||||
|
||||
# the other way around
|
||||
self.sock.driver_outgoing.dataReceived('hi there')
|
||||
self.assertEqual(self.sock.transport.value(), 'hi there')
|
||||
|
||||
self.sock.connectionLost('fake reason')
|
||||
|
||||
|
||||
def test_socks4aSuccessfulResolution(self):
|
||||
"""
|
||||
If the destination IP address has zeros for the first three octets and
|
||||
non-zero for the fourth octet, the client is attempting a v4a
|
||||
connection. A hostname is specified after the user ID string and the
|
||||
server connects to the address that hostname resolves to.
|
||||
|
||||
@see: U{http://en.wikipedia.org/wiki/SOCKS#SOCKS_4a_protocol}
|
||||
"""
|
||||
# send the domain name "localhost" to be resolved
|
||||
clientRequest = (
|
||||
struct.pack('!BBH', 4, 1, 34)
|
||||
+ socket.inet_aton('0.0.0.1')
|
||||
+ 'fooBAZ\0'
|
||||
+ 'localhost\0')
|
||||
|
||||
# Deliver the bytes one by one to exercise the protocol's buffering
|
||||
# logic. FakeResolverReactor's resolve method is invoked to "resolve"
|
||||
# the hostname.
|
||||
for byte in clientRequest:
|
||||
self.sock.dataReceived(byte)
|
||||
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
|
||||
# Verify that the server responded with the address which will be
|
||||
# connected to.
|
||||
self.assertEqual(
|
||||
sent,
|
||||
struct.pack('!BBH', 0, 90, 34) + socket.inet_aton('127.0.0.1'))
|
||||
self.assertFalse(self.sock.transport.stringTCPTransport_closing)
|
||||
self.assertNotIdentical(self.sock.driver_outgoing, None)
|
||||
|
||||
# Pass some data through and verify it is forwarded to the outgoing
|
||||
# connection.
|
||||
self.sock.dataReceived('hello, world')
|
||||
self.assertEqual(
|
||||
self.sock.driver_outgoing.transport.value(), 'hello, world')
|
||||
|
||||
# Deliver some data from the output connection and verify it is
|
||||
# passed along to the incoming side.
|
||||
self.sock.driver_outgoing.dataReceived('hi there')
|
||||
self.assertEqual(self.sock.transport.value(), 'hi there')
|
||||
|
||||
self.sock.connectionLost('fake reason')
|
||||
|
||||
|
||||
def test_socks4aFailedResolution(self):
|
||||
"""
|
||||
Failed hostname resolution on a SOCKSv4a packet results in a 91 error
|
||||
response and the connection getting closed.
|
||||
"""
|
||||
# send the domain name "failinghost" to be resolved
|
||||
clientRequest = (
|
||||
struct.pack('!BBH', 4, 1, 34)
|
||||
+ socket.inet_aton('0.0.0.1')
|
||||
+ 'fooBAZ\0'
|
||||
+ 'failinghost\0')
|
||||
|
||||
# Deliver the bytes one by one to exercise the protocol's buffering
|
||||
# logic. FakeResolverReactor's resolve method is invoked to "resolve"
|
||||
# the hostname.
|
||||
for byte in clientRequest:
|
||||
self.sock.dataReceived(byte)
|
||||
|
||||
# Verify that the server responds with a 91 error.
|
||||
sent = self.sock.transport.value()
|
||||
self.assertEqual(
|
||||
sent,
|
||||
struct.pack('!BBH', 0, 91, 0) + socket.inet_aton('0.0.0.0'))
|
||||
|
||||
# A failed resolution causes the transport to drop the connection.
|
||||
self.assertTrue(self.sock.transport.stringTCPTransport_closing)
|
||||
self.assertIdentical(self.sock.driver_outgoing, None)
|
||||
|
||||
|
||||
def test_accessDenied(self):
|
||||
self.sock.authorize = lambda code, server, port, user: 0
|
||||
self.sock.dataReceived(
|
||||
struct.pack('!BBH', 4, 1, 4242)
|
||||
+ socket.inet_aton('10.2.3.4')
|
||||
+ 'fooBAR'
|
||||
+ '\0')
|
||||
self.assertEqual(self.sock.transport.value(),
|
||||
struct.pack('!BBH', 0, 91, 0)
|
||||
+ socket.inet_aton('0.0.0.0'))
|
||||
self.assert_(self.sock.transport.stringTCPTransport_closing)
|
||||
self.assertIdentical(self.sock.driver_outgoing, None)
|
||||
|
||||
|
||||
def test_eofRemote(self):
|
||||
self.sock.dataReceived(
|
||||
struct.pack('!BBH', 4, 1, 34)
|
||||
+ socket.inet_aton('1.2.3.4')
|
||||
+ 'fooBAR'
|
||||
+ '\0')
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
|
||||
# pass some data through
|
||||
self.sock.dataReceived('hello, world')
|
||||
self.assertEqual(self.sock.driver_outgoing.transport.value(),
|
||||
'hello, world')
|
||||
|
||||
# now close it from the server side
|
||||
self.sock.driver_outgoing.transport.loseConnection()
|
||||
self.sock.driver_outgoing.connectionLost('fake reason')
|
||||
|
||||
|
||||
def test_eofLocal(self):
|
||||
self.sock.dataReceived(
|
||||
struct.pack('!BBH', 4, 1, 34)
|
||||
+ socket.inet_aton('1.2.3.4')
|
||||
+ 'fooBAR'
|
||||
+ '\0')
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
|
||||
# pass some data through
|
||||
self.sock.dataReceived('hello, world')
|
||||
self.assertEqual(self.sock.driver_outgoing.transport.value(),
|
||||
'hello, world')
|
||||
|
||||
# now close it from the client side
|
||||
self.sock.connectionLost('fake reason')
|
||||
|
||||
|
||||
|
||||
class Bind(unittest.TestCase):
|
||||
"""
|
||||
Tests for SOCKS and SOCKSv4a bind requests using the L{SOCKSv4} protocol.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.sock = SOCKSv4Driver()
|
||||
self.sock.transport = StringTCPTransport()
|
||||
self.sock.connectionMade()
|
||||
self.sock.reactor = FakeResolverReactor({"localhost":"127.0.0.1"})
|
||||
|
||||
## def tearDown(self):
|
||||
## # TODO ensure the listen port is closed
|
||||
## listen = self.sock.driver_listen
|
||||
## if listen is not None:
|
||||
## self.assert_(incoming.transport.stringTCPTransport_closing,
|
||||
## "Incoming SOCKS connections need to be closed.")
|
||||
|
||||
def test_simple(self):
|
||||
self.sock.dataReceived(
|
||||
struct.pack('!BBH', 4, 2, 34)
|
||||
+ socket.inet_aton('1.2.3.4')
|
||||
+ 'fooBAR'
|
||||
+ '\0')
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
self.assertEqual(sent,
|
||||
struct.pack('!BBH', 0, 90, 1234)
|
||||
+ socket.inet_aton('6.7.8.9'))
|
||||
self.assert_(not self.sock.transport.stringTCPTransport_closing)
|
||||
self.assert_(self.sock.driver_listen is not None)
|
||||
|
||||
# connect
|
||||
incoming = self.sock.driver_listen.buildProtocol(('1.2.3.4', 5345))
|
||||
self.assertNotIdentical(incoming, None)
|
||||
incoming.transport = StringTCPTransport()
|
||||
incoming.connectionMade()
|
||||
|
||||
# now we should have the second reply packet
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
self.assertEqual(sent,
|
||||
struct.pack('!BBH', 0, 90, 0)
|
||||
+ socket.inet_aton('0.0.0.0'))
|
||||
self.assert_(not self.sock.transport.stringTCPTransport_closing)
|
||||
|
||||
# pass some data through
|
||||
self.sock.dataReceived('hello, world')
|
||||
self.assertEqual(incoming.transport.value(),
|
||||
'hello, world')
|
||||
|
||||
# the other way around
|
||||
incoming.dataReceived('hi there')
|
||||
self.assertEqual(self.sock.transport.value(), 'hi there')
|
||||
|
||||
self.sock.connectionLost('fake reason')
|
||||
|
||||
|
||||
def test_socks4a(self):
|
||||
"""
|
||||
If the destination IP address has zeros for the first three octets and
|
||||
non-zero for the fourth octet, the client is attempting a v4a
|
||||
connection. A hostname is specified after the user ID string and the
|
||||
server connects to the address that hostname resolves to.
|
||||
|
||||
@see: U{http://en.wikipedia.org/wiki/SOCKS#SOCKS_4a_protocol}
|
||||
"""
|
||||
# send the domain name "localhost" to be resolved
|
||||
clientRequest = (
|
||||
struct.pack('!BBH', 4, 2, 34)
|
||||
+ socket.inet_aton('0.0.0.1')
|
||||
+ 'fooBAZ\0'
|
||||
+ 'localhost\0')
|
||||
|
||||
# Deliver the bytes one by one to exercise the protocol's buffering
|
||||
# logic. FakeResolverReactor's resolve method is invoked to "resolve"
|
||||
# the hostname.
|
||||
for byte in clientRequest:
|
||||
self.sock.dataReceived(byte)
|
||||
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
|
||||
# Verify that the server responded with the address which will be
|
||||
# connected to.
|
||||
self.assertEqual(
|
||||
sent,
|
||||
struct.pack('!BBH', 0, 90, 1234) + socket.inet_aton('6.7.8.9'))
|
||||
self.assertFalse(self.sock.transport.stringTCPTransport_closing)
|
||||
self.assertNotIdentical(self.sock.driver_listen, None)
|
||||
|
||||
# connect
|
||||
incoming = self.sock.driver_listen.buildProtocol(('127.0.0.1', 5345))
|
||||
self.assertNotIdentical(incoming, None)
|
||||
incoming.transport = StringTCPTransport()
|
||||
incoming.connectionMade()
|
||||
|
||||
# now we should have the second reply packet
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
self.assertEqual(sent,
|
||||
struct.pack('!BBH', 0, 90, 0)
|
||||
+ socket.inet_aton('0.0.0.0'))
|
||||
self.assertNotIdentical(
|
||||
self.sock.transport.stringTCPTransport_closing, None)
|
||||
|
||||
# Deliver some data from the output connection and verify it is
|
||||
# passed along to the incoming side.
|
||||
self.sock.dataReceived('hi there')
|
||||
self.assertEqual(incoming.transport.value(), 'hi there')
|
||||
|
||||
# the other way around
|
||||
incoming.dataReceived('hi there')
|
||||
self.assertEqual(self.sock.transport.value(), 'hi there')
|
||||
|
||||
self.sock.connectionLost('fake reason')
|
||||
|
||||
|
||||
def test_socks4aFailedResolution(self):
|
||||
"""
|
||||
Failed hostname resolution on a SOCKSv4a packet results in a 91 error
|
||||
response and the connection getting closed.
|
||||
"""
|
||||
# send the domain name "failinghost" to be resolved
|
||||
clientRequest = (
|
||||
struct.pack('!BBH', 4, 2, 34)
|
||||
+ socket.inet_aton('0.0.0.1')
|
||||
+ 'fooBAZ\0'
|
||||
+ 'failinghost\0')
|
||||
|
||||
# Deliver the bytes one by one to exercise the protocol's buffering
|
||||
# logic. FakeResolverReactor's resolve method is invoked to "resolve"
|
||||
# the hostname.
|
||||
for byte in clientRequest:
|
||||
self.sock.dataReceived(byte)
|
||||
|
||||
# Verify that the server responds with a 91 error.
|
||||
sent = self.sock.transport.value()
|
||||
self.assertEqual(
|
||||
sent,
|
||||
struct.pack('!BBH', 0, 91, 0) + socket.inet_aton('0.0.0.0'))
|
||||
|
||||
# A failed resolution causes the transport to drop the connection.
|
||||
self.assertTrue(self.sock.transport.stringTCPTransport_closing)
|
||||
self.assertIdentical(self.sock.driver_outgoing, None)
|
||||
|
||||
|
||||
def test_accessDenied(self):
|
||||
self.sock.authorize = lambda code, server, port, user: 0
|
||||
self.sock.dataReceived(
|
||||
struct.pack('!BBH', 4, 2, 4242)
|
||||
+ socket.inet_aton('10.2.3.4')
|
||||
+ 'fooBAR'
|
||||
+ '\0')
|
||||
self.assertEqual(self.sock.transport.value(),
|
||||
struct.pack('!BBH', 0, 91, 0)
|
||||
+ socket.inet_aton('0.0.0.0'))
|
||||
self.assert_(self.sock.transport.stringTCPTransport_closing)
|
||||
self.assertIdentical(self.sock.driver_listen, None)
|
||||
|
||||
def test_eofRemote(self):
|
||||
self.sock.dataReceived(
|
||||
struct.pack('!BBH', 4, 2, 34)
|
||||
+ socket.inet_aton('1.2.3.4')
|
||||
+ 'fooBAR'
|
||||
+ '\0')
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
|
||||
# connect
|
||||
incoming = self.sock.driver_listen.buildProtocol(('1.2.3.4', 5345))
|
||||
self.assertNotIdentical(incoming, None)
|
||||
incoming.transport = StringTCPTransport()
|
||||
incoming.connectionMade()
|
||||
|
||||
# now we should have the second reply packet
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
self.assertEqual(sent,
|
||||
struct.pack('!BBH', 0, 90, 0)
|
||||
+ socket.inet_aton('0.0.0.0'))
|
||||
self.assert_(not self.sock.transport.stringTCPTransport_closing)
|
||||
|
||||
# pass some data through
|
||||
self.sock.dataReceived('hello, world')
|
||||
self.assertEqual(incoming.transport.value(),
|
||||
'hello, world')
|
||||
|
||||
# now close it from the server side
|
||||
incoming.transport.loseConnection()
|
||||
incoming.connectionLost('fake reason')
|
||||
|
||||
def test_eofLocal(self):
|
||||
self.sock.dataReceived(
|
||||
struct.pack('!BBH', 4, 2, 34)
|
||||
+ socket.inet_aton('1.2.3.4')
|
||||
+ 'fooBAR'
|
||||
+ '\0')
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
|
||||
# connect
|
||||
incoming = self.sock.driver_listen.buildProtocol(('1.2.3.4', 5345))
|
||||
self.assertNotIdentical(incoming, None)
|
||||
incoming.transport = StringTCPTransport()
|
||||
incoming.connectionMade()
|
||||
|
||||
# now we should have the second reply packet
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
self.assertEqual(sent,
|
||||
struct.pack('!BBH', 0, 90, 0)
|
||||
+ socket.inet_aton('0.0.0.0'))
|
||||
self.assert_(not self.sock.transport.stringTCPTransport_closing)
|
||||
|
||||
# pass some data through
|
||||
self.sock.dataReceived('hello, world')
|
||||
self.assertEqual(incoming.transport.value(),
|
||||
'hello, world')
|
||||
|
||||
# now close it from the client side
|
||||
self.sock.connectionLost('fake reason')
|
||||
|
||||
def test_badSource(self):
|
||||
self.sock.dataReceived(
|
||||
struct.pack('!BBH', 4, 2, 34)
|
||||
+ socket.inet_aton('1.2.3.4')
|
||||
+ 'fooBAR'
|
||||
+ '\0')
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
|
||||
# connect from WRONG address
|
||||
incoming = self.sock.driver_listen.buildProtocol(('1.6.6.6', 666))
|
||||
self.assertIdentical(incoming, None)
|
||||
|
||||
# Now we should have the second reply packet and it should
|
||||
# be a failure. The connection should be closing.
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
self.assertEqual(sent,
|
||||
struct.pack('!BBH', 0, 91, 0)
|
||||
+ socket.inet_aton('0.0.0.0'))
|
||||
self.assert_(self.sock.transport.stringTCPTransport_closing)
|
||||
727
Linux_i686/lib/python2.7/site-packages/twisted/test/test_ssl.py
Normal file
727
Linux_i686/lib/python2.7/site-packages/twisted/test/test_ssl.py
Normal file
|
|
@ -0,0 +1,727 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for twisted SSL support.
|
||||
"""
|
||||
from __future__ import division, absolute_import
|
||||
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import protocol, reactor, interfaces, defer
|
||||
from twisted.internet.error import ConnectionDone
|
||||
from twisted.protocols import basic
|
||||
from twisted.python.runtime import platform
|
||||
from twisted.test.test_tcp import ProperlyCloseFilesMixin
|
||||
|
||||
import os, errno
|
||||
|
||||
try:
|
||||
from OpenSSL import SSL, crypto
|
||||
from twisted.internet import ssl
|
||||
from twisted.test.ssl_helpers import ClientTLSContext, certPath
|
||||
except ImportError:
|
||||
def _noSSL():
|
||||
# ugh, make pyflakes happy.
|
||||
global SSL
|
||||
global ssl
|
||||
SSL = ssl = None
|
||||
_noSSL()
|
||||
|
||||
try:
|
||||
from twisted.protocols import tls as newTLS
|
||||
except ImportError:
|
||||
# Assuming SSL exists, we're using old version in reactor (i.e. non-protocol)
|
||||
newTLS = None
|
||||
|
||||
|
||||
class UnintelligentProtocol(basic.LineReceiver):
|
||||
"""
|
||||
@ivar deferred: a deferred that will fire at connection lost.
|
||||
@type deferred: L{defer.Deferred}
|
||||
|
||||
@cvar pretext: text sent before TLS is set up.
|
||||
@type pretext: C{bytes}
|
||||
|
||||
@cvar posttext: text sent after TLS is set up.
|
||||
@type posttext: C{bytes}
|
||||
"""
|
||||
pretext = [
|
||||
b"first line",
|
||||
b"last thing before tls starts",
|
||||
b"STARTTLS"]
|
||||
|
||||
posttext = [
|
||||
b"first thing after tls started",
|
||||
b"last thing ever"]
|
||||
|
||||
def __init__(self):
|
||||
self.deferred = defer.Deferred()
|
||||
|
||||
|
||||
def connectionMade(self):
|
||||
for l in self.pretext:
|
||||
self.sendLine(l)
|
||||
|
||||
|
||||
def lineReceived(self, line):
|
||||
if line == b"READY":
|
||||
self.transport.startTLS(ClientTLSContext(), self.factory.client)
|
||||
for l in self.posttext:
|
||||
self.sendLine(l)
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.deferred.callback(None)
|
||||
|
||||
|
||||
|
||||
class LineCollector(basic.LineReceiver):
|
||||
"""
|
||||
@ivar deferred: a deferred that will fire at connection lost.
|
||||
@type deferred: L{defer.Deferred}
|
||||
|
||||
@ivar doTLS: whether the protocol is initiate TLS or not.
|
||||
@type doTLS: C{bool}
|
||||
|
||||
@ivar fillBuffer: if set to True, it will send lots of data once
|
||||
C{STARTTLS} is received.
|
||||
@type fillBuffer: C{bool}
|
||||
"""
|
||||
|
||||
def __init__(self, doTLS, fillBuffer=False):
|
||||
self.doTLS = doTLS
|
||||
self.fillBuffer = fillBuffer
|
||||
self.deferred = defer.Deferred()
|
||||
|
||||
|
||||
def connectionMade(self):
|
||||
self.factory.rawdata = b''
|
||||
self.factory.lines = []
|
||||
|
||||
|
||||
def lineReceived(self, line):
|
||||
self.factory.lines.append(line)
|
||||
if line == b'STARTTLS':
|
||||
if self.fillBuffer:
|
||||
for x in range(500):
|
||||
self.sendLine(b'X' * 1000)
|
||||
self.sendLine(b'READY')
|
||||
if self.doTLS:
|
||||
ctx = ServerTLSContext(
|
||||
privateKeyFileName=certPath,
|
||||
certificateFileName=certPath,
|
||||
)
|
||||
self.transport.startTLS(ctx, self.factory.server)
|
||||
else:
|
||||
self.setRawMode()
|
||||
|
||||
|
||||
def rawDataReceived(self, data):
|
||||
self.factory.rawdata += data
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.deferred.callback(None)
|
||||
|
||||
|
||||
|
||||
class SingleLineServerProtocol(protocol.Protocol):
|
||||
"""
|
||||
A protocol that sends a single line of data at C{connectionMade}.
|
||||
"""
|
||||
|
||||
def connectionMade(self):
|
||||
self.transport.write(b"+OK <some crap>\r\n")
|
||||
self.transport.getPeerCertificate()
|
||||
|
||||
|
||||
|
||||
class RecordingClientProtocol(protocol.Protocol):
|
||||
"""
|
||||
@ivar deferred: a deferred that will fire with first received content.
|
||||
@type deferred: L{defer.Deferred}
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.deferred = defer.Deferred()
|
||||
|
||||
|
||||
def connectionMade(self):
|
||||
self.transport.getPeerCertificate()
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.deferred.callback(data)
|
||||
|
||||
|
||||
|
||||
class ImmediatelyDisconnectingProtocol(protocol.Protocol):
|
||||
"""
|
||||
A protocol that disconnect immediately on connection. It fires the
|
||||
C{connectionDisconnected} deferred of its factory on connetion lost.
|
||||
"""
|
||||
|
||||
def connectionMade(self):
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.factory.connectionDisconnected.callback(None)
|
||||
|
||||
|
||||
|
||||
def generateCertificateObjects(organization, organizationalUnit):
|
||||
"""
|
||||
Create a certificate for given C{organization} and C{organizationalUnit}.
|
||||
|
||||
@return: a tuple of (key, request, certificate) objects.
|
||||
"""
|
||||
pkey = crypto.PKey()
|
||||
pkey.generate_key(crypto.TYPE_RSA, 512)
|
||||
req = crypto.X509Req()
|
||||
subject = req.get_subject()
|
||||
subject.O = organization
|
||||
subject.OU = organizationalUnit
|
||||
req.set_pubkey(pkey)
|
||||
req.sign(pkey, "md5")
|
||||
|
||||
# Here comes the actual certificate
|
||||
cert = crypto.X509()
|
||||
cert.set_serial_number(1)
|
||||
cert.gmtime_adj_notBefore(0)
|
||||
cert.gmtime_adj_notAfter(60) # Testing certificates need not be long lived
|
||||
cert.set_issuer(req.get_subject())
|
||||
cert.set_subject(req.get_subject())
|
||||
cert.set_pubkey(req.get_pubkey())
|
||||
cert.sign(pkey, "md5")
|
||||
|
||||
return pkey, req, cert
|
||||
|
||||
|
||||
|
||||
def generateCertificateFiles(basename, organization, organizationalUnit):
|
||||
"""
|
||||
Create certificate files key, req and cert prefixed by C{basename} for
|
||||
given C{organization} and C{organizationalUnit}.
|
||||
"""
|
||||
pkey, req, cert = generateCertificateObjects(organization, organizationalUnit)
|
||||
|
||||
for ext, obj, dumpFunc in [
|
||||
('key', pkey, crypto.dump_privatekey),
|
||||
('req', req, crypto.dump_certificate_request),
|
||||
('cert', cert, crypto.dump_certificate)]:
|
||||
fName = os.extsep.join((basename, ext)).encode("utf-8")
|
||||
FilePath(fName).setContent(dumpFunc(crypto.FILETYPE_PEM, obj))
|
||||
|
||||
|
||||
|
||||
class ContextGeneratingMixin:
|
||||
"""
|
||||
Offer methods to create L{ssl.DefaultOpenSSLContextFactory} for both client
|
||||
and server.
|
||||
|
||||
@ivar clientBase: prefix of client certificate files.
|
||||
@type clientBase: C{str}
|
||||
|
||||
@ivar serverBase: prefix of server certificate files.
|
||||
@type serverBase: C{str}
|
||||
|
||||
@ivar clientCtxFactory: a generated context factory to be used in
|
||||
C{reactor.connectSSL}.
|
||||
@type clientCtxFactory: L{ssl.DefaultOpenSSLContextFactory}
|
||||
|
||||
@ivar serverCtxFactory: a generated context factory to be used in
|
||||
C{reactor.listenSSL}.
|
||||
@type serverCtxFactory: L{ssl.DefaultOpenSSLContextFactory}
|
||||
"""
|
||||
|
||||
def makeContextFactory(self, org, orgUnit, *args, **kwArgs):
|
||||
base = self.mktemp()
|
||||
generateCertificateFiles(base, org, orgUnit)
|
||||
serverCtxFactory = ssl.DefaultOpenSSLContextFactory(
|
||||
os.extsep.join((base, 'key')),
|
||||
os.extsep.join((base, 'cert')),
|
||||
*args, **kwArgs)
|
||||
|
||||
return base, serverCtxFactory
|
||||
|
||||
|
||||
def setupServerAndClient(self, clientArgs, clientKwArgs, serverArgs,
|
||||
serverKwArgs):
|
||||
self.clientBase, self.clientCtxFactory = self.makeContextFactory(
|
||||
*clientArgs, **clientKwArgs)
|
||||
self.serverBase, self.serverCtxFactory = self.makeContextFactory(
|
||||
*serverArgs, **serverKwArgs)
|
||||
|
||||
|
||||
|
||||
if SSL is not None:
|
||||
class ServerTLSContext(ssl.DefaultOpenSSLContextFactory):
|
||||
"""
|
||||
A context factory with a default method set to L{SSL.TLSv1_METHOD}.
|
||||
"""
|
||||
isClient = False
|
||||
|
||||
def __init__(self, *args, **kw):
|
||||
kw['sslmethod'] = SSL.TLSv1_METHOD
|
||||
ssl.DefaultOpenSSLContextFactory.__init__(self, *args, **kw)
|
||||
|
||||
|
||||
|
||||
class StolenTCPTestCase(ProperlyCloseFilesMixin, unittest.TestCase):
|
||||
"""
|
||||
For SSL transports, test many of the same things which are tested for
|
||||
TCP transports.
|
||||
"""
|
||||
|
||||
def createServer(self, address, portNumber, factory):
|
||||
"""
|
||||
Create an SSL server with a certificate using L{IReactorSSL.listenSSL}.
|
||||
"""
|
||||
cert = ssl.PrivateCertificate.loadPEM(FilePath(certPath).getContent())
|
||||
contextFactory = cert.options()
|
||||
return reactor.listenSSL(
|
||||
portNumber, factory, contextFactory, interface=address)
|
||||
|
||||
|
||||
def connectClient(self, address, portNumber, clientCreator):
|
||||
"""
|
||||
Create an SSL client using L{IReactorSSL.connectSSL}.
|
||||
"""
|
||||
contextFactory = ssl.CertificateOptions()
|
||||
return clientCreator.connectSSL(address, portNumber, contextFactory)
|
||||
|
||||
|
||||
def getHandleExceptionType(self):
|
||||
"""
|
||||
Return L{SSL.Error} as the expected error type which will be raised by
|
||||
a write to the L{OpenSSL.SSL.Connection} object after it has been
|
||||
closed.
|
||||
"""
|
||||
return SSL.Error
|
||||
|
||||
|
||||
def getHandleErrorCode(self):
|
||||
"""
|
||||
Return the argument L{SSL.Error} will be constructed with for this
|
||||
case. This is basically just a random OpenSSL implementation detail.
|
||||
It would be better if this test worked in a way which did not require
|
||||
this.
|
||||
"""
|
||||
# Windows 2000 SP 4 and Windows XP SP 2 give back WSAENOTSOCK for
|
||||
# SSL.Connection.write for some reason. The twisted.protocols.tls
|
||||
# implementation of IReactorSSL doesn't suffer from this imprecation,
|
||||
# though, since it is isolated from the Windows I/O layer (I suppose?).
|
||||
|
||||
# If test_properlyCloseFiles waited for the SSL handshake to complete
|
||||
# and performed an orderly shutdown, then this would probably be a
|
||||
# little less weird: writing to a shutdown SSL connection has a more
|
||||
# well-defined failure mode (or at least it should).
|
||||
|
||||
# So figure out if twisted.protocols.tls is in use. If it can be
|
||||
# imported, it should be.
|
||||
try:
|
||||
import twisted.protocols.tls
|
||||
except ImportError:
|
||||
# It isn't available, so we expect WSAENOTSOCK if we're on Windows.
|
||||
if platform.getType() == 'win32':
|
||||
return errno.WSAENOTSOCK
|
||||
|
||||
# Otherwise, we expect an error about how we tried to write to a
|
||||
# shutdown connection. This is terribly implementation-specific.
|
||||
return [('SSL routines', 'SSL_write', 'protocol is shutdown')]
|
||||
|
||||
|
||||
|
||||
class TLSTestCase(unittest.TestCase):
|
||||
"""
|
||||
Tests for startTLS support.
|
||||
|
||||
@ivar fillBuffer: forwarded to L{LineCollector.fillBuffer}
|
||||
@type fillBuffer: C{bool}
|
||||
"""
|
||||
fillBuffer = False
|
||||
|
||||
clientProto = None
|
||||
serverProto = None
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
if self.clientProto.transport is not None:
|
||||
self.clientProto.transport.loseConnection()
|
||||
if self.serverProto.transport is not None:
|
||||
self.serverProto.transport.loseConnection()
|
||||
|
||||
|
||||
def _runTest(self, clientProto, serverProto, clientIsServer=False):
|
||||
"""
|
||||
Helper method to run TLS tests.
|
||||
|
||||
@param clientProto: protocol instance attached to the client
|
||||
connection.
|
||||
@param serverProto: protocol instance attached to the server
|
||||
connection.
|
||||
@param clientIsServer: flag indicated if client should initiate
|
||||
startTLS instead of server.
|
||||
|
||||
@return: a L{defer.Deferred} that will fire when both connections are
|
||||
lost.
|
||||
"""
|
||||
self.clientProto = clientProto
|
||||
cf = self.clientFactory = protocol.ClientFactory()
|
||||
cf.protocol = lambda: clientProto
|
||||
if clientIsServer:
|
||||
cf.server = False
|
||||
else:
|
||||
cf.client = True
|
||||
|
||||
self.serverProto = serverProto
|
||||
sf = self.serverFactory = protocol.ServerFactory()
|
||||
sf.protocol = lambda: serverProto
|
||||
if clientIsServer:
|
||||
sf.client = False
|
||||
else:
|
||||
sf.server = True
|
||||
|
||||
port = reactor.listenTCP(0, sf, interface="127.0.0.1")
|
||||
self.addCleanup(port.stopListening)
|
||||
|
||||
reactor.connectTCP('127.0.0.1', port.getHost().port, cf)
|
||||
|
||||
return defer.gatherResults([clientProto.deferred, serverProto.deferred])
|
||||
|
||||
|
||||
def test_TLS(self):
|
||||
"""
|
||||
Test for server and client startTLS: client should received data both
|
||||
before and after the startTLS.
|
||||
"""
|
||||
def check(ignore):
|
||||
self.assertEqual(
|
||||
self.serverFactory.lines,
|
||||
UnintelligentProtocol.pretext + UnintelligentProtocol.posttext
|
||||
)
|
||||
d = self._runTest(UnintelligentProtocol(),
|
||||
LineCollector(True, self.fillBuffer))
|
||||
return d.addCallback(check)
|
||||
|
||||
|
||||
def test_unTLS(self):
|
||||
"""
|
||||
Test for server startTLS not followed by a startTLS in client: the data
|
||||
received after server startTLS should be received as raw.
|
||||
"""
|
||||
def check(ignored):
|
||||
self.assertEqual(
|
||||
self.serverFactory.lines,
|
||||
UnintelligentProtocol.pretext
|
||||
)
|
||||
self.failUnless(self.serverFactory.rawdata,
|
||||
"No encrypted bytes received")
|
||||
d = self._runTest(UnintelligentProtocol(),
|
||||
LineCollector(False, self.fillBuffer))
|
||||
return d.addCallback(check)
|
||||
|
||||
|
||||
def test_backwardsTLS(self):
|
||||
"""
|
||||
Test startTLS first initiated by client.
|
||||
"""
|
||||
def check(ignored):
|
||||
self.assertEqual(
|
||||
self.clientFactory.lines,
|
||||
UnintelligentProtocol.pretext + UnintelligentProtocol.posttext
|
||||
)
|
||||
d = self._runTest(LineCollector(True, self.fillBuffer),
|
||||
UnintelligentProtocol(), True)
|
||||
return d.addCallback(check)
|
||||
|
||||
|
||||
|
||||
class SpammyTLSTestCase(TLSTestCase):
|
||||
"""
|
||||
Test TLS features with bytes sitting in the out buffer.
|
||||
"""
|
||||
fillBuffer = True
|
||||
|
||||
|
||||
|
||||
class BufferingTestCase(unittest.TestCase):
|
||||
serverProto = None
|
||||
clientProto = None
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
if self.serverProto.transport is not None:
|
||||
self.serverProto.transport.loseConnection()
|
||||
if self.clientProto.transport is not None:
|
||||
self.clientProto.transport.loseConnection()
|
||||
|
||||
|
||||
def test_openSSLBuffering(self):
|
||||
serverProto = self.serverProto = SingleLineServerProtocol()
|
||||
clientProto = self.clientProto = RecordingClientProtocol()
|
||||
|
||||
server = protocol.ServerFactory()
|
||||
client = self.client = protocol.ClientFactory()
|
||||
|
||||
server.protocol = lambda: serverProto
|
||||
client.protocol = lambda: clientProto
|
||||
|
||||
sCTX = ssl.DefaultOpenSSLContextFactory(certPath, certPath)
|
||||
cCTX = ssl.ClientContextFactory()
|
||||
|
||||
port = reactor.listenSSL(0, server, sCTX, interface='127.0.0.1')
|
||||
self.addCleanup(port.stopListening)
|
||||
|
||||
reactor.connectSSL('127.0.0.1', port.getHost().port, client, cCTX)
|
||||
|
||||
return clientProto.deferred.addCallback(
|
||||
self.assertEqual, b"+OK <some crap>\r\n")
|
||||
|
||||
|
||||
|
||||
class ConnectionLostTestCase(unittest.TestCase, ContextGeneratingMixin):
|
||||
"""
|
||||
SSL connection closing tests.
|
||||
"""
|
||||
|
||||
def testImmediateDisconnect(self):
|
||||
org = "twisted.test.test_ssl"
|
||||
self.setupServerAndClient(
|
||||
(org, org + ", client"), {},
|
||||
(org, org + ", server"), {})
|
||||
|
||||
# Set up a server, connect to it with a client, which should work since our verifiers
|
||||
# allow anything, then disconnect.
|
||||
serverProtocolFactory = protocol.ServerFactory()
|
||||
serverProtocolFactory.protocol = protocol.Protocol
|
||||
self.serverPort = serverPort = reactor.listenSSL(0,
|
||||
serverProtocolFactory, self.serverCtxFactory)
|
||||
|
||||
clientProtocolFactory = protocol.ClientFactory()
|
||||
clientProtocolFactory.protocol = ImmediatelyDisconnectingProtocol
|
||||
clientProtocolFactory.connectionDisconnected = defer.Deferred()
|
||||
clientConnector = reactor.connectSSL('127.0.0.1',
|
||||
serverPort.getHost().port, clientProtocolFactory, self.clientCtxFactory)
|
||||
|
||||
return clientProtocolFactory.connectionDisconnected.addCallback(
|
||||
lambda ignoredResult: self.serverPort.stopListening())
|
||||
|
||||
|
||||
def test_bothSidesLoseConnection(self):
|
||||
"""
|
||||
Both sides of SSL connection close connection; the connections should
|
||||
close cleanly, and only after the underlying TCP connection has
|
||||
disconnected.
|
||||
"""
|
||||
class CloseAfterHandshake(protocol.Protocol):
|
||||
gotData = False
|
||||
|
||||
def __init__(self):
|
||||
self.done = defer.Deferred()
|
||||
|
||||
def connectionMade(self):
|
||||
self.transport.write(b"a")
|
||||
|
||||
def dataReceived(self, data):
|
||||
# If we got data, handshake is over:
|
||||
self.gotData = True
|
||||
self.transport.loseConnection()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
if not self.gotData:
|
||||
reason = RuntimeError("We never received the data!")
|
||||
self.done.errback(reason)
|
||||
del self.done
|
||||
|
||||
org = "twisted.test.test_ssl"
|
||||
self.setupServerAndClient(
|
||||
(org, org + ", client"), {},
|
||||
(org, org + ", server"), {})
|
||||
|
||||
serverProtocol = CloseAfterHandshake()
|
||||
serverProtocolFactory = protocol.ServerFactory()
|
||||
serverProtocolFactory.protocol = lambda: serverProtocol
|
||||
serverPort = reactor.listenSSL(0,
|
||||
serverProtocolFactory, self.serverCtxFactory)
|
||||
self.addCleanup(serverPort.stopListening)
|
||||
|
||||
clientProtocol = CloseAfterHandshake()
|
||||
clientProtocolFactory = protocol.ClientFactory()
|
||||
clientProtocolFactory.protocol = lambda: clientProtocol
|
||||
clientConnector = reactor.connectSSL('127.0.0.1',
|
||||
serverPort.getHost().port, clientProtocolFactory, self.clientCtxFactory)
|
||||
|
||||
def checkResult(failure):
|
||||
failure.trap(ConnectionDone)
|
||||
return defer.gatherResults(
|
||||
[clientProtocol.done.addErrback(checkResult),
|
||||
serverProtocol.done.addErrback(checkResult)])
|
||||
|
||||
if newTLS is None:
|
||||
test_bothSidesLoseConnection.skip = "Old SSL code doesn't always close cleanly."
|
||||
|
||||
|
||||
def testFailedVerify(self):
|
||||
org = "twisted.test.test_ssl"
|
||||
self.setupServerAndClient(
|
||||
(org, org + ", client"), {},
|
||||
(org, org + ", server"), {})
|
||||
|
||||
def verify(*a):
|
||||
return False
|
||||
self.clientCtxFactory.getContext().set_verify(SSL.VERIFY_PEER, verify)
|
||||
|
||||
serverConnLost = defer.Deferred()
|
||||
serverProtocol = protocol.Protocol()
|
||||
serverProtocol.connectionLost = serverConnLost.callback
|
||||
serverProtocolFactory = protocol.ServerFactory()
|
||||
serverProtocolFactory.protocol = lambda: serverProtocol
|
||||
self.serverPort = serverPort = reactor.listenSSL(0,
|
||||
serverProtocolFactory, self.serverCtxFactory)
|
||||
|
||||
clientConnLost = defer.Deferred()
|
||||
clientProtocol = protocol.Protocol()
|
||||
clientProtocol.connectionLost = clientConnLost.callback
|
||||
clientProtocolFactory = protocol.ClientFactory()
|
||||
clientProtocolFactory.protocol = lambda: clientProtocol
|
||||
clientConnector = reactor.connectSSL('127.0.0.1',
|
||||
serverPort.getHost().port, clientProtocolFactory, self.clientCtxFactory)
|
||||
|
||||
dl = defer.DeferredList([serverConnLost, clientConnLost], consumeErrors=True)
|
||||
return dl.addCallback(self._cbLostConns)
|
||||
|
||||
|
||||
def _cbLostConns(self, results):
|
||||
(sSuccess, sResult), (cSuccess, cResult) = results
|
||||
|
||||
self.failIf(sSuccess)
|
||||
self.failIf(cSuccess)
|
||||
|
||||
acceptableErrors = [SSL.Error]
|
||||
|
||||
# Rather than getting a verification failure on Windows, we are getting
|
||||
# a connection failure. Without something like sslverify proxying
|
||||
# in-between we can't fix up the platform's errors, so let's just
|
||||
# specifically say it is only OK in this one case to keep the tests
|
||||
# passing. Normally we'd like to be as strict as possible here, so
|
||||
# we're not going to allow this to report errors incorrectly on any
|
||||
# other platforms.
|
||||
|
||||
if platform.isWindows():
|
||||
from twisted.internet.error import ConnectionLost
|
||||
acceptableErrors.append(ConnectionLost)
|
||||
|
||||
sResult.trap(*acceptableErrors)
|
||||
cResult.trap(*acceptableErrors)
|
||||
|
||||
return self.serverPort.stopListening()
|
||||
|
||||
|
||||
|
||||
class FakeContext:
|
||||
"""
|
||||
L{OpenSSL.SSL.Context} double which can more easily be inspected.
|
||||
"""
|
||||
def __init__(self, method):
|
||||
self._method = method
|
||||
self._options = 0
|
||||
|
||||
|
||||
def set_options(self, options):
|
||||
self._options |= options
|
||||
|
||||
|
||||
def use_certificate_file(self, fileName):
|
||||
pass
|
||||
|
||||
|
||||
def use_privatekey_file(self, fileName):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class DefaultOpenSSLContextFactoryTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{ssl.DefaultOpenSSLContextFactory}.
|
||||
"""
|
||||
def setUp(self):
|
||||
# pyOpenSSL Context objects aren't introspectable enough. Pass in
|
||||
# an alternate context factory so we can inspect what is done to it.
|
||||
self.contextFactory = ssl.DefaultOpenSSLContextFactory(
|
||||
certPath, certPath, _contextFactory=FakeContext)
|
||||
self.context = self.contextFactory.getContext()
|
||||
|
||||
|
||||
def test_method(self):
|
||||
"""
|
||||
L{ssl.DefaultOpenSSLContextFactory.getContext} returns an SSL context
|
||||
which can use SSLv3 or TLSv1 but not SSLv2.
|
||||
"""
|
||||
# SSLv23_METHOD allows SSLv2, SSLv3, or TLSv1
|
||||
self.assertEqual(self.context._method, SSL.SSLv23_METHOD)
|
||||
|
||||
# And OP_NO_SSLv2 disables the SSLv2 support.
|
||||
self.assertTrue(self.context._options & SSL.OP_NO_SSLv2)
|
||||
|
||||
# Make sure SSLv3 and TLSv1 aren't disabled though.
|
||||
self.assertFalse(self.context._options & SSL.OP_NO_SSLv3)
|
||||
self.assertFalse(self.context._options & SSL.OP_NO_TLSv1)
|
||||
|
||||
|
||||
def test_missingCertificateFile(self):
|
||||
"""
|
||||
Instantiating L{ssl.DefaultOpenSSLContextFactory} with a certificate
|
||||
filename which does not identify an existing file results in the
|
||||
initializer raising L{OpenSSL.SSL.Error}.
|
||||
"""
|
||||
self.assertRaises(
|
||||
SSL.Error,
|
||||
ssl.DefaultOpenSSLContextFactory, certPath, self.mktemp())
|
||||
|
||||
|
||||
def test_missingPrivateKeyFile(self):
|
||||
"""
|
||||
Instantiating L{ssl.DefaultOpenSSLContextFactory} with a private key
|
||||
filename which does not identify an existing file results in the
|
||||
initializer raising L{OpenSSL.SSL.Error}.
|
||||
"""
|
||||
self.assertRaises(
|
||||
SSL.Error,
|
||||
ssl.DefaultOpenSSLContextFactory, self.mktemp(), certPath)
|
||||
|
||||
|
||||
|
||||
class ClientContextFactoryTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{ssl.ClientContextFactory}.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.contextFactory = ssl.ClientContextFactory()
|
||||
self.contextFactory._contextFactory = FakeContext
|
||||
self.context = self.contextFactory.getContext()
|
||||
|
||||
|
||||
def test_method(self):
|
||||
"""
|
||||
L{ssl.ClientContextFactory.getContext} returns a context which can use
|
||||
SSLv3 or TLSv1 but not SSLv2.
|
||||
"""
|
||||
self.assertEqual(self.context._method, SSL.SSLv23_METHOD)
|
||||
self.assertTrue(self.context._options & SSL.OP_NO_SSLv2)
|
||||
self.assertFalse(self.context._options & SSL.OP_NO_SSLv3)
|
||||
self.assertFalse(self.context._options & SSL.OP_NO_TLSv1)
|
||||
|
||||
|
||||
|
||||
if interfaces.IReactorSSL(reactor, None) is None:
|
||||
for tCase in [StolenTCPTestCase, TLSTestCase, SpammyTLSTestCase,
|
||||
BufferingTestCase, ConnectionLostTestCase,
|
||||
DefaultOpenSSLContextFactoryTests,
|
||||
ClientContextFactoryTests]:
|
||||
tCase.skip = "Reactor does not support SSL, cannot run SSL tests"
|
||||
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""
|
||||
Test cases for twisted.protocols.stateful
|
||||
"""
|
||||
|
||||
from twisted.trial.unittest import TestCase
|
||||
from twisted.protocols.test import test_basic
|
||||
from twisted.protocols.stateful import StatefulProtocol
|
||||
|
||||
from struct import pack, unpack, calcsize
|
||||
|
||||
|
||||
class MyInt32StringReceiver(StatefulProtocol):
|
||||
"""
|
||||
A stateful Int32StringReceiver.
|
||||
"""
|
||||
MAX_LENGTH = 99999
|
||||
structFormat = "!I"
|
||||
prefixLength = calcsize(structFormat)
|
||||
|
||||
def getInitialState(self):
|
||||
return self._getHeader, 4
|
||||
|
||||
def lengthLimitExceeded(self, length):
|
||||
self.transport.loseConnection()
|
||||
|
||||
def _getHeader(self, msg):
|
||||
length, = unpack("!i", msg)
|
||||
if length > self.MAX_LENGTH:
|
||||
self.lengthLimitExceeded(length)
|
||||
return
|
||||
return self._getString, length
|
||||
|
||||
def _getString(self, msg):
|
||||
self.stringReceived(msg)
|
||||
return self._getHeader, 4
|
||||
|
||||
def stringReceived(self, msg):
|
||||
"""
|
||||
Override this.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def sendString(self, data):
|
||||
"""
|
||||
Send an int32-prefixed string to the other end of the connection.
|
||||
"""
|
||||
self.transport.write(pack(self.structFormat, len(data)) + data)
|
||||
|
||||
|
||||
class TestInt32(MyInt32StringReceiver):
|
||||
def connectionMade(self):
|
||||
self.received = []
|
||||
|
||||
def stringReceived(self, s):
|
||||
self.received.append(s)
|
||||
|
||||
MAX_LENGTH = 50
|
||||
closed = 0
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.closed = 1
|
||||
|
||||
|
||||
class Int32TestCase(TestCase, test_basic.IntNTestCaseMixin):
|
||||
protocol = TestInt32
|
||||
strings = ["a", "b" * 16]
|
||||
illegalStrings = ["\x10\x00\x00\x00aaaaaa"]
|
||||
partialStrings = ["\x00\x00\x00", "hello there", ""]
|
||||
|
||||
def test_bigReceive(self):
|
||||
r = self.getProtocol()
|
||||
big = ""
|
||||
for s in self.strings * 4:
|
||||
big += pack("!i", len(s)) + s
|
||||
r.dataReceived(big)
|
||||
self.assertEqual(r.received, self.strings * 4)
|
||||
|
||||
|
|
@ -0,0 +1,371 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.internet.stdio}.
|
||||
"""
|
||||
|
||||
import os, sys, itertools
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.python import filepath, log
|
||||
from twisted.python.runtime import platform
|
||||
from twisted.internet import error, defer, protocol, stdio, reactor
|
||||
from twisted.test.test_tcp import ConnectionLostNotifyingProtocol
|
||||
|
||||
|
||||
# A short string which is intended to appear here and nowhere else,
|
||||
# particularly not in any random garbage output CPython unavoidable
|
||||
# generates (such as in warning text and so forth). This is searched
|
||||
# for in the output from stdio_test_lastwrite.py and if it is found at
|
||||
# the end, the functionality works.
|
||||
UNIQUE_LAST_WRITE_STRING = 'xyz123abc Twisted is great!'
|
||||
|
||||
skipWindowsNopywin32 = None
|
||||
if platform.isWindows():
|
||||
try:
|
||||
import win32process
|
||||
except ImportError:
|
||||
skipWindowsNopywin32 = ("On windows, spawnProcess is not available "
|
||||
"in the absence of win32process.")
|
||||
|
||||
|
||||
class StandardIOTestProcessProtocol(protocol.ProcessProtocol):
|
||||
"""
|
||||
Test helper for collecting output from a child process and notifying
|
||||
something when it exits.
|
||||
|
||||
@ivar onConnection: A L{defer.Deferred} which will be called back with
|
||||
C{None} when the connection to the child process is established.
|
||||
|
||||
@ivar onCompletion: A L{defer.Deferred} which will be errbacked with the
|
||||
failure associated with the child process exiting when it exits.
|
||||
|
||||
@ivar onDataReceived: A L{defer.Deferred} which will be called back with
|
||||
this instance whenever C{childDataReceived} is called, or C{None} to
|
||||
suppress these callbacks.
|
||||
|
||||
@ivar data: A C{dict} mapping file descriptors to strings containing all
|
||||
bytes received from the child process on each file descriptor.
|
||||
"""
|
||||
onDataReceived = None
|
||||
|
||||
def __init__(self):
|
||||
self.onConnection = defer.Deferred()
|
||||
self.onCompletion = defer.Deferred()
|
||||
self.data = {}
|
||||
|
||||
|
||||
def connectionMade(self):
|
||||
self.onConnection.callback(None)
|
||||
|
||||
|
||||
def childDataReceived(self, name, bytes):
|
||||
"""
|
||||
Record all bytes received from the child process in the C{data}
|
||||
dictionary. Fire C{onDataReceived} if it is not C{None}.
|
||||
"""
|
||||
self.data[name] = self.data.get(name, '') + bytes
|
||||
if self.onDataReceived is not None:
|
||||
d, self.onDataReceived = self.onDataReceived, None
|
||||
d.callback(self)
|
||||
|
||||
|
||||
def processEnded(self, reason):
|
||||
self.onCompletion.callback(reason)
|
||||
|
||||
|
||||
|
||||
class StandardInputOutputTestCase(unittest.TestCase):
|
||||
|
||||
skip = skipWindowsNopywin32
|
||||
|
||||
def _spawnProcess(self, proto, sibling, *args, **kw):
|
||||
"""
|
||||
Launch a child Python process and communicate with it using the
|
||||
given ProcessProtocol.
|
||||
|
||||
@param proto: A L{ProcessProtocol} instance which will be connected
|
||||
to the child process.
|
||||
|
||||
@param sibling: The basename of a file containing the Python program
|
||||
to run in the child process.
|
||||
|
||||
@param *args: strings which will be passed to the child process on
|
||||
the command line as C{argv[2:]}.
|
||||
|
||||
@param **kw: additional arguments to pass to L{reactor.spawnProcess}.
|
||||
|
||||
@return: The L{IProcessTransport} provider for the spawned process.
|
||||
"""
|
||||
import twisted
|
||||
subenv = dict(os.environ)
|
||||
subenv['PYTHONPATH'] = os.pathsep.join(
|
||||
[os.path.abspath(
|
||||
os.path.dirname(os.path.dirname(twisted.__file__))),
|
||||
subenv.get('PYTHONPATH', '')
|
||||
])
|
||||
args = [sys.executable,
|
||||
filepath.FilePath(__file__).sibling(sibling).path,
|
||||
reactor.__class__.__module__] + list(args)
|
||||
return reactor.spawnProcess(
|
||||
proto,
|
||||
sys.executable,
|
||||
args,
|
||||
env=subenv,
|
||||
**kw)
|
||||
|
||||
|
||||
def _requireFailure(self, d, callback):
|
||||
def cb(result):
|
||||
self.fail("Process terminated with non-Failure: %r" % (result,))
|
||||
def eb(err):
|
||||
return callback(err)
|
||||
return d.addCallbacks(cb, eb)
|
||||
|
||||
|
||||
def test_loseConnection(self):
|
||||
"""
|
||||
Verify that a protocol connected to L{StandardIO} can disconnect
|
||||
itself using C{transport.loseConnection}.
|
||||
"""
|
||||
errorLogFile = self.mktemp()
|
||||
log.msg("Child process logging to " + errorLogFile)
|
||||
p = StandardIOTestProcessProtocol()
|
||||
d = p.onCompletion
|
||||
self._spawnProcess(p, 'stdio_test_loseconn.py', errorLogFile)
|
||||
|
||||
def processEnded(reason):
|
||||
# Copy the child's log to ours so it's more visible.
|
||||
for line in file(errorLogFile):
|
||||
log.msg("Child logged: " + line.rstrip())
|
||||
|
||||
self.failIfIn(1, p.data)
|
||||
reason.trap(error.ProcessDone)
|
||||
return self._requireFailure(d, processEnded)
|
||||
|
||||
|
||||
def test_readConnectionLost(self):
|
||||
"""
|
||||
When stdin is closed and the protocol connected to it implements
|
||||
L{IHalfCloseableProtocol}, the protocol's C{readConnectionLost} method
|
||||
is called.
|
||||
"""
|
||||
errorLogFile = self.mktemp()
|
||||
log.msg("Child process logging to " + errorLogFile)
|
||||
p = StandardIOTestProcessProtocol()
|
||||
p.onDataReceived = defer.Deferred()
|
||||
|
||||
def cbBytes(ignored):
|
||||
d = p.onCompletion
|
||||
p.transport.closeStdin()
|
||||
return d
|
||||
p.onDataReceived.addCallback(cbBytes)
|
||||
|
||||
def processEnded(reason):
|
||||
reason.trap(error.ProcessDone)
|
||||
d = self._requireFailure(p.onDataReceived, processEnded)
|
||||
|
||||
self._spawnProcess(
|
||||
p, 'stdio_test_halfclose.py', errorLogFile)
|
||||
return d
|
||||
|
||||
|
||||
def test_lastWriteReceived(self):
|
||||
"""
|
||||
Verify that a write made directly to stdout using L{os.write}
|
||||
after StandardIO has finished is reliably received by the
|
||||
process reading that stdout.
|
||||
"""
|
||||
p = StandardIOTestProcessProtocol()
|
||||
|
||||
# Note: the OS X bug which prompted the addition of this test
|
||||
# is an apparent race condition involving non-blocking PTYs.
|
||||
# Delaying the parent process significantly increases the
|
||||
# likelihood of the race going the wrong way. If you need to
|
||||
# fiddle with this code at all, uncommenting the next line
|
||||
# will likely make your life much easier. It is commented out
|
||||
# because it makes the test quite slow.
|
||||
|
||||
# p.onConnection.addCallback(lambda ign: __import__('time').sleep(5))
|
||||
|
||||
try:
|
||||
self._spawnProcess(
|
||||
p, 'stdio_test_lastwrite.py', UNIQUE_LAST_WRITE_STRING,
|
||||
usePTY=True)
|
||||
except ValueError, e:
|
||||
# Some platforms don't work with usePTY=True
|
||||
raise unittest.SkipTest(str(e))
|
||||
|
||||
def processEnded(reason):
|
||||
"""
|
||||
Asserts that the parent received the bytes written by the child
|
||||
immediately after the child starts.
|
||||
"""
|
||||
self.assertTrue(
|
||||
p.data[1].endswith(UNIQUE_LAST_WRITE_STRING),
|
||||
"Received %r from child, did not find expected bytes." % (
|
||||
p.data,))
|
||||
reason.trap(error.ProcessDone)
|
||||
return self._requireFailure(p.onCompletion, processEnded)
|
||||
|
||||
|
||||
def test_hostAndPeer(self):
|
||||
"""
|
||||
Verify that the transport of a protocol connected to L{StandardIO}
|
||||
has C{getHost} and C{getPeer} methods.
|
||||
"""
|
||||
p = StandardIOTestProcessProtocol()
|
||||
d = p.onCompletion
|
||||
self._spawnProcess(p, 'stdio_test_hostpeer.py')
|
||||
|
||||
def processEnded(reason):
|
||||
host, peer = p.data[1].splitlines()
|
||||
self.failUnless(host)
|
||||
self.failUnless(peer)
|
||||
reason.trap(error.ProcessDone)
|
||||
return self._requireFailure(d, processEnded)
|
||||
|
||||
|
||||
def test_write(self):
|
||||
"""
|
||||
Verify that the C{write} method of the transport of a protocol
|
||||
connected to L{StandardIO} sends bytes to standard out.
|
||||
"""
|
||||
p = StandardIOTestProcessProtocol()
|
||||
d = p.onCompletion
|
||||
|
||||
self._spawnProcess(p, 'stdio_test_write.py')
|
||||
|
||||
def processEnded(reason):
|
||||
self.assertEqual(p.data[1], 'ok!')
|
||||
reason.trap(error.ProcessDone)
|
||||
return self._requireFailure(d, processEnded)
|
||||
|
||||
|
||||
def test_writeSequence(self):
|
||||
"""
|
||||
Verify that the C{writeSequence} method of the transport of a
|
||||
protocol connected to L{StandardIO} sends bytes to standard out.
|
||||
"""
|
||||
p = StandardIOTestProcessProtocol()
|
||||
d = p.onCompletion
|
||||
|
||||
self._spawnProcess(p, 'stdio_test_writeseq.py')
|
||||
|
||||
def processEnded(reason):
|
||||
self.assertEqual(p.data[1], 'ok!')
|
||||
reason.trap(error.ProcessDone)
|
||||
return self._requireFailure(d, processEnded)
|
||||
|
||||
|
||||
def _junkPath(self):
|
||||
junkPath = self.mktemp()
|
||||
junkFile = file(junkPath, 'w')
|
||||
for i in xrange(1024):
|
||||
junkFile.write(str(i) + '\n')
|
||||
junkFile.close()
|
||||
return junkPath
|
||||
|
||||
|
||||
def test_producer(self):
|
||||
"""
|
||||
Verify that the transport of a protocol connected to L{StandardIO}
|
||||
is a working L{IProducer} provider.
|
||||
"""
|
||||
p = StandardIOTestProcessProtocol()
|
||||
d = p.onCompletion
|
||||
|
||||
written = []
|
||||
toWrite = range(100)
|
||||
|
||||
def connectionMade(ign):
|
||||
if toWrite:
|
||||
written.append(str(toWrite.pop()) + "\n")
|
||||
proc.write(written[-1])
|
||||
reactor.callLater(0.01, connectionMade, None)
|
||||
|
||||
proc = self._spawnProcess(p, 'stdio_test_producer.py')
|
||||
|
||||
p.onConnection.addCallback(connectionMade)
|
||||
|
||||
def processEnded(reason):
|
||||
self.assertEqual(p.data[1], ''.join(written))
|
||||
self.failIf(toWrite, "Connection lost with %d writes left to go." % (len(toWrite),))
|
||||
reason.trap(error.ProcessDone)
|
||||
return self._requireFailure(d, processEnded)
|
||||
|
||||
|
||||
def test_consumer(self):
|
||||
"""
|
||||
Verify that the transport of a protocol connected to L{StandardIO}
|
||||
is a working L{IConsumer} provider.
|
||||
"""
|
||||
p = StandardIOTestProcessProtocol()
|
||||
d = p.onCompletion
|
||||
|
||||
junkPath = self._junkPath()
|
||||
|
||||
self._spawnProcess(p, 'stdio_test_consumer.py', junkPath)
|
||||
|
||||
def processEnded(reason):
|
||||
self.assertEqual(p.data[1], file(junkPath).read())
|
||||
reason.trap(error.ProcessDone)
|
||||
return self._requireFailure(d, processEnded)
|
||||
|
||||
|
||||
def test_normalFileStandardOut(self):
|
||||
"""
|
||||
If L{StandardIO} is created with a file descriptor which refers to a
|
||||
normal file (ie, a file from the filesystem), L{StandardIO.write}
|
||||
writes bytes to that file. In particular, it does not immediately
|
||||
consider the file closed or call its protocol's C{connectionLost}
|
||||
method.
|
||||
"""
|
||||
onConnLost = defer.Deferred()
|
||||
proto = ConnectionLostNotifyingProtocol(onConnLost)
|
||||
path = filepath.FilePath(self.mktemp())
|
||||
self.normal = normal = path.open('w')
|
||||
self.addCleanup(normal.close)
|
||||
|
||||
kwargs = dict(stdout=normal.fileno())
|
||||
if not platform.isWindows():
|
||||
# Make a fake stdin so that StandardIO doesn't mess with the *real*
|
||||
# stdin.
|
||||
r, w = os.pipe()
|
||||
self.addCleanup(os.close, r)
|
||||
self.addCleanup(os.close, w)
|
||||
kwargs['stdin'] = r
|
||||
connection = stdio.StandardIO(proto, **kwargs)
|
||||
|
||||
# The reactor needs to spin a bit before it might have incorrectly
|
||||
# decided stdout is closed. Use this counter to keep track of how
|
||||
# much we've let it spin. If it closes before we expected, this
|
||||
# counter will have a value that's too small and we'll know.
|
||||
howMany = 5
|
||||
count = itertools.count()
|
||||
|
||||
def spin():
|
||||
for value in count:
|
||||
if value == howMany:
|
||||
connection.loseConnection()
|
||||
return
|
||||
connection.write(str(value))
|
||||
break
|
||||
reactor.callLater(0, spin)
|
||||
reactor.callLater(0, spin)
|
||||
|
||||
# Once the connection is lost, make sure the counter is at the
|
||||
# appropriate value.
|
||||
def cbLost(reason):
|
||||
self.assertEqual(count.next(), howMany + 1)
|
||||
self.assertEqual(
|
||||
path.getContent(),
|
||||
''.join(map(str, range(howMany))))
|
||||
onConnLost.addCallback(cbLost)
|
||||
return onConnLost
|
||||
|
||||
if platform.isWindows():
|
||||
test_normalFileStandardOut.skip = (
|
||||
"StandardIO does not accept stdout as an argument to Windows. "
|
||||
"Testing redirection to a file is therefore harder.")
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue