# rpc.py
# Copyright (C) 2003, 2004 by David Handy
"""
Remote Procedure Call

This module enables a program to call methods on an object in another
program.

Example:

A simple "client" and "server". The Client and Server classes are
instantiated when the connection is complete. The client blocks until it
connects; for asynchronous connection create a Connector and call run()
in a background thread.

In this example the client and server run on different threads, but they
could just as well run on different computers.
    
>>> # from cpif.rpc import Listener, connect
>>> import threading
>>> class Client:
...     def __init__(self, connection):
...         self.connection = connection
...     def callMe(self):
...         print 'server called client'
>>> class Server:
...     def __init__(self, connection):
...         self.connection = connection
...     def add(self, a, b):
...         return a + b
...     def callMeBack(self):
...         self.connection.callMethod('callMe', (), None)
...     def stop(self):
...         self.connection.node.stop()
...         return 'goodbye'
>>> # start server
>>> listener = Listener(('127.0.0.1', 9098), Server)
>>> threading.Thread(target=listener.run).start()
>>> # start client
>>> connection = connect(('127.0.0.1', 9098), Client)
>>> proxy = connection.getBlockingProxy(1.0)
>>> print proxy.add(3, 4)
7
>>> proxy.callMeBack()
server called client
>>> print proxy.stop()
goodbye
"""

import copy
import cPickle
import cStringIO
import errno
import Queue
import select
import sets
import socket
import struct
import sys
import threading
import time
import traceback

module = sys.modules[__name__]


__all__ = [
    # Exceptions
    'RpcException',
        'RemoteException',
            'MethodNotFound', 'ExceptionInMethod', 'UnmarshallingException',
        'LocalException',
            'TimeoutException', 'Disconnected',
    # Classes
    'LimitedUnpickler', 'Listener', 'Connector', 'Connection',
    'Proxy', 'BlockingProxy', 'Result', 'InProcNode',
    # Functions
    'connect',
    'getInProcConnection',
    'lookupInetHost',
    'lookupInetHostInBackground',
    ]


if hasattr(errno, 'WSAEWOULDBLOCK'):
    _expected_socket_read_errors = (
                errno.WSAEWOULDBLOCK,
                errno.WSAEINPROGRESS,
                errno.WSAEINTR,
               )
else:
    # On some systems some of these codes are alaises
    _expected_socket_read_errors = tuple(dict.fromkeys((
                errno.EWOULDBLOCK,
                errno.EAGAIN,
                errno.EINTR,
                )).keys())

if hasattr(errno, 'WSAEWOULDBLOCK'):
    _expected_socket_write_errors = (
                errno.WSAEWOULDBLOCK,
                errno.WSAEINPROGRESS,
               )
else:
    # On some systems some of these codes are alaises
    _expected_socket_write_errors = tuple(dict.fromkeys((
                errno.EAGAIN,
                errno.EWOULDBLOCK,
                errno.ENOBUFS,
                errno.EINTR,
                errno.EINPROGRESS,
                )).keys())


class RpcException(Exception):
    """Base class of all rpc exceptions."""

class RemoteException(RpcException):
    """Exception raised by the remote side of a Remote Procedure Call."""

class MethodNotFound(RemoteException):
    """An attempt was made to call a method that doesn't exist on the remote
    object."""
    def __init__(self, method_name, obj):
        """
        method_name: String containing the name of the method
        obj: Object that was searched for the method
        """
        try:
            obj_repr = repr(obj)
        except:
            obj_repr = '<error-in-object-repr>'
        RemoteException.__init__(self, method_name, obj_repr)
    def __str__(self):
        return "MethodNotFound: '%s' on %s" % tuple(self.args[:2])

class ExceptionInMethod(RemoteException):
    """An exception occurred when attempting to execute the method on the
    remote object."""
    def __init__(self, method_name, obj, exc_info=None):
        """
        method_name: String containing the name of the method
        obj: Object implementing the method
        exc_info: Exception info returned by sys.exc_info(). If None,
                  sys.exc_info() will be called.

        Stores a string containing up to 200 stack trace entries from the
        exception traceback.
        """
        try:
            obj_repr = repr(obj)
        except:
            obj_repr = '<error-in-object-repr>'
        if exc_info:
            etype, value, tb = exc_info
        else:
            etype, value, tb = sys.exc_info()
        sb = cStringIO.StringIO()
        traceback.print_exception(etype, value, tb, limit=200, file=sb)
        RemoteException.__init__(self, method_name, obj_repr, sb.getvalue())
    def __str__(self):
        return "ExceptionInMethod: '%s' on %s:\n%s" % tuple(self.args[:3])

class UnmarshallingException(RemoteException):
    """An error occurred when trying to unmarshall (read into an object) a
    message received via an rpc Connection."""
    def __init__(self, exc_info=None):
        """
        exc_info: Exception info returned by sys.exc_info(). If None,
                  sys.exc_info() will be called.

        Stores a string containing up to 200 stack trace entries from the
        exception traceback.
        """
        if exc_info:
            etype, value, tb = exc_info
        else:
            etype, value, tb = sys.exc_info()
        sb = cStringIO.StringIO()
        traceback.print_exception(etype, value, tb, limit=200, file=sb)
        RemoteException.__init__(self, sb.getvalue())
    def __str__(self):
        return "UnmarshallingException: %s" % str(self.args[0])

class LocalException(RpcException):
    """An exception raised by the local side of a remote procedure call."""

class TimeoutException(LocalException):
    """Timed out waiting for the remote connection."""

class Disconnected(LocalException):
    """Disconnected from the remote rpc node."""


class LimitedUnpickler:
    """
    Class for unmarshalling messages. Uses cPickle.Unpickler, but for
    increased security it is limited in what it can load. The only class
    instances it is allowed to unmarshall are instances of RemoteException
    or any of its derived classes defined in this module. It can also
    unmarshall the standard built-in Python types: string, int, float, list,
    tuple, dict, bool, None, etc.

    You could override find_global to allow it to unmarshall instances of
    other classes in other modules. See the pickle module documentation for
    more details.
    """

    def __init__(self, fileobj):
        self._u = cPickle.Unpickler(fileobj)
        self._u.find_global = self.find_global

    def find_global(self, module_name, class_name):
        # module_name must be the name of this module. I'm also allowing
        # '__main__' so we can run this module as a script for testing. The
        # class name must be a subclass of RemoteException.
        if module_name not in ('__main__', module.__name__):
            raise ImportError("LimitedUnpickler: can't import '%s'" %
                              module_name)
        class_obj = getattr(module, class_name)
        if not issubclass(class_obj, RemoteException):
            raise AttributeError("LimitedUnpickler: can't access '%s'" %
                                 class_name)
        return class_obj

    def load(self):
        return self._u.load()


# I suppose I have to write a lengthy justification for why I wrote the
# _SocketMultiplexer/_SocketWrapper micro-framework instead of using
# asyncore.  Here are some good reasons:
#
# 1. asyncore has module-level side-effects that prevent it from being used
# by multiple components that don't know about each other.  An attempt was
# made to fix asyncore by adding a map parameter to its functions and
# methods, but that code was definitely broken as of Python 2.3.
#
# 2. _SocketMultiplexer lets you safely add and remove objects from the
# socket map from multiple threads *without* requiring locks around every
# access to the socket map, by scheduling actions to take place at the
# start of the select loop.
#
# 3. _SocketWrapper has a simpler and more explicit API than
# asyncore.dispatcher. There are less methods to override, and no
# __getattr__ nonsense.
#
# 4. asycore.dispatcher has state attributes that are not applicable for
# many kinds of sockets, i.e. connected, accepting. In contrast, I use the
# type of the _SocketWrapper derived class to tell what state it is in.
# For example, I move a socket from the connecting state to the connected
# state like this:
#
#     def handleWrite(self):
#         self.mux.removeSocket(self)
#         self.mux.addSocket(MyConnectedSocketWrapper(self.sock))
#
# Thus more socket behavior is delegated to the _SocketWrapper derived
# classes, which makes for better encapsulation.
#
# 5. _SocketMultiplexer has a "graceful shutdown" feature. Objects in the
# socket map that have a flush_on_shutdown attribute (and the value is true)
# will be allowed to finish writing before the micro-framework exits.
#
# In any case, this micro-framework doesn't take much code space and it
# works just fine.

class _SocketMultiplexer:
    """
    Wraps call to select.select.
    
    Dispatches read and write events to socket wrappers.
    """

    def __init__(self):
        self.__socket_set = sets.Set()
        self.__stop = False
        self.__poll_thread = None
        self.__poll_thread_lock = threading.Lock()
        self.__signal_sender = None
        self.__actions = []
        self.__actions_lock = threading.Lock()

    def addSocket(self, sock_wrapper):
        sock_wrapper._handleRegister(self)
        self._addAction(self._addSocket, sock_wrapper)
        self.kick()

    def isStopped(self):
        return self.__stop

    def kick(self):
        """Cause the select loop to un-block and re-run itself."""
        if threading.currentThread() is not self.__poll_thread:
            if self.__signal_sender:
                self.__signal_sender.send('!')

    def removeSocket(self, sock_wrapper):
        self._addAction(self._removeSocket, sock_wrapper)
        self.kick()

    def run(self):
        """
        Run the select loop.

        Exits when all non-background sockets are closed.
        (Call stop() to close all sockets.)
        """
        while self.poll(timeout=60.0):
            pass

    def poll(self, timeout=None):
        """
        Run the select loop once and dispatch any messages that are
        complete.

        timeout -- the number of seconds to block waiting for activity,
                   0 to not block at all, and None to block indefinitely.

        Return True if the multiplexer is still running, False if it has
        been stopped.
        """
        self.__poll_thread_lock.acquire()
        try:
            if self.__poll_thread is not None:
                raise LocalException(
                    "Can't poll() from multiple threads simultaneously")
            self.__poll_thread = threading.currentThread()
        finally:
            self.__poll_thread_lock.release()
        try:
            self._init()
            self._loopSetup()
            recv_socks = []
            send_socks = []
            background_count = 0
            for sock_wrapper in self.__socket_set:
                if sock_wrapper.background:
                    background_count += 1
                if sock_wrapper.readable():
                    recv_socks.append(sock_wrapper)
                if sock_wrapper.writeable():
                    send_socks.append(sock_wrapper)
            if self.isStopped():
                recv_socks = []
            debug = getattr(self, 'debug', False)
            if debug:
                print "----------> Before select:"
                print "background_count =", background_count
                print "recv_socks =", recv_socks
                print "send_socks =", send_socks
                print "isStopped =", self.isStopped()
                print "len(self.__socket_set) =", len(self.__socket_set)
            if (background_count == len(self.__socket_set) or
                len(recv_socks) + len(send_socks) == 0):
                # Multiplexer has been stopped, shutdown is complete
                # Clean up and don't call this method on this object again
                self._cleanup()
                self.poll = self._polling_stopped
                return False
            try:
                r, w = select.select(recv_socks, send_socks, [], 
                                     timeout)[:2]
            except (select.error, socket.error):
                e = sys.exc_info()[1]
                if hasattr(errno, 'WSAEINTR'):
                    ok_code = errno.WSAEINTR
                else:
                    ok_code = errno.EINTR
                if (e[0] != ok_code and 
                     self._closeBadSockets(recv_socks, send_socks) == 0):
                    # unrecoverable exception
                    raise
                return True # exception Ok, keep running
            if debug:
                print "----------> After select:"
                print "r =", r
                print "w =", w
            for sock_wrapper in r:
                try:
                    sock_wrapper.handleRead()
                except:
                    traceback.print_exc()
            for sock_wrapper in w:
                try:
                    sock_wrapper.handleWrite()
                except:
                    traceback.print_exc()
        finally:
            self.__poll_thread = None
        return True

    def _polling_stopped(self, timeout=None):
        raise LocalException(
                "Cannot poll() after stop() is complete.")

    def stop(self):
        """Cause the sockets to all close themselves, thus shutting down the
        main loop."""
        self._addAction(self._stop)
        self.kick()

    def _addAction(self, bound_method, *params):
        # Schedule an action to be executed at the beginning of the select
        # loop. These actions are executed this way to avoid modifying the
        # socket map while we are iterating through it.
        self.__actions_lock.acquire()
        try:
            self.__actions.append((bound_method, params))
        finally:
            self.__actions_lock.release()

    def _init(self):
        # One-time initialization of the multiplexer --
        # Find a port number and create a signal socket.
        # Technique from ZODB ZEO/zrpc/trigger.py
        a = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        w = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

        # set TCP_NODELAY to true to avoid buffering
        w.setsockopt(socket.IPPROTO_TCP, 1, 1)

        # tricky: get a pair of connected sockets
        for port in xrange(10000, 11000):
            address = ('127.0.0.1', port)
            try:
                a.bind(address)
            except socket.error:
                continue
            else:
                break
        else:
            raise RuntimeError, "Can't find available port for signal socket."

        a.listen(1)
        w.setblocking(0)
        try:
            w.connect(address)
        except socket.error:
            pass # assume 'would block' exception or the like
        r, addr = a.accept()
        a.close()
        w.setblocking(1)
        self.__signal_sender = w # un-wrapped socket
        self.addSocket(_SignalSocketWrapper(r))
        self._init = self._init_completed

    def _init_completed(self):
        pass

    def _loopSetup(self):
        # Execute actions that were scheduled by _addAction()
        self.__actions_lock.acquire()
        try:
            actions = self.__actions
            self.__actions = []
        finally:
            self.__actions_lock.release()
        for bound_method, params in actions:
            bound_method(*params)

    def _closeBadSockets(self, recv_socks, send_socks):
        closed_count = 0
        m = sets.Set(recv_socks)
        m.update(send_socks)
        for sock in m:
            try:
                select.select([sock], [sock], [sock], 0.0)
            except (select.error, socket.error):
                self._closeCarefully(sock)
                closed_count += 1
        return closed_count

    def _cleanup(self):
        self.__signal_sender.close()
        self.__signal_sender = None
        debug = getattr(self, 'debug', False)
        sockets = self.__socket_set.copy()
        for sock_wrapper in sockets:
            if debug:
                print "Closing leftover socket:", sock_wrapper
            self._closeCarefully(sock_wrapper)
        if debug:
            print "select loop done"; sys.stdout.flush()

    def _closeCarefully(self, sock_wrapper):
        # Close a socket wrapper without crashing the server, even
        # if there is some kind of error.
        # Prevent the _SocketWrapper from calling back to _SocketMultiplexer
        # to remove itself. If we wanted that to happen, we would have
        # just called sock_wrapper.close().
        self._removeSocket(sock_wrapper)
        try:
            sock_wrapper.close()
        except:
            traceback.print_exc()

    # Action methods

    def _addSocket(self, sock_wrapper):
        self.__socket_set.add(sock_wrapper)

    def _removeSocket(self, sock_wrapper):
        self.__socket_set.discard(sock_wrapper)
        try:
            sock_wrapper._handleUnregister(self)
        except:
            traceback.print_exc()

    def _stop(self):
        self.__stop = True
        still_sending = False
        removable_sockets = []
        for sock_wrapper in self.__socket_set:
            if sock_wrapper.flush_on_shutdown:
                if sock_wrapper.writeable():
                    # Allow sending sockets to complete
                    still_sending = True
                    continue
            removable_sockets.append(sock_wrapper)
        for sock_wrapper in removable_sockets:
            self._closeCarefully(sock_wrapper)
        if still_sending:
            self._addAction(self._stop)


class _SocketWrapper(object):
    """
    Base class that wraps a socket.socket object.

    Derived classes implement these methods:
        readable() - return True iff this object can accept data
        writeable() - return True iff this object has data to write
        handleRead() - read from self.sock/accept connection
        handleWrite() - write to self.sock/complete connection

    The following read-only properties are available:
        mux - the _SocketMultiplexer, or None if mux.addSocket() not called
        sock - the socket object used for reads and writes

    The following methods are provided:
        close() - un-register this object and close the socket
        fileno() - return a file descriptor compatible with select.select()
    """

    # If true, let this socket finish writing before the multiplexer exits.
    flush_on_shutdown = False

    # If true, marks this socket as a background socket.
    # The multiplexer exits when only background sockets are active.
    background = False

    def __init__(self, sock):
        """
        sock: The instance of socket.socket that does the reads and writes.
        """
        self.__mux = None
        self.__sock = sock
        self.__fd = sock.fileno()

    def __repr__(self):
        return '%s(%s)' % (self.__class__.__name__, self.__sock)

    def __str__(self):
        sock = self.__sock
        if hasattr(sock, 'getsockname'):
            name = sock.getsockname()
        else:
            name = str(sock)
        return '%s(%s)' % (self.__class__.__name__, name)

    ################################################################
    # These methods are all you have to worry about implementing.

    def readable(self):
        return True

    def writeable(self):
        return True

    def handleRead(self):
        raise NotImplementedError

    def handleWrite(self):
        raise NotImplementedError

    ################################################################
    # Properties

    def __getmux(self): return self.__mux

    mux = property(__getmux, doc="""
    The instance of _SocketMultiplexer that manages this _SocketWrapper.
    """)

    def __getsock(self): return self.__sock

    sock = property(__getsock, doc="""
    The instance of socket.socket that does the reads and writes.
    """)

    ################################################################
    # Methods that you or the _SocketMultiplexer may call

    def close(self):
        if self.mux:
            self.mux.removeSocket(self)
        self.sock.close()

    def fileno(self):
        return self.__fd

    ################################################################
    # Implementation details

    def _handleRegister(self, mux):
        # Called by the _SocketMultiplexer when it adds this object to its
        # socket map. Normally this method is not overridden - if it is, the
        # derived class must call the base class. Read the code first.
        if self.__mux is not None and self.__mux is not mux:
            raise LocalException(
                "A _SocketWrapper may only belong to one _SocketMultiplexer.")
        self.__mux = mux

    def _handleUnregister(self, mux):
        # Called by the _SocketMultiplexer when it removes this object from
        # its socket map. Normally this method is not overridden - if it is,
        # the derived class must call the base class. Read the code first.
        if self.__mux is not None and self.__mux is not mux:
            raise LocalException(
                "A _SocketWrapper may only belong to one _SocketMultiplexer.")
        self.__mux = None


class _SignalSocketWrapper(_SocketWrapper):
    """Socket for kicking the socket multiplexer to wake up and re-run the
    select loop."""

    background = True

    def __init__(self, sock):
        super(_SignalSocketWrapper, self).__init__(sock)

    def readable(self):
        return True

    def writeable(self):
        return False

    def handleRead(self):
        # Just flush the data, the content is not important.
        self.sock.recv(8192)

    def handleWrite(self):
        raise LocalException("This method should not have been called.")


class _ConnectingSocket(_SocketWrapper):
    """Socket that is in the process of completing a TCP/IP connection.
    
    Once the connection is complete, connectedWrapperFactory is called.
    This object is removed from the socket map and the returned value is
    added in its place.

    This is a template class. Derive your own class from this and define
    connectedWrapperFactory() to return the connected _SocketWrapper instance
    of your choice.
    """

    def __init__(self, sock):
        super(_ConnectingSocket, self).__init__(sock)

    def readable(self):
        return False

    def writeable(self):
        return True

    def handleWrite(self):
        # Socket has become connected.
        self.mux.removeSocket(self)
        self.mux.addSocket(self.connectedWrapperFactory(self.sock))

    def connectedWrapperFactory(self, sock):
        """Return an instance of a class derived from _SocketWrapper."""
        raise NotImplementedError
        # return _SocketWrapperDerived(self.mux, self.sock)


def _getTCPListeningSocket(addr_or_sock):
    try:
        if isinstance(addr_or_sock, socket.socket):
            return addr_or_sock
    except TypeError:
        pass
    addr = addr_or_sock
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    # try to re-use a server port if possible
    try:
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    except socket.error:
        pass
    sock.bind(addr)
    sock.listen(5)
    return sock


class _TCPListener(_SocketWrapper):
    """
    Create a TCP listening socket, and accept incoming connections.
    
    Override connectedWrapperFactory to wrap accepted sockets in the class
    of your choice. Accepted sockets are set to non-blocking mode by
    default.

    Override checkAddr() to do access control by connecting address.
    """

    def __init__(self, addr_or_sock):
        """
        addr_or_sock:
            A (host, port) address to listen on, or a socket or socket-like
            object ready for listening.
        """
        sock = _getTCPListeningSocket(addr_or_sock)
        sock.setblocking(0)
        super(_TCPListener, self).__init__(sock)

    def readable(self):
        return True

    def writeable(self):
        return False

    def handleRead(self):
        c, addr = self.sock.accept()
        self._handleAccept(c, addr)

    def handleWrite(self):
        raise LocalException("This method should never be called.")

    def connectedWrapperFactory(self, sock):
        """
        Return an instance of a _SocketWrapper derived class.
        """
        raise NotImplementedError

    def checkAddr(self, addr):
        """
        Return true iff a socket with this address is allowed to connect.
        (Override the base class method to do real access checking.)
        """
        return True

    def start(self, mux):
        """Start accepting connections, and register with the multiplexor."""
        try:
            c, addr = self.sock.accept()
        except socket.error, e:
            if not e[0] in _expected_socket_read_errors:
                raise
            c, addr = None, None
        mux.addSocket(self)
        if c:
            if not self.checkAddr(addr):
                c.close()
                return
            self._handleAccept(c, addr)

    def _handleAccept(self, connecting_sock, connecting_addr):
        # Called (normally) by the multiplexor when there is an incoming
        # connection.
        if not self.checkAddr(connecting_addr):
            connecting_sock.close()
            return
        connecting_sock.setblocking(0)
        self.mux.addSocket(self.connectedWrapperFactory(connecting_sock))


class _TCPConnector(_SocketWrapper):
    """
    Create a non-blocking TCP socket and initiate a connection with a
    listening TCP socket.

    Override connectedWrapperFactory() to convert this socket to the wrapped
    socket type of your choice once the connection is completed.
    """

    def __init__(self, addr):
        self.addr = addr
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.setblocking(0)
        super(_TCPConnector, self).__init__(sock)

    def readable(self):
        return False

    def writeable(self):
        return True

    def handleRead(self):
        raise LocalException("This method should never be called.")

    def handleWrite(self):
        self.mux.removeSocket(self)
        self.mux.addSocket(self.connectedWrapperFactory(self.sock))

    def connectedWrapperFactory(self, sock):
        """
        Return an instance of a _SocketWrapper derived class.
        """
        raise NotImplementedError

    def start(self, mux):
        """
        Start the process of connecting to the remote socket, and
        register with the multiplexor.
        """
        connected_immediately = False
        try:
            self.sock.connect(self.addr)
            connected_immediately = True
        except socket.error, e:
            if not e[0] in _expected_socket_write_errors:
                raise
        mux.addSocket(self)
        if connected_immediately:
            self.handleWrite()


# All of the following _SocketWrapper classes are designed specifically for
# rpc sockets.


class _RpcListenerSocket(_TCPListener):
    """Socket listening for incoming rpc connections."""

    def __init__(self, addr, node):
        super(_RpcListenerSocket, self).__init__(addr)
        self.__node = node

    def connectedWrapperFactory(self, sock):
        return _RpcConnectionSocket(sock, self.__node)


class _RpcConnectingSocket(_TCPConnector):
    """Socket that is in the process of making a connection to an rpc
    Listener. Once the connection is complete it will change itself into a
    _RpcConnectionSocket."""

    def __init__(self, addr, node):
        super(_RpcConnectingSocket, self).__init__(addr)
        self.__node = node

    def connectedWrapperFactory(self, sock):
        """Return an instance of a class derived from _SocketWrapper."""
        return _RpcConnectionSocket(sock, self.__node)


class _RpcConnectionSocket(_SocketWrapper):
    """Socket for an rpc Connection.
    
    Handles read and write events, converts incoming packets into messages,
    and passes them on to the connection object. Buffers and sends outgoing
    packets.
    """

    flush_on_shutdown = True

    def __init__(self, sock, node):
        """
        sock: the socket.socket (or simular) object for communication
        node: the Listener node managing this connection
        """
        super(_RpcConnectionSocket, self).__init__(sock)
        self.__node = node
        self.__receive_it = self._receiver()
        self.__send_it = self._sender()
        self.__out_q = []
        self.connection = Connection(self, node)

    def close(self):
        self.connection._cancelAllCalls(None)
        super(_RpcConnectionSocket, self).close()

    def readable(self):
        return True

    def writeable(self):
        return len(self.__out_q)

    def handleRead(self):
        try:
            self.__receive_it.next()
            return
        except (StopIteration, EOFError):
            self.connection._handleDisconnect()
        except socket.error, e:
            self.connection._handleError(e)
        self.close()

    def handleWrite(self):
        try:
            self.__send_it.next()
            return
        except StopIteration:
            self.connection._handleDisconnect()
        except socket.error, e:
            self.connection._handleError(e)
        self.close()

    def _readBytes(self, n):
        try:
            data = self.sock.recv(n)
        except socket.error, e:
            if e[0] in _expected_socket_read_errors:
                return None
            else:
                raise
        else:
            if data:
                return data
            else:
                raise EOFError

    def _receiver(self):
        # Generator-iterator that reads the 4-byte packet length, then reads
        # the packet.
        while True:
            count_str = ''
            while len(count_str) < 4:
                data = self._readBytes(4 - len(count_str))
                if data is None:
                    yield None
                    continue
                count_str += data
            count = struct.unpack(">i", count_str)[0]
            if count < 0:
                # bad packet, close the connection
                return
            message = ''
            while count:
                data = self._readBytes(count)
                if data is None:
                    yield None
                    continue
                message += data
                count -= len(data)
            self.connection._handleMessage(message)

    def _sender(self):
        # Generator-iterator that writes the 4-byte packet length, then
        # writes the packet contents.
        while True:
            try:
                message = self.__out_q.pop(0)
            except IndexError:
                message = None
            if message is None:
                yield None
                continue
            packet = struct.pack(">i", len(message)) + message
            while packet:
                try:
                    n = self.sock.send(packet)
                except socket.error, e:
                    if e[0] in _expected_socket_write_errors:
                        yield None
                    else:
                        raise
                else:
                    packet = packet[n:]

    ###########################################
    # Interface for Connection to send messages
    # May be called by any thread

    def sendMessage(self, message):
        """
        Send a string message to the remote end of the connected socket.
        """
        self.__out_q.append(message)
        mux = self.mux
        if mux:
            mux.kick()


class _Node(object):
    """
    Base class for Listener and Connector.

    Handles message events. Converts between Python objects and string
    messages using the supplied dumper and loader factories.
    """

    def __init__(self, addr, handlerFactory, 
                 dumperFactory=cPickle.Pickler,
                 loaderFactory=LimitedUnpickler):
        self.addr = addr
        self.handlerFactory = handlerFactory
        self.dumperFactory = dumperFactory
        self.loaderFactory = loaderFactory
        self._mux = _SocketMultiplexer()
        self._disconnectHandler = self._defaultDisconnectHandler
        self._errorHandler = self._defaultErrorHandler

    def __repr__(self):
        return '%s(%s, %s)' % (self.__class__.__name__, self.addr,
                               self.handlerFactory)

    def poll(self, timeout=None):
        """
        Poll the RPC node. Dispatch all messages that are ready.

        timeout -- the number of seconds to block waiting for activity,
                   0 to not block at all, and None to block indefinitely.

        Return True if the node is still running, False if it has
        been stopped.
        """
        return self._mux.poll(timeout=timeout)

    def run(self):
        """
        Run the RPC node. Exit when all connections are closed.
        """
        self._mux.run()

    def stop(self):
        """
        Close all current connections and stop accepting new requests.
        Results that are already being transmitted will be allowed to
        complete.
        """
        self._mux.stop()

    def onDisconnectCall(self, disconnectHandler):
        """
        Call disconnectedHandler(connection) when the remote node closes
        the connection normally.

        Return the previous error handler.
        """
        prev_handler = self._disconnectHandler
        self._disconnectHandler = disconnectHandler
        return prev_handler

    def onErrorCall(self, errorHandler):
        """
        Call errorHandler(connection, error) on communication errors, such
        as failure to connect.

        Return the previous error handler.
        """
        prev_handler = self._errorHandler
        self._errorHandler = errorHandler
        return prev_handler

    def dump(self, obj):
        """Convert an object to transport format"""
        sb = cStringIO.StringIO()
        self.dumperFactory(sb).dump(obj)
        return sb.getvalue()

    def load(self, s):
        """Convert from transport format to Python object"""
        return self.loaderFactory(cStringIO.StringIO(s)).load()

    def _handleDisconnect(self, connection):
        if self._disconnectHandler:
            self._disconnectHandler(connection)

    def _handleError(self, connection, error):
        if self._errorHandler:
            self._errorHandler(connection, error)

    def _defaultDisconnectHandler(self, connection):
        if getattr(self._mux, 'debug', False):
            print connection, "disconnected"
        connection.close()

    def _defaultErrorHandler(self, connection, error):
        sys.stderr.write("Error on " + repr(connection) + ": " +
                         str(error) + '\n')
        connection.close()


class InProcNode(_Node):
    """
    This is an RPC node for communicating between threads in the same
    process with the same semantics as if they were in different processes.
    Create an instance of this class and run it in the "server" thread.
    Then pass that instance as a parameter to connect() or
    getInProcConnection() to send messages to that node from a "client"
    thread.
    """

    STOP_MARKER = object()

    def __init__(self, handlerFactory):
        self.handlerFactory = handlerFactory
        self._disconnectHandler = self._defaultDisconnectHandler
        self._errorHandler = self._defaultErrorHandler
        self._incoming_q = Queue.Queue()
        self._connection_map = {} # {remote_socket: local_connection}

    def _connect(self, socket_wrapper, connection):
        self._connection_map[socket_wrapper] = connection

    def __repr__(self):
        return '%s(%s)' % (self.__class__.__name__, self.handlerFactory)

    def __str__(self):
        return 'Local Node'

    def _postMessage(self, sock_wrapper, message):
        self._incoming_q.put((sock_wrapper, message))


    def poll(self, timeout=None):
        """
        Poll the RPC node. Dispatch all messages that are ready.

        timeout -- the number of seconds to block waiting for activity,
                   0 to not block at all, and None to block indefinitely.

        Return True if the node is still running, False if it has
        been stopped.
        """
        block = True
        if not timeout and timeout is not None:
            block = False
        try:
            msg = self._incoming_q.get(block, timeout)
        except Queue.Empty:
            return True
        if msg is self.STOP_MARKER:
            return False
        sock_wrapper, message = msg
        connection = self._connection_map.get(sock_wrapper)
        if connection:
            if message is None:
                del self._connection_map[sock_wrapper]
                self._handleDisconnect(connection)
            else:
                connection._handleMessage(message)
        return True

    def run(self):
        """
        Run the RPC node.
        """
        while self.poll(timeout=60.0):
            pass

    def stop(self):
        """
        Close all current connections and stop accepting new requests.
        Results that are already being transmitted will be allowed to
        complete.
        """
        self._incoming_q.put(self.STOP_MARKER)

    def onDisconnectCall(self, disconnectHandler):
        """
        Call disconnectedHandler(connection) when the remote node closes
        the connection normally.

        Return the previous error handler.
        """
        prev_handler = self._disconnectHandler
        self._disconnectHandler = disconnectHandler
        return prev_handler

    def onErrorCall(self, errorHandler):
        """
        Call errorHandler(connection, error) on communication errors, such
        as failure to connect.

        Return the previous error handler.
        """
        prev_handler = self._errorHandler
        self._errorHandler = errorHandler
        return prev_handler

    def dump(self, obj):
        """Convert an object to transport format"""
        return copy.copy(obj)

    def load(self, s):
        """Convert from transport format to Python object"""
        return s

    def _handleDisconnect(self, connection):
        if self._disconnectHandler:
            self._disconnectHandler(connection)

    def _handleError(self, connection, error):
        if self._errorHandler:
            self._errorHandler(connection, error)

    def _defaultDisconnectHandler(self, connection):
        if getattr(self, 'debug', False):
            print connection, "disconnected"
        connection.close()

    def _defaultErrorHandler(self, connection, error):
        sys.stderr.write("Error on " + repr(connection) + ": " +
                         str(error) + '\n')
        connection.close()


class _InProcClient(InProcNode):
    """
    This is like InProcNode, except that it shuts itself down when the
    number of connections goes from one to zero. It is normally constructed
    by the connect() function.
    """

    def _handleDisconnect(self, connection):
        super(_InProcClient, self)._handleDisconnect(connection)
        if not self._connection_map:
            # Last one out turn off the lights
            self.stop()


class Listener(_Node):
    """
    Listens for incoming RPC connections.
    """

    def __init__(self, addr, handlerFactory, 
                 dumperFactory=cPickle.Pickler,
                 loaderFactory=LimitedUnpickler):
        """
        addr:
            TCP address (host, port) for listening for incoming connections,
            or a socket object set up for listening.
        handlerFactory:
            Called when a connection is accepted, with one parameter, the
            Connection object. It should return the object that will handle
            method calls from the remote RPC node.
        dumperFactory:
            Called to create a new object dumper (marshaller) for each
            message sent. Takes one parameter, an output file object.
        loaderFactory:
            Called to create a new object loader (unmarshaller) for each
            message received. Takes one parameter, an input file object.
        """
        super(Listener, self).__init__(addr, handlerFactory, dumperFactory, 
                       loaderFactory)
        _RpcListenerSocket(addr, self).start(self._mux)


class Connector(_Node):
    """
    Connects to an RPC listener.
    """

    def __init__(self, addr, handlerFactory, 
                 dumperFactory=cPickle.Pickler,
                 loaderFactory=LimitedUnpickler):
        """
        addr:
            TCP address (host, port) of the destination RPC node
        handlerFactory:
            Called when the connection is complete, with one parameter, the
            Connection object. It should return the object that will handle
            method calls from the remote RPC node.
        dumperFactory:
            Called to create a new object dumper (marshaller) for each
            message sent. Takes one parameter, an output file object.
        loaderFactory:
            Called to create a new object loader (unmarshaller) for each
            message received. Takes one parameter, an input file object.
        """
        super(Connector, self).__init__(addr, handlerFactory, dumperFactory, 
                       loaderFactory)
        _RpcConnectingSocket(addr, self).start(self._mux)


class Connection(object):
    """
    Connection with the remote RPC node.

    Sends method calls to the remote object, routes return values to the
    appropriate callback, and handles calls originating from the remote
    object.
    """

    def __init__(self, sock_wrapper, node):
        """
        Called by the _SocketWrapper when a connection is established with
        the remote rpc node.
        """
        self.__sock_wrapper = sock_wrapper
        self.__node = node
        self.__seq_num = 0
        self.__seq_num_lock = threading.Lock()
        self.__seq_num_callback_map = {}
        self.__closed = False
        # do this last, don't know what handlerFactory will do with self
        factory = node.handlerFactory
        if factory:
            self.__handler = factory(self)

    def __repr__(self):
        return 'Connection(%s, %s)' % (repr(self.__sock_wrapper), 
                                       repr(self.__node))

    def __str__(self):
        return 'Connection(%s, %s)' % (self.__sock_wrapper, self.__node)

    def __getnode(self): return self.__node

    node = property(__getnode, doc="""
    The RPC node (Listener or Connector) that manages this Connection.
    Call node.stop() to shut down all connections. Use with care.
    """)

    def close(self):
        """Close this connection with the remote RPC node."""
        self.__closed = True
        self._cancelAllCalls(None)
        self.__sock_wrapper.close()

    def setHandler(self, handler):
        self.__handler = handler

    def getHandler(self):
        return self.__handler

    def getBlockingProxy(self, timeout=None):
        return BlockingProxy(self, timeout)

    def getProxy(self, callback=None):
        return Proxy(self, callback)

    def callMethod(self, method, params, callback):
        """
        Call a remote procedure. A callback function will be called when the
        remote procedure call is complete.

        method: The name of the method on the remote server to call.
        params: The parameters to pass to the remote method.
        callback:
            A callable object that is called with a single parameter, an
            instance of Result, when the rpc call is complete. Call
            getResult() on the result object to get the return value (or
            raise the returned RemoteException.) Pass None for this
            parameter when you don't want the server to generate a reply
            message.
        """
        if callback:
            if not callable(callback):
                raise ValueError("Callback must be a callable object.")
            self.__seq_num_lock.acquire()
            try:
                self.__seq_num += 1
                seq_num = self.__seq_num
                self.__seq_num_callback_map[seq_num] = callback
            finally:
                self.__seq_num_lock.release()
        else:
            seq_num = None
        self.__sock_wrapper.sendMessage(self.__node.dump(
                                        (method, params, seq_num)))

    def callMethodAndBlock(self, method, params, timeout=None):
        """
        Call a remote procedure and wait for the return value.

        method: The name of the method on the remote server to call.
        params: The parameters to pass to the remote method.
        timeout:
            The time, in seconds, to wait for the remote process to
            complete the call, or None to wait indefinitely.

        Return the result of the remote procedure call.
        Raise RemoteException if there was an error executing the remote
        procedure. Raise a TimeoutException if the result doesn't come back
        in time.
        """
        waiter = _MethodWaiter(timeout)
        self.callMethod(method, params, waiter.callComplete)
        return waiter.getResult()

    def _handleMessage(self, message):
        # Called by the _SocketWrapper when a new message arrives.
        message_tuple = self.__node.load(message)
        try:
            method, params, seqnum = message_tuple
        except ValueError:
            # response was not a tuple, or not the right size
            # can't return error message to client without seqnum
            traceback.print_exc()
            return
        if method:
            m = getattr(self.__handler, method, None)
            if not m or method[0] == '_':
                response = MethodNotFound(method, self.__handler)
            else:
                try:
                    response = (m(*params),)
                except RemoteException, fault:
                    response = fault
                except:
                    # report exception back from server
                    response = ExceptionInMethod(method, self.__handler)
            if seqnum:
                self.__sock_wrapper.sendMessage(self.__node.dump(
                                                 (None, response, seqnum)))
            elif isinstance(response, RemoteException):
                sys.stderr.write("Unreported RemoteException in method:\n")
                sys.stderr.write("%s\n" % str(response))
        else:
            callback = self.__seq_num_callback_map.pop(seqnum, None)
            if callback:
                callback(Result(params))

    def _handleDisconnect(self):
        # Called by the _SocketWrapper when the remote end closes the socket
        self._cancelAllCalls(None)
        self.__node._handleDisconnect(self)

    def _handleError(self, error):
        # Called by the _SocketWrapper when there is an error on the socket
        if not self.__closed:
            self._cancelAllCalls(error)
            self.__node._handleError(self, error)
        else:
            sys.stdout.flush()

    def _cancelAllCalls(self, error):
        # Cancel all pending method calls with a Disconnected exception,
        # taking care to avoid race conditions. Failing to cancel a method
        # call could result in a call blocking forever when the socket is
        # closed.
        self._cancelAllCalls = self._callsAlreadyCanceled
        self.callMethod = self._noMoreMethodCalls
        while True:
            self.__seq_num_lock.acquire()
            try:
                callbacks = self.__seq_num_callback_map.values()
                self.__seq_num_callback_map = {}
            finally:
                self.__seq_num_lock.release()
            for callback in callbacks:
                try:
                    callback(Result(Disconnected(self, error)))
                except:
                    traceback.print_exc()
            time.sleep(0.001)
            if not self.__seq_num_callback_map:
                break

    def _callsAlreadyCanceled(self, error):
        pass # calling _cancelAllCalls only works the first time

    def _noMoreMethodCalls(self, method, params, callback):
        raise Disconnected(self, "This connection has been shut down.")


def connect(addr, 
            handlerFactory=None,
            timeout=None,
            disconnectHandler=None,
            errorHandler=None,
            dumperFactory=cPickle.Pickler,
            loaderFactory=LimitedUnpickler):
    """
    addr:
        TCP address (host, port) of the destination RPC node
    handlerFactory:
        Called when the connection is complete, with one parameter, the
        Connection object. It should return the object that will handle
        method calls from the remote RPC node. If this value is None,
        then the server cannot call back to the client.
    timeout:
        The time, in seconds, to wait for connection to the destination
        RPC node to complete, or None to wait indefinitely.
    disconnectHandler:
        disconnectHandler(connection) is called when the other side
        closes the connection normally.
    errorHandler:
        errorHandler(connection, error) is called on connection errors.
    dumperFactory:
        Called to create a new object dumper (marshaller) for each
        message sent. Takes one parameter, an output file object.
    loaderFactory:
        Called to create a new object loader (unmarshaller) for each
        message received. Takes one parameter, an input file object.

    Start a thread and run a Connector node in the background.  Return the
    Connection object when the connection is completed. (The thread keeps on
    running, processing messages.)

    Raise TimeoutException if the connection is not completed before
    the time expires.
    """
    if isinstance(addr, InProcNode):
        return _connectInProc(addr, handlerFactory)
    factory = _WaitingFactory(handlerFactory)
    connector = Connector(addr, factory, dumperFactory, loaderFactory)
    if disconnectHandler:
        connector.onDisconnectCall(disconnectHandler)
    if errorHandler:
        connector.onErrorCall(errorHandler)
    t = threading.Thread(target=connector.run)
    t.setDaemon(True) # Don't blame me, I didn't name it
    t.start()
    factory.event.wait(timeout)
    if not factory.event.isSet():
        connector.stop()
        raise TimeoutException
    return factory.connection


def _connectInProc(server, handlerFactory=None):
    """Connect to another "in-proc" node."""
    client, client_connection = getInProcConnection(server, handlerFactory)
    t = threading.Thread(target=client.run)
    t.setDaemon(True) # Don't blame me, I didn't name it
    t.start()
    return client_connection


def getInProcConnection(server, handlerFactory=None):
    """
    server -- The other InProcNode to which we want to connect
    handlerFactory -- callable to return client message handler

    Return (client, connection)
    client -- A new InProcNode that the client can poll()
    connection -- Connection for calling the server node
    """
    client = _InProcClient(handlerFactory)
    client_socket = _InProcSocketWrapper(client)
    server_socket = _InProcSocketWrapper(server)
    client_connection = Connection(server_socket, client)
    server_connection = Connection(client_socket, server)
    client._connect(client_socket, client_connection)
    server._connect(server_socket, server_connection)
    return client, client_connection


class _WaitingFactory:
    def __init__(self, realFactory):
        self.realFactory = realFactory
        self.event = threading.Event()
        self.connection = None
    def __call__(self, connection):
        self.connection = connection
        self.event.set()
        if self.realFactory:
            return self.realFactory(connection)
        else:
            return None


class _Method:
    # bind an RPC method to an RPC server.
    # (adapted from xmlrpclib._Method, minus the security bug)
    def __init__(self, send, name):
        self.__send = send
        self.__name = name
    def __call__(self, *args):
        return self.__send(self.__name, args)


class _MethodWaiter:
    def __init__(self, timeout=None):
        self.__timeout = timeout
        self.__event = threading.Event()
        self.__result = None
    def callComplete(self, result):
        self.__result = result
        self.__event.set()
    def getResult(self):
        self.__event.wait(self.__timeout)
        if not self.__event.isSet():
            raise TimeoutException()
        try:
            return self.__result.getResult()
        finally:
            self.__event.clear()
            self.__result = None
    def getTimeout(self):
        return self.__timeout


class Proxy:
    """
    Make calls to an rpc server.  The RPC framework will call a
    callback function when the result comes back from the server.
    """

    def __init__(self, connection, callback):
        """
        connection:
            The Connection that will transmit the procedure call and receive
            the return value.
        callback:
            A callable object that is called with a single parameter, an
            instance of Result, when the RPC call is complete. Call
            getResult() on the result object to get the return value (or to
            raise the returned RemoteException.)

            If this parameter is None then no value will be returned.
        """
        self._connection = connection
        self._callback = callback

    def _callMethod(self, name, params):
        self._connection.callMethod(name, params, self._callback)

    def __getattr__(self, name):
        # technique borrowed from xmlrpclib
        return _Method(self._callMethod, name)


class BlockingProxy:
    """
    Make calls to an rpc server and wait for the result.
    """

    def __init__(self, connection, timeout=None):
        """
        connection:
            The Connection that will transmit the procedure call
            and receive the return value.
        timeout:
            The time, in seconds, to wait for the remote process to
            complete the call, or None to wait indefinitely.
        """
        self._connection = connection
        self._waiter = _MethodWaiter(timeout)

    def _callMethod(self, name, params):
        self._connection.callMethod(name, params, self._waiter.callComplete)
        try:
            return self._waiter.getResult()
        except TimeoutException:
            # replace the waiter object in case the reply comes later
            # (don't want to get the return values confused)
            self._waiter = _MethodWaiter(self._waiter.getTimeout())
            raise

    def __getattr__(self, name):
        # technique borrowed from xmlrpclib
        return _Method(self._callMethod, name)


class Result:
    """
    Hold a result of an rpc call for later retrieval by the
    caller. Normally this class is only instantiated by the Connection.
    """

    def __init__(self, result_params):
        """
        result_params:
            Result item directly from the RPC message tuple.
            It should be either a tuple containing one item, the return
            value, or an instance of RemoteException.
        """
        self.__result = result_params

    def getResult(self):
        """
        Return the value returned from the remote procedure call (which may
        be None) or raise an exception if there was an error.

        If the error was on the server side then a RemoteException is
        raised, otherwise a LocalException is raised.
        """
        result = self.__result
        if type(result) == tuple:
            return result[0]
        elif isinstance(result, RpcException):
            raise result
        else:
            raise LocalException("Return value in wrong format", result)


class _InProcSocketWrapper:

    def __init__(self, node):
        self.__node = node

    def sendMessage(self, message):
        self.__node._postMessage(self, message)

    def close(self):
        self.__node._postMessage(self, None)

    def __repr__(self):
        return '_InProcSocketWrapper@%s' % id(self)

    def __str__(self):
        return 'InProc Socket'


def lookupInetHost(addr):
    """
    Look up the IP address of an internet host.

    addr -- hostname or (hostname, port) tuple

    Return one of the following:
        IP address string if addr was a string, or
        (ip_address, port) if addr was (host, port), or
        exception object if there was an error
    """
    try:
        if type(addr) == tuple:
            host, port = addr
        else:
            host = addr
        ip_addr = socket.gethostbyname(host)
        if type(addr) == tuple:
            return (ip_addr, port)
        else:
            return ip_addr
    except Exception, e: # normally will be socket.error
        return e


def lookupInetHostInBackground(addr, callback):
    """
    Look up the IP address of an internet host in a background thread.

    addr -- hostname or (hostname, port) tuple
    callback -- function to call when lookup is complete

    The lookup will occur in a background thread. The callback function will
    be called from within that thread, like this:

        callback(result)

        result --
            ip address string if addr was a string, or
            (ip_address, port) if addr was (host, port), or
            exception object if there was an error

    Return the new thread object. The caller can call join() on the returned
    object to wait for the lookup to complete, if desired.
    """
    def _lookup():
        callback(lookupInetHost(addr))
    t = threading.Thread(target=_lookup)
    t.setDaemon(True) # Don't blame me, I didn't name it
    t.start()
    return t


######################################################################
# Tests and test infrastructure
# 


class _Tester(object):
    
    def __init__(self):
        self.listener_start_event = threading.Event()

    def runSocketMultiplexer(self, addr):
        class TestListener(_TCPListener):
            def __init__(self, addr, tester):
                super(TestListener, self).__init__(addr)
                self.tester = tester
            def connectedWrapperFactory(self, sock):
                self.tester.listener_start_event.set()
                return EchoConnection(sock)
        class EchoConnection(_SocketWrapper):
            flush_on_shutdown = True
            def __init__(self, sock):
                super(EchoConnection, self).__init__(sock)
                self.__out = []
            def readable(self):
                return True
            def writeable(self):
                return len(self.__out)
            def handleRead(self):
                data = self.sock.recv(8192)
                if not data:
                    self.close()
                    return
                if 'stop' in data:
                    self.__out.append('goodbye')
                    self.mux.stop()
                    return
                self.__out.append(data)
            def handleWrite(self):
                data = self.__out.pop(0)
                self.sock.send(data)
        mux = _SocketMultiplexer()
        #mux.debug = True
        TestListener(addr, self).start(mux)
        mux.run()
        #print "Echo server done."
        sys.stdout.flush()

    def testSocketMultiplexer(self, addr=('127.0.0.1', 9114)):
        self.listener_start_event.clear()
        t = threading.Thread(target=self.runSocketMultiplexer, args=(addr,))
        t.start()
        self.listener_start_event.wait(2.0)
        class TestConnector(_TCPConnector):
            def __init__(self, addr):
                super(TestConnector, self).__init__(addr)
            def connectedWrapperFactory(self, sock):
                wrapper = EchoClient(sock)
                t = threading.Thread(target=wrapper.runEchoClient)
                t.start()
                return wrapper
        class EchoClient(_SocketWrapper):
            def __init__(self, sock):
                super(EchoClient, self).__init__(sock)
                self.__out = []
                self.__out_lock = threading.Lock()
                self.__event = threading.Event()
                self.__received = None
            def readable(self):
                return True
            def writeable(self):
                return len(self.__out)
            def handleRead(self):
                data = self.sock.recv(8192)
                print "Received", repr(data)
                sys.stdout.flush()
                if not data:
                    print "Remote end closed the socket."
                    sys.stdout.flush()
                    self.close()
                    return
                self.__received = data
                self.__event.set()
            def handleWrite(self):
                self.__out_lock.acquire()
                try:
                    data = self.__out.pop(0)
                finally:
                    self.__out_lock.release()
                print "Sending", repr(data)
                sys.stdout.flush()
                self.sock.send(data)
            # extension to interface
            def send(self, data):
                self.__event.clear()
                self.__out_lock.acquire()
                try:
                    self.__out.append(data)
                finally:
                    self.__out_lock.release()
                self.mux.kick()
            def runEchoClient(self):
                import time
                msg = 'Are we there yet?'
                self.send(msg)
                self.__event.wait(1.0)
                assert msg == self.__received
                self.send('stop')
                # That last send should cause the server to stop.
                # The server should send 'goodbye' and close the socket.
                self.__event.wait(1.0)
                assert 'goodbye' == self.__received
        mux = _SocketMultiplexer()
        #mux.debug = True
        TestConnector(addr).start(mux)
        mux.run()
        print "Echo client test complete."

    def testLookupInetHost(self):
        assert lookupInetHost('localhost') == '127.0.0.1'
        assert lookupInetHost(('localhost', 999)) == ('127.0.0.1', 999)
        assert isinstance(lookupInetHost('zzzzz.abababab.com'), socket.error)

    def testLookupInetHostInBackground(self):
        def _callback1(result):
            assert result == '127.0.0.1'
            print "Ok 1"
        t = lookupInetHostInBackground('localhost', _callback1)
        t.join()
        def _callback2(result):
            assert result == ('127.0.0.1', 999)
            print "Ok 2"
        t = lookupInetHostInBackground(('localhost', 999), _callback2)
        t.join()
        def _callback3(result):
            assert isinstance(result, socket.error)
            print "Ok 3"
        t = lookupInetHostInBackground('zzzzz.abababab.com', _callback3)
        t.join()

    def testRPC(self, addr=('127.0.0.1', 9096)):
        import threading
        # client message handler
        class Client:
            def __init__(self, connection):
                self.connection = connection
                t = threading.Thread(target=self._runTest)
                t.start()
            def callMe(self):
                print 'server called client'
            def _runTest(self):
                proxy = self.connection.getBlockingProxy(None)
                print proxy.add(3, 4)
                proxy.callMeBack()
                print proxy.stop()
        # start server
        self._runListener(addr)
        # start client
        connector = Connector(addr, Client)
        #connector._mux.debug = True
        connector.run()
        #print connector, 'done'
        # look at Client._runTest to see the rest of the test.

    def _runListener(self, addr):
        self.listener = Listener(addr, self.Server)
        #listener._mux.debug = True
        t = threading.Thread(target=self.listener.run)
        t.setDaemon(True)
        t.start()
        time.sleep(0.5) # give listener time to start

    # server message handler
    class Server:
        def __init__(self, connection):
            self.connection = connection
        def add(self, a, b):
            return a + b
        def callMeBack(self):
            self.connection.callMethod('callMe', (), None)
        def stop(self):
            self.connection.node.stop()
            return 'goodbye'

    def testExceptionHandling(self, addr=None):
        _TestSuite2(9937).run()

    def testInProcMessaging(self, addr=None):
        _TestSuite3().run()

    def main(self):
        if len(sys.argv) <= 1:
            # run all tests
            print
            print "Running doctests"
            print
            import doctest
            doctest.testmod(module, verbose=False)
            print
            print "Running regression tests"
            print
            pass_count = 0
            for method_name in dir(self):
                if method_name.startswith('test'):
                    print method_name
                    getattr(self, method_name)()
                    pass_count += 1
            print pass_count, "tests pass"
        else:
            # run a single test
            action = sys.argv[1]
            if len(sys.argv) >= 3:
                port = int(sys.argv[2])
            else:
                port = 9000
            addr = ('127.0.0.1', port)
            if hasattr(self, action):
                test_method = getattr(self, action)
                test_method(addr)
            else:
                print "Unknown action:", action


class _TestSuite2:
    # Focus on testing errors and exceptions

    class Server:
        def __init__(self, connection):
            self.connection = connection
        def add(self, a, b):
            return a + b
        def close(self):
            self.connection.close()
    
    def __init__(self, port):
        self.port = port
        self.connected_event = threading.Event()
        self.error_event = threading.Event()
        self.disconnected_event = threading.Event()
        self.connector_done = threading.Event()
        self.error_obj = None
        self.connection = None
        self.listener = None
        self.connector = None

    def _clientFactory(self, connection):
        self.connection = connection
        return None # server doesn't call client

    def _onError(self, connection, error):
        print connection, "had error", error
        assert connection is self.connection
        self.error_obj = error
        self.error_event.set()

    def _onDisconnect(self, connection):
        print connection, "closed by remote rpc node"; sys.stdout.flush()
        assert connection is self.connection
        self.disconnected_event.set()

    def test1(self):
        # server hasn't started yet - expect connection refused
        print "Expecting 'Connection refused' error"
        self._startConnector()
        self.error_event.wait(1.0)
        assert self.error_event.isSet()
        assert isinstance(self.error_obj, socket.error)
        self.connector_done.wait(1.0)
        assert self.connector_done.isSet()

    def test2(self):
        # react to server closing the connection
        print "Expecting 'closed by remote rpc node' message"
        self._startListener()
        self._startConnector()
        self.connected_event.wait(1.0)
        proxy = self.connection.getBlockingProxy(1.0)
        assert proxy.add(3, 4) == 7
        self.listener.stop() # close all connections
        self.disconnected_event.wait(1.0)
        assert self.disconnected_event.isSet()
        try:
            proxy.add(3, 4)
        except Disconnected:
            pass # Ok
        else:
            assert False, "Expected 'Disconnected' exception"
        self.connector_done.wait(1.0)
        assert self.connector_done.isSet()
        assert not self.error_event.isSet()

    def test3(self):
        # react to client closing the connection
        self._startListener()
        self._startConnector()
        self.connected_event.wait(1.0)
        proxy = self.connection.getBlockingProxy(1.0)
        assert proxy.add(3, 4) == 7
        self.connection.close()
        # we do *not* expect _onDisconnect when we close it ourselves
        self.disconnected_event.wait(1.0)
        assert not self.disconnected_event.isSet()
        try:
            proxy.add(3, 4)
        except Disconnected:
            pass # Ok
        else:
            assert False, "Expected 'Disconnected' exception"
        self.connector_done.wait(1.0)
        assert self.connector_done.isSet()
        assert not self.error_event.isSet()
        # re-connect to server to make sure it didn't die
        self._startConnector()
        self.connected_event.wait(1.0)
        proxy = self.connection.getBlockingProxy(1.0)
        assert proxy.add(4, 5) == 9
        print "Expecting 'closed by remote rpc node' message"
        sys.stdout.flush()
        self.listener.stop()
        self.disconnected_event.wait(1.0)
        assert self.disconnected_event.isSet()
        
    def _createConnector(self):
        # create a connector object, re-initialize test attributes
        if self.connection:
            self.connection.node.stop()
            self.connection = None
        self.connected_event.clear()
        self.error_event.clear()
        self.disconnected_event.clear()
        self.connector_done.clear()
        self.error_obj = None
        connector = Connector(('127.0.0.1', self.port), self._clientFactory)
        connector.onErrorCall(self._onError)
        connector.onDisconnectCall(self._onDisconnect)
        return connector

    def _startConnector(self):
        self.connector = self._createConnector()
        t = threading.Thread(target=self._runConnector)
        t.setDaemon(True)
        t.start()

    def _runConnector(self):
        self.connector.run()
        self.connector_done.set()

    def _startListener(self):
        self.listener = Listener(('127.0.0.1', self.port), self.Server)
        t = threading.Thread(target=self.listener.run)
        t.setDaemon(True)
        t.start()

    def run(self):
        for name in dir(self):
            if name.startswith('test'):
                print self.__class__.__name__, name; sys.stdout.flush()
                getattr(self, name)()

class _TestSuite3:
    # Focus on testing "InProc" (within the same process) messaging
    # The idea is to have nearly the same interface as if communicating
    # with a remote node.

    def test1(self):
        class InProcServerHandler:
            def __init__(self, connection):
                self.connection = connection
            def add(self, a, b):
                return a + b
            def serverHello(self):
                print "InProc server says hello"
                proxy = self.connection.getProxy(None)
                proxy.clientHello()
            def stop(self):
                self.connection.node.stop()
        server = InProcNode(InProcServerHandler)
        t = threading.Thread(target=self._test1Client, args=(server,))
        t.setDaemon(False) # Don't blame me, I didn't name this method
        t.start()
        server.run()
        print "InProc server stopped properly"

    def _test1Client(self, server):
        class InProcClientHandler:
            def __init__(self, connection):
                self.connection = connection
            def clientHello(self):
                print "InProc client says hello"
        connection = connect(server, InProcClientHandler)
        proxy = connection.getBlockingProxy()
        print "3 + 2 =", proxy.add(3, 2)
        connection.callMethod('serverHello', (), None)
        proxy = connection.getProxy(None)
        proxy.stop()
        connection.close()
        connection.node.stop()
        print "InProc client done"
                
    def run(self):
        for name in dir(self):
            if name.startswith('test'):
                print self.__class__.__name__, name; sys.stdout.flush()
                getattr(self, name)()


if __name__ == '__main__':
    _Tester().main()

# end-of-file
