378 lines
12 KiB
Python
378 lines
12 KiB
Python
|
# -*- test-case-name: twisted.test.test_loopback -*-
|
||
|
# Copyright (c) Twisted Matrix Laboratories.
|
||
|
# See LICENSE for details.
|
||
|
|
||
|
"""
|
||
|
Testing support for protocols -- loopback between client and server.
|
||
|
"""
|
||
|
|
||
|
from __future__ import division, absolute_import
|
||
|
|
||
|
# system imports
|
||
|
import tempfile
|
||
|
|
||
|
from zope.interface import implementer
|
||
|
|
||
|
# Twisted Imports
|
||
|
from twisted.protocols import policies
|
||
|
from twisted.internet import interfaces, protocol, main, defer
|
||
|
from twisted.internet.task import deferLater
|
||
|
from twisted.python import failure
|
||
|
from twisted.internet.interfaces import IAddress
|
||
|
|
||
|
|
||
|
class _LoopbackQueue(object):
|
||
|
"""
|
||
|
Trivial wrapper around a list to give it an interface like a queue, which
|
||
|
the addition of also sending notifications by way of a Deferred whenever
|
||
|
the list has something added to it.
|
||
|
"""
|
||
|
|
||
|
_notificationDeferred = None
|
||
|
disconnect = False
|
||
|
|
||
|
def __init__(self):
|
||
|
self._queue = []
|
||
|
|
||
|
|
||
|
def put(self, v):
|
||
|
self._queue.append(v)
|
||
|
if self._notificationDeferred is not None:
|
||
|
d, self._notificationDeferred = self._notificationDeferred, None
|
||
|
d.callback(None)
|
||
|
|
||
|
|
||
|
def __nonzero__(self):
|
||
|
return bool(self._queue)
|
||
|
__bool__ = __nonzero__
|
||
|
|
||
|
|
||
|
def get(self):
|
||
|
return self._queue.pop(0)
|
||
|
|
||
|
|
||
|
|
||
|
@implementer(IAddress)
|
||
|
class _LoopbackAddress(object):
|
||
|
pass
|
||
|
|
||
|
|
||
|
|
||
|
@implementer(interfaces.ITransport, interfaces.IConsumer)
|
||
|
class _LoopbackTransport(object):
|
||
|
disconnecting = False
|
||
|
producer = None
|
||
|
|
||
|
# ITransport
|
||
|
def __init__(self, q):
|
||
|
self.q = q
|
||
|
|
||
|
def write(self, data):
|
||
|
if not isinstance(data, bytes):
|
||
|
raise TypeError("Can only write bytes to ITransport")
|
||
|
self.q.put(data)
|
||
|
|
||
|
def writeSequence(self, iovec):
|
||
|
self.q.put(b''.join(iovec))
|
||
|
|
||
|
def loseConnection(self):
|
||
|
self.q.disconnect = True
|
||
|
self.q.put(None)
|
||
|
|
||
|
def getPeer(self):
|
||
|
return _LoopbackAddress()
|
||
|
|
||
|
def getHost(self):
|
||
|
return _LoopbackAddress()
|
||
|
|
||
|
# IConsumer
|
||
|
def registerProducer(self, producer, streaming):
|
||
|
assert self.producer is None
|
||
|
self.producer = producer
|
||
|
self.streamingProducer = streaming
|
||
|
self._pollProducer()
|
||
|
|
||
|
def unregisterProducer(self):
|
||
|
assert self.producer is not None
|
||
|
self.producer = None
|
||
|
|
||
|
def _pollProducer(self):
|
||
|
if self.producer is not None and not self.streamingProducer:
|
||
|
self.producer.resumeProducing()
|
||
|
|
||
|
|
||
|
|
||
|
def identityPumpPolicy(queue, target):
|
||
|
"""
|
||
|
L{identityPumpPolicy} is a policy which delivers each chunk of data written
|
||
|
to the given queue as-is to the target.
|
||
|
|
||
|
This isn't a particularly realistic policy.
|
||
|
|
||
|
@see: L{loopbackAsync}
|
||
|
"""
|
||
|
while queue:
|
||
|
bytes = queue.get()
|
||
|
if bytes is None:
|
||
|
break
|
||
|
target.dataReceived(bytes)
|
||
|
|
||
|
|
||
|
|
||
|
def collapsingPumpPolicy(queue, target):
|
||
|
"""
|
||
|
L{collapsingPumpPolicy} is a policy which collapses all outstanding chunks
|
||
|
into a single string and delivers it to the target.
|
||
|
|
||
|
@see: L{loopbackAsync}
|
||
|
"""
|
||
|
bytes = []
|
||
|
while queue:
|
||
|
chunk = queue.get()
|
||
|
if chunk is None:
|
||
|
break
|
||
|
bytes.append(chunk)
|
||
|
if bytes:
|
||
|
target.dataReceived(b''.join(bytes))
|
||
|
|
||
|
|
||
|
|
||
|
def loopbackAsync(server, client, pumpPolicy=identityPumpPolicy):
|
||
|
"""
|
||
|
Establish a connection between C{server} and C{client} then transfer data
|
||
|
between them until the connection is closed. This is often useful for
|
||
|
testing a protocol.
|
||
|
|
||
|
@param server: The protocol instance representing the server-side of this
|
||
|
connection.
|
||
|
|
||
|
@param client: The protocol instance representing the client-side of this
|
||
|
connection.
|
||
|
|
||
|
@param pumpPolicy: When either C{server} or C{client} writes to its
|
||
|
transport, the string passed in is added to a queue of data for the
|
||
|
other protocol. Eventually, C{pumpPolicy} will be called with one such
|
||
|
queue and the corresponding protocol object. The pump policy callable
|
||
|
is responsible for emptying the queue and passing the strings it
|
||
|
contains to the given protocol's C{dataReceived} method. The signature
|
||
|
of C{pumpPolicy} is C{(queue, protocol)}. C{queue} is an object with a
|
||
|
C{get} method which will return the next string written to the
|
||
|
transport, or C{None} if the transport has been disconnected, and which
|
||
|
evaluates to C{True} if and only if there are more items to be
|
||
|
retrieved via C{get}.
|
||
|
|
||
|
@return: A L{Deferred} which fires when the connection has been closed and
|
||
|
both sides have received notification of this.
|
||
|
"""
|
||
|
serverToClient = _LoopbackQueue()
|
||
|
clientToServer = _LoopbackQueue()
|
||
|
|
||
|
server.makeConnection(_LoopbackTransport(serverToClient))
|
||
|
client.makeConnection(_LoopbackTransport(clientToServer))
|
||
|
|
||
|
return _loopbackAsyncBody(
|
||
|
server, serverToClient, client, clientToServer, pumpPolicy)
|
||
|
|
||
|
|
||
|
|
||
|
def _loopbackAsyncBody(server, serverToClient, client, clientToServer,
|
||
|
pumpPolicy):
|
||
|
"""
|
||
|
Transfer bytes from the output queue of each protocol to the input of the other.
|
||
|
|
||
|
@param server: The protocol instance representing the server-side of this
|
||
|
connection.
|
||
|
|
||
|
@param serverToClient: The L{_LoopbackQueue} holding the server's output.
|
||
|
|
||
|
@param client: The protocol instance representing the client-side of this
|
||
|
connection.
|
||
|
|
||
|
@param clientToServer: The L{_LoopbackQueue} holding the client's output.
|
||
|
|
||
|
@param pumpPolicy: See L{loopbackAsync}.
|
||
|
|
||
|
@return: A L{Deferred} which fires when the connection has been closed and
|
||
|
both sides have received notification of this.
|
||
|
"""
|
||
|
def pump(source, q, target):
|
||
|
sent = False
|
||
|
if q:
|
||
|
pumpPolicy(q, target)
|
||
|
sent = True
|
||
|
if sent and not q:
|
||
|
# A write buffer has now been emptied. Give any producer on that
|
||
|
# side an opportunity to produce more data.
|
||
|
source.transport._pollProducer()
|
||
|
|
||
|
return sent
|
||
|
|
||
|
while 1:
|
||
|
disconnect = clientSent = serverSent = False
|
||
|
|
||
|
# Deliver the data which has been written.
|
||
|
serverSent = pump(server, serverToClient, client)
|
||
|
clientSent = pump(client, clientToServer, server)
|
||
|
|
||
|
if not clientSent and not serverSent:
|
||
|
# Neither side wrote any data. Wait for some new data to be added
|
||
|
# before trying to do anything further.
|
||
|
d = defer.Deferred()
|
||
|
clientToServer._notificationDeferred = d
|
||
|
serverToClient._notificationDeferred = d
|
||
|
d.addCallback(
|
||
|
_loopbackAsyncContinue,
|
||
|
server, serverToClient, client, clientToServer, pumpPolicy)
|
||
|
return d
|
||
|
if serverToClient.disconnect:
|
||
|
# The server wants to drop the connection. Flush any remaining
|
||
|
# data it has.
|
||
|
disconnect = True
|
||
|
pump(server, serverToClient, client)
|
||
|
elif clientToServer.disconnect:
|
||
|
# The client wants to drop the connection. Flush any remaining
|
||
|
# data it has.
|
||
|
disconnect = True
|
||
|
pump(client, clientToServer, server)
|
||
|
if disconnect:
|
||
|
# Someone wanted to disconnect, so okay, the connection is gone.
|
||
|
server.connectionLost(failure.Failure(main.CONNECTION_DONE))
|
||
|
client.connectionLost(failure.Failure(main.CONNECTION_DONE))
|
||
|
return defer.succeed(None)
|
||
|
|
||
|
|
||
|
|
||
|
def _loopbackAsyncContinue(ignored, server, serverToClient, client,
|
||
|
clientToServer, pumpPolicy):
|
||
|
# Clear the Deferred from each message queue, since it has already fired
|
||
|
# and cannot be used again.
|
||
|
clientToServer._notificationDeferred = None
|
||
|
serverToClient._notificationDeferred = None
|
||
|
|
||
|
# Schedule some more byte-pushing to happen. This isn't done
|
||
|
# synchronously because no actual transport can re-enter dataReceived as
|
||
|
# a result of calling write, and doing this synchronously could result
|
||
|
# in that.
|
||
|
from twisted.internet import reactor
|
||
|
return deferLater(
|
||
|
reactor, 0,
|
||
|
_loopbackAsyncBody,
|
||
|
server, serverToClient, client, clientToServer, pumpPolicy)
|
||
|
|
||
|
|
||
|
|
||
|
@implementer(interfaces.ITransport, interfaces.IConsumer)
|
||
|
class LoopbackRelay:
|
||
|
buffer = ''
|
||
|
shouldLose = 0
|
||
|
disconnecting = 0
|
||
|
producer = None
|
||
|
|
||
|
def __init__(self, target, logFile=None):
|
||
|
self.target = target
|
||
|
self.logFile = logFile
|
||
|
|
||
|
def write(self, data):
|
||
|
self.buffer = self.buffer + data
|
||
|
if self.logFile:
|
||
|
self.logFile.write("loopback writing %s\n" % repr(data))
|
||
|
|
||
|
def writeSequence(self, iovec):
|
||
|
self.write("".join(iovec))
|
||
|
|
||
|
def clearBuffer(self):
|
||
|
if self.shouldLose == -1:
|
||
|
return
|
||
|
|
||
|
if self.producer:
|
||
|
self.producer.resumeProducing()
|
||
|
if self.buffer:
|
||
|
if self.logFile:
|
||
|
self.logFile.write("loopback receiving %s\n" % repr(self.buffer))
|
||
|
buffer = self.buffer
|
||
|
self.buffer = ''
|
||
|
self.target.dataReceived(buffer)
|
||
|
if self.shouldLose == 1:
|
||
|
self.shouldLose = -1
|
||
|
self.target.connectionLost(failure.Failure(main.CONNECTION_DONE))
|
||
|
|
||
|
def loseConnection(self):
|
||
|
if self.shouldLose != -1:
|
||
|
self.shouldLose = 1
|
||
|
|
||
|
def getHost(self):
|
||
|
return 'loopback'
|
||
|
|
||
|
def getPeer(self):
|
||
|
return 'loopback'
|
||
|
|
||
|
def registerProducer(self, producer, streaming):
|
||
|
self.producer = producer
|
||
|
|
||
|
def unregisterProducer(self):
|
||
|
self.producer = None
|
||
|
|
||
|
def logPrefix(self):
|
||
|
return 'Loopback(%r)' % (self.target.__class__.__name__,)
|
||
|
|
||
|
|
||
|
|
||
|
class LoopbackClientFactory(protocol.ClientFactory):
|
||
|
|
||
|
def __init__(self, protocol):
|
||
|
self.disconnected = 0
|
||
|
self.deferred = defer.Deferred()
|
||
|
self.protocol = protocol
|
||
|
|
||
|
def buildProtocol(self, addr):
|
||
|
return self.protocol
|
||
|
|
||
|
def clientConnectionLost(self, connector, reason):
|
||
|
self.disconnected = 1
|
||
|
self.deferred.callback(None)
|
||
|
|
||
|
|
||
|
class _FireOnClose(policies.ProtocolWrapper):
|
||
|
def __init__(self, protocol, factory):
|
||
|
policies.ProtocolWrapper.__init__(self, protocol, factory)
|
||
|
self.deferred = defer.Deferred()
|
||
|
|
||
|
def connectionLost(self, reason):
|
||
|
policies.ProtocolWrapper.connectionLost(self, reason)
|
||
|
self.deferred.callback(None)
|
||
|
|
||
|
|
||
|
def loopbackTCP(server, client, port=0, noisy=True):
|
||
|
"""Run session between server and client protocol instances over TCP."""
|
||
|
from twisted.internet import reactor
|
||
|
f = policies.WrappingFactory(protocol.Factory())
|
||
|
serverWrapper = _FireOnClose(f, server)
|
||
|
f.noisy = noisy
|
||
|
f.buildProtocol = lambda addr: serverWrapper
|
||
|
serverPort = reactor.listenTCP(port, f, interface='127.0.0.1')
|
||
|
clientF = LoopbackClientFactory(client)
|
||
|
clientF.noisy = noisy
|
||
|
reactor.connectTCP('127.0.0.1', serverPort.getHost().port, clientF)
|
||
|
d = clientF.deferred
|
||
|
d.addCallback(lambda x: serverWrapper.deferred)
|
||
|
d.addCallback(lambda x: serverPort.stopListening())
|
||
|
return d
|
||
|
|
||
|
|
||
|
def loopbackUNIX(server, client, noisy=True):
|
||
|
"""Run session between server and client protocol instances over UNIX socket."""
|
||
|
path = tempfile.mktemp()
|
||
|
from twisted.internet import reactor
|
||
|
f = policies.WrappingFactory(protocol.Factory())
|
||
|
serverWrapper = _FireOnClose(f, server)
|
||
|
f.noisy = noisy
|
||
|
f.buildProtocol = lambda addr: serverWrapper
|
||
|
serverPort = reactor.listenUNIX(path, f)
|
||
|
clientF = LoopbackClientFactory(client)
|
||
|
clientF.noisy = noisy
|
||
|
reactor.connectUNIX(path, clientF)
|
||
|
d = clientF.deferred
|
||
|
d.addCallback(lambda x: serverWrapper.deferred)
|
||
|
d.addCallback(lambda x: serverPort.stopListening())
|
||
|
return d
|