I have the case where there are several models, some of which make s3 read/write calls. I want to execute a predict() function in parallel but some of those predicts have async calls in them, including the calls to s3. So my code looks something like this:
class ModelA(BaseModel):
def predict(self, df):
# some prediction logic
class ModelB(BaseModel):
def predict(self, df):
preds = []
with aiohttp.ClientSession() as session:
for company, comp_df in df.groupby('company_id'):
path = f'{some_path}/{company}.pqt'
exists = await key_exists(WRITE_BUCKET, path)
if exists:
# pre-generated OCR output that model relies on
preds_df = await self.read_from_s3(path, session)
else:
preds_df = run_ocr(comp_df)
# some prediction logic
async def key_exists(bucket, key):
async with aioboto3.Session().client("s3") as s3_client:
try:
LOGGER.info("entering get object in key_exists")
await s3_client.head_object(Bucket=bucket, Key=key)
LOGGER.info("exiting get object in key_exists")
except Exception:
return False
return True
async def run_model(model):
"""
Wrapper function to run the model.
"""
# runs underlying predict()
await model.run_predictions()
async def task_gatherer(subtasks: list):
"""Wrapper for gather function."""
return await asyncio.gather(*subtasks)
def get_model_tasks(model):
subtasks = [run_model(model)]
result = asyncio.run(task_gatherer(subtasks))
return result
async def run_async_models(async_models):
with ProcessPoolExecutor(mp_context=mp.get_context('fork')) as pool:
event_loop = asyncio.get_running_loop()
master_tasks = [
event_loop.run_in_executor(
pool,
get_model_tasks,
model,
)
for model in async_models
]
await asyncio.gather(*master_tasks)
Now if I run as a standalone script, it runs fine.
def main():
individual_models = [ModelA(), ModelB()]
for model in individual_models:
asyncio.run(model.populate_table())
Output:
2024-02-01 05:24:25,571 - INFO - async_utils - entering get object in key_exists
2024-02-01 05:24:25,638 - INFO - async_utils - exiting get object in key_exists
But if I run in processpoolExecutor, it hangs indefinitely, specifically at the key_exists function, which does a call to s3.head_object():
async_models = [
# add new models here:
ModelA(),
ModelB()
]
# Run these models in parallel
asyncio.run(run_async_models(async_models))
I've tried sharing the session between the key_exists() function and read_from_s3() function but it doesn't make a difference. I've also tried spawning the processing instead of forking, but that doesn't matter either.
Turns out I needed to use ThreadPoolExecutor, not ProcessPoolExecutor. After making the switch it runs to completion.