Asyncio persisent client protocol class using queue

4.2k views Asked by At

I'm trying to get my head around the Python 3 asyncio module, in particular using the transport/protocol API. I want to create a publish/subscribe pattern, and use the asyncio.Protocol class to create my client and server.

At the moment I've got the server up and running, and listening for incoming client connections. The client is able to connect to the server, send a message and receive the reply.

I would like to be able to keep the TCP connection alive and maintain a queue that allows me to add messages. I've tried to find a way to do this using the low-level API (Transport/Protocols) but the limited asyncio docs/examples online all seem to go into the high level API - using streams, etc. Would someone be able to point me in the right direction on how to implement this?

Here's the server code:

#!/usr/bin/env python3

import asyncio
import json


class SubscriberServerProtocol(asyncio.Protocol):
    """ A Server Protocol listening for subscriber messages """

    def connection_made(self, transport):
        """ Called when connection is initiated """

        self.peername = transport.get_extra_info('peername')
        print('connection from {}'.format(self.peername))
        self.transport = transport

    def data_received(self, data):
        """ The protocol expects a json message containing
        the following fields:

            type:       subscribe/unsubscribe
            channel:    the name of the channel

        Upon receiving a valid message the protocol registers
        the client with the pubsub hub. When succesfully registered
        we return the following json message:

            type:           subscribe/unsubscribe/unknown
            channel:        The channel the subscriber registered to
            channel_count:  the amount of channels registered
        """

        # Receive a message and decode the json output
        recv_message = json.loads(data.decode())

        # Check the message type and subscribe/unsubscribe
        # to the channel. If the action was succesful inform
        # the client.
        if recv_message['type'] == 'subscribe':
            print('Client {} subscribed to {}'.format(self.peername,
                                                      recv_message['channel']))
            send_message = json.dumps({'type': 'subscribe',
                                       'channel': recv_message['channel'],
                                       'channel_count': 10},
                                      separators=(',', ':'))
        elif recv_message['type'] == 'unsubscribe':
            print('Client {} unsubscribed from {}'
                  .format(self.peername, recv_message['channel']))
            send_message = json.dumps({'type': 'unsubscribe',
                                       'channel': recv_message['channel'],
                                       'channel_count': 9},
                                      separators=(',', ':'))
        else:
            print('Invalid message type {}'.format(recv_message['type']))
            send_message = json.dumps({'type': 'unknown_type'},
                                      separators=(',', ':'))

        print('Sending {!r}'.format(send_message))
        self.transport.write(send_message.encode())

    def eof_received(self):
        """ an EOF has been received from the client.

        This indicates the client has gracefully exited
        the connection. Inform the pubsub hub that the
        subscriber is gone
        """
        print('Client {} closed connection'.format(self.peername))
        self.transport.close()

    def connection_lost(self, exc):
        """ A transport error or EOF is seen which
        means the client is disconnected.

        Inform the pubsub hub that the subscriber has
        Disappeared
        """
        if exc:
            print('{} {}'.format(exc, self.peername))


loop = asyncio.get_event_loop()

# Each client will create a new protocol instance
coro = loop.create_server(SubscriberServerProtocol, '127.0.0.1', 10666)
server = loop.run_until_complete(coro)

# Serve requests until Ctrl+C
print('Serving on {}'.format(server.sockets[0].getsockname()))
try:
    loop.run_forever()
except KeyboardInterrupt:
    pass

# Close the server
try:
    server.close()
    loop.until_complete(server.wait_closed())
    loop.close()
except:
    pass

And here's the client code:

#!/usr/bin/env python3

import asyncio
import json


class SubscriberClientProtocol(asyncio.Protocol):
    def __init__(self, message, loop):
        self.message = message
        self.loop = loop

    def connection_made(self, transport):
        """ Upon connection send the message to the
        server

        A message has to have the following items:
            type:       subscribe/unsubscribe
            channel:    the name of the channel
        """
        transport.write(self.message.encode())
        print('Message sent: {!r}'.format(self.message))

    def data_received(self, data):
        """ After sending a message we expect a reply
        back from the server

        The return message consist of three fields:
            type:           subscribe/unsubscribe
            channel:        the name of the channel
            channel_count:  the amount of channels subscribed to
        """
        print('Message received: {!r}'.format(data.decode()))

    def connection_lost(self, exc):
        print('The server closed the connection')
        print('Stop the event loop')
        self.loop.stop()

if __name__ == '__main__':
    message = json.dumps({'type': 'subscribe', 'channel': 'sensor'},
                         separators=(',', ':'))

    loop = asyncio.get_event_loop()
    coro = loop.create_connection(lambda: SubscriberClientProtocol(message,
                                                                   loop),
                                  '127.0.0.1', 10666)
    loop.run_until_complete(coro)
    try:
        loop.run_forever()
    except KeyboardInterrupt:
        print('Closing connection')
    loop.close()
1

There are 1 answers

2
dano On BEST ANSWER

Your server is fine as-is for what you're trying to do; your code as written actually keeps the TCP connection alive, it's you just don't have the plumbing in place to continously feed it new messages. To do that, you need to tweak the client code so that you can feed new messages into it whenever you want, rather than only doing it when the connection_made callback fires.

This is easy enough; we'll add an internal asyncio.Queue to the ClientProtocol which can receive messages, and then run a coroutine in an infinite loop that consumes the messages from that Queue, and sends them on to the server. The final piece is to actually store the ClientProtocol instance you get back from the create_connection call, and then pass it to a coroutine that actually sends messages.

import asyncio
import json

class SubscriberClientProtocol(asyncio.Protocol):
    def __init__(self, loop):
        self.transport = None
        self.loop = loop
        self.queue = asyncio.Queue()
        self._ready = asyncio.Event()
        asyncio.async(self._send_messages())  # Or asyncio.ensure_future if using 3.4.3+

    @asyncio.coroutine
    def _send_messages(self):
        """ Send messages to the server as they become available. """
        yield from self._ready.wait()
        print("Ready!")
        while True:
            data = yield from self.queue.get()
            self.transport.write(data.encode('utf-8'))
            print('Message sent: {!r}'.format(message))

    def connection_made(self, transport):
        """ Upon connection send the message to the
        server

        A message has to have the following items:
            type:       subscribe/unsubscribe
            channel:    the name of the channel
        """
        self.transport = transport
        print("Connection made.")
        self._ready.set()

    @asyncio.coroutine
    def send_message(self, data):
        """ Feed a message to the sender coroutine. """
        yield from self.queue.put(data)

    def data_received(self, data):
        """ After sending a message we expect a reply
        back from the server

        The return message consist of three fields:
            type:           subscribe/unsubscribe
            channel:        the name of the channel
            channel_count:  the amount of channels subscribed to
        """
        print('Message received: {!r}'.format(data.decode()))

    def connection_lost(self, exc):
        print('The server closed the connection')
        print('Stop the event loop')
        self.loop.stop()

@asyncio.coroutine
def feed_messages(protocol):
    """ An example function that sends the same message repeatedly. """
    message = json.dumps({'type': 'subscribe', 'channel': 'sensor'},
                         separators=(',', ':'))
    while True:
        yield from protocol.send_message(message)
        yield from asyncio.sleep(1)

if __name__ == '__main__':
    message = json.dumps({'type': 'subscribe', 'channel': 'sensor'},
                         separators=(',', ':'))

    loop = asyncio.get_event_loop()
    coro = loop.create_connection(lambda: SubscriberClientProtocol(loop),
                                  '127.0.0.1', 10666)
    _, proto = loop.run_until_complete(coro)
    asyncio.async(feed_messages(proto))  # Or asyncio.ensure_future if using 3.4.3+
    try:
        loop.run_forever()
    except KeyboardInterrupt:
        print('Closing connection')
    loop.close()