Interactions between asyncio, pytest, and SQLAlchemy

152 views Asked by At

I have a Python 3.8 program which runs a coroutine that loops infinitely, scheduling other coroutines in the background via create_task but not awaiting them. These coroutines use SQLAlchemy's AsyncEngine to write some data to a database. If an exception is raised, the handler cancels the looping coroutine and then awaits the remaining background tasks via gather. This all seems to work fine.

Stripped-down example:

async def write_loop():
  async_engine: AsyncEngine = get_async_engine()
  while True:
    data = await get_some_data()
    bg_task: Task = asyncio.create_task(write_to_db(async_engine, data))
    outstanding_tasks.add(bg_task)
    bg_task.add_done_callback(outstanding_tasks.discard)

async def main():
  try:
    await write_loop()
  except:
    await asyncio.gather(*outstanding_tasks)

I'm using pytest and pytest-asyncio to implement some test cases exercising the above. The exception is forced via some patching, so the line with gather is the last thing that should happen. Each test case is decorated with @pytest.mark.asyncio to ensure that none share an event loop. They all pass when run individually. When running multiple at the same time, the SQLAlchemy writes would fail with cannot perform operation: another operation is in progress. Googling told me that I needed to turn off connection pooling in test, which seems to imply that the different tests are somehow sharing an AsyncEngine (or at least the underlying connection pool?) despite each creating its own.

Having done that, I no longer get the above error. However, now if I run too many tests at once, at least one times out (no matter how high I set the timeout). I can see that the point where it hangs is await asyncio.gather(*outstanding_tasks). If I print the remaining tasks before calling gather, I see that they are in the pending state and have wait_for=<_GatheringFuture pending cb=[<TaskWakeupMethWrapper object]> (in addition to the callback that I attached).

I have the following questions:

  • Why are the tests sharing anything in the way of the AsyncEngine? Does this have to do with the fact (I think) that they run in the same process?
  • What does it mean that, when enough tests are run, the last few tasks never seem to get scheduled? My only guess was that they needed extra time, but I think I've disproven that with high timeouts.

In case it matters, the pytest functions are also parameterized such that each will be run multiple times with different arguments.

Edit: I should also mention that this behavior is not deterministic. Occasionally all tests pass.

1

There are 1 answers

1
Benyamin Jafari On

One possible solution that I personally prefer, is having a clean database for each test using fixture. In this way, you would need refactoring to the write_loop() coroutine method to have db session as dependency injection instead of defining AsyncEngine inside the method and passing it to write_to_db and then reaching to db async session.

from sqlalchemy.ext.asyncio import AsyncSession

async def write_loop(session: AsyncSession):
  async_engine: AsyncEngine = get_async_engine()
  while True:
    data = await get_some_data()
    bg_task: Task = asyncio.create_task(write_to_db(session, data))
    ...

Now, in this way, you can pass the test db session that could be an in-memory sqlite db session:

# client.py
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine

from module import models

DATABASE_URL = "sqlite+aiosqlite:///:memory:"
engine: AsyncEngine = create_async_engine(DATABASE_URL)
async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)

async def create_tables():
    async with engine.begin() as conn:
        await conn.run_sync(models.Base.metadata.create_all)

async def drop_tables():
    async with engine.begin() as conn:
        await conn.run_sync(models.Base.metadata.drop_all)
#conftest.py
import pytest_asyncio
from sqlalchemy.ext.asyncio import AsyncSession
from . import client

@pytest_asyncio.fixture()
async def db() -> AsyncSession:
    async with client.async_session() as session:
        await client.create_tables()
        yield session
        await client.drop_tables()
# test_something.py
import pytest
from interruptingcow import timeout
from module import functions

@pytest.mark.asyncio
async def test_get_something(db: AsyncSession):
    try:
        with timeout(10, exception=asyncio.CancelledError):
            await functions.write_loop(session=db)
            assert False
    except:
        assert True

Another refactoring point that might be beneficial is splitting the write_loop() into two new functions namely run_once() and run_forever()run_once() is inside of an infinite loop called run_forever(). In this way, you could easily test the run_once() function without needing an extra package such as interruptingcow for an interruption via a timeout.