How to cancel tasks in anyio.TaskGroup context?

190 views Asked by At

I write a script to find out the fastest one in a list of cdn hosts:

#!/usr/bin/env python3.11
import time
from contextlib import contextmanager
from enum import StrEnum

import anyio
import httpx


@contextmanager
def timeit(msg: str):
    start = time.time()
    yield
    cost = time.time() - start
    print(msg, f"{cost = }")


class CdnHost(StrEnum):
    jsdelivr = "https://cdn.jsdelivr.net/npm/[email protected]/swagger-ui.css"
    unpkg = "https://unpkg.com/[email protected]/swagger-ui.css"
    cloudflare = (
        "https://cdnjs.cloudflare.com/ajax/libs/swagger-ui/5.9.0/swagger-ui.css"
    )


TIMEOUT = 5
LOOP_INTERVAL = 0.1


async def fetch(client, url, results, index):
    try:
        r = await client.get(url)
    except (httpx.ConnectError, httpx.ReadError):
        ...
    else:
        print(f"{url = }\n{r.elapsed = }")
        if r.status_code < 300:
            results[index] = r.content


class StopNow(Exception):
    ...


async def find_fastest_host(timeout=TIMEOUT, loop_interval=LOOP_INTERVAL) -> str:
    urls = list(CdnHost)
    results = [None] * len(urls)
    try:
        async with anyio.create_task_group() as tg:
            with anyio.move_on_after(timeout):
                async with httpx.AsyncClient() as client:
                    for i, url in enumerate(urls):
                        tg.start_soon(fetch, client, url, results, i)
                    for _ in range(int(timeout / loop_interval) + 1):
                        for res in results:
                            if res is not None:
                                raise StopNow
                        await anyio.sleep(0.1)
    except (
        StopNow,
        httpx.ReadError,
        httpx.ReadTimeout,
        httpx.ConnectError,
        httpx.ConnectTimeout,
    ):
        ...
    for url, res in zip(urls, results):
        if res is not None:
            return url
    return urls[0]


async def main():
    with timeit("Sniff hosts"):
        url = await find_fastest_host()
    print("cdn host:", CdnHost)
    print("result:", url)


if __name__ == "__main__":
    anyio.run(main)

There are three cdn hosts (https://cdn.jsdelivr.net, https://unpkg.com, https://cdnjs.cloudflare.com). I make three concurrent async task to fetch them by httpx. If one of them get a response with status_code<300, then stop all task and return the right url. But I don't know how to cancel tasks without using a custom exception (in the script is StopNow).

2

There are 2 answers

2
blhsing On BEST ANSWER

You can call the cancel method of the cancel_scope attribute of the task group to cancel all of its tasks:

async with anyio.create_task_group() as tg:
    ...
    tg.cancel_scope.cancel()
0
Waket Zheng On

Thanks blhsing tg.cancel_scope.cancel() did work.

Here is the final code:

#!/usr/bin/env python3.11
import time
from contextlib import contextmanager
from enum import StrEnum

import anyio
import httpx


@contextmanager
def timeit(msg: str):
    start = time.time()
    yield
    cost = time.time() - start
    print(msg, f"{cost = }")


class CdnHost(StrEnum):
    jsdelivr = "https://cdn.jsdelivr.net/npm/[email protected]/swagger-ui.css"
    unpkg = "https://unpkg.com/[email protected]/swagger-ui.css"
    cloudflare = (
        "https://cdnjs.cloudflare.com/ajax/libs/swagger-ui/5.9.0/swagger-ui.css"
    )



async def fetch(client, url, results, index) -> None:
    try:
        r = await client.get(url)
    except (httpx.ConnectError, httpx.ReadError):
        ...
    else:
        print(f"{url = }\n{r.elapsed = }")
        if r.status_code < 300:
            results[index] = r.content


async def find_fastest_host(urls: List[str], total_seconds=5, loop_interval=0.1) -> str:
    results = [None] * len(urls)
    async with (
        anyio.create_task_group() as tg,
        httpx.AsyncClient(timeout=total_seconds) as client,
    ):
        for i, url in enumerate(urls):
            tg.start_soon(fetch, client, url, results, i)
        for _ in range(int(total_seconds / loop_interval) + 1):
            if any(r is not None for r in results):
                tg.cancel_scope.cancel()
                break
            await anyio.sleep(loop_interval)
    for url, res in zip(urls, results):
        if res is not None:
            return url
    return urls[0]


async def main():
    with timeit("Sniff hosts"):
        url = await find_fastest_host(list(CdnHost))
    print("cdn host list:", [i.name for i in CdnHost])
    print("result:", url)


if __name__ == "__main__":
    anyio.run(main)