Correct destruction process for async code running in a thread

250 views Asked by At

Below is (working) code for a generic websocket streamer.

It creates a daemon thread from which performs asyncio.run(...).

The asyncio code spawns 2 tasks, which never complete.

How to correctly destroy this object?

One of the tasks is executing a keepalive 'ping', so I can easily exit that loop using a flag. But the other is blocking on a message from the websocket.

import json
import aiohttp
import asyncio
import gzip

import asyncio
from threading import Thread

class WebSocket:
    KEEPALIVE_INTERVAL_S = 10

    def __init__(self, url, on_connect, on_msg):
        self.url = url
        self.on_connect = on_connect
        self.on_msg = on_msg

        self.streams = {}
        self.worker_thread = Thread(name='WebSocket', target=self.thread_func, daemon=True).start()

    def thread_func(self):
        asyncio.run(self.aio_run())

    async def aio_run(self):
        async with aiohttp.ClientSession() as session:

            self.ws = await session.ws_connect(self.url)

            await self.on_connect(self)

            async def ping():
                while True:
                    print('KEEPALIVE')
                    await self.ws.ping()
                    await asyncio.sleep(WebSocket.KEEPALIVE_INTERVAL_S)

            async def main_loop():
                async for msg in self.ws:
                    def extract_data(msg):
                        if msg.type == aiohttp.WSMsgType.BINARY:
                            as_bytes = gzip.decompress(msg.data)
                            as_string = as_bytes.decode('utf8')
                            as_json = json.loads(as_string)
                            return as_json

                        elif msg.type == aiohttp.WSMsgType.TEXT:
                            return json.loads(msg.data)

                        elif msg.type == aiohttp.WSMsgType.ERROR:
                            print('⛔️ aiohttp.WSMsgType.ERROR')

                        return msg.data

                    data = extract_data(msg)

                    self.on_msg(data)

            # May want this approach if we want to handle graceful shutdown
            # W.task_ping = asyncio.create_task(ping())
            # W.task_main_loop = asyncio.create_task(main_loop())

            await asyncio.gather(
                ping(),
                main_loop()
            )

    async def send_json(self, J):
        await self.ws.send_json(J)

2

There are 2 answers

0
Łukasz Kwieciński On

I'd suggest the use of asyncio.run_coroutine_threadsafe instead of asyncio.run. It returns a concurrent.futures.Future object which you can cancel:

def thread_func(self):
    self.future = asyncio.run_coroutine_threadsafe(
        self.aio_run(), 
        asyncio.get_event_loop()
    )

# somewhere else
self.future.cancel()

Another approach would be to make ping and main_loop a task, and cancel them when necessary:

# inside `aio_run`
self.task_ping = asyncio.create_task(ping())
self.main_loop_task = asyncio.create_task(main_loop())

await asyncio.gather(
    self.task_ping,
    self.main_loop_task
    return_exceptions=True
)


# somewhere else
self.task_ping.cancel()
self.main_loop_task.cancel()

This doesn't change the fact that aio_run should also be called with asyncio.run_coroutine_threadsafe. asyncio.run should be used as a main entry point for asyncio programs and should be only called once.

0
alex_noname On

I would like to suggest one more variation of the solution. When finishing coroutines (tasks), I prefer minimizing the use of cancel() (but not excluding), since sometimes it can make it difficult to debug business logic (keep in mind that asyncio.CancelledError does not inherit from an Exception).

In your case, the code might look like this(only changes):

class WebSocket:
    KEEPALIVE_INTERVAL_S = 10

    def __init__(self, url, on_connect, on_msg):
        # ...      
        self.worker_thread = Thread(name='WebSocket', target=self.thread_func)
        self.worker_thread.start()

    async def aio_run(self):
        self._loop = asyncio.get_event_loop()
        # ...
 
        self._ping_task = asyncio.create_task(ping())
        self._main_task = asyncio.create_task(main_loop())

        await asyncio.gather(
            self._ping_task,
            self._main_task,
            return_exceptions=True
        )
        # ...

    async def stop_ping(self):
        self._ping_task.cancel()
        try:
            await self._ping_task
        except asyncio.CancelledError:
            pass

    async def _stop(self):
        # wait ping end before socket closing
        await self.stop_ping()
        # lead to correct exit from `async for msg in self.ws`
        await self.ws.close()

    def stop(self):
        # wait stopping ping and closing socket
        asyncio.run_coroutine_threadsafe(
            self._stop(), self._loop
        ).result() 
        self.worker_thread.join()  # wait thread finish