Mocks passed to LangChain are not preserved

310 views Asked by At

I use LangChain in Python and I am using the following dependencies for testing:

pytest ~= 7.4.0
pytest-asyncio ~= 0.21.1
pytest-mock ~= 3.11.1

I wrote a custom chain implementation, similar to:

class FooChain(Chain):
    retriever: BaseRetriever

    async def _aget_docs(self, query: str) -> list[Document]:
        return await self.retriever.aget_relevant_documents(query)

    # ...

Trying to test the code as shown below:

async def test_aget_docs(fake_llm: FakeLLM, combine_docs_chain: BaseCombineDocumentsChain):
    document = Document(page_content="Foo bar baz")

    retriever = AsyncMock(BaseRetriever)
    retriever.aget_relevant_documents.return_value = [document] * 4

    foo_chain = FooChain(retriever=retriever)

    await foo_chain._aget_docs(query="foo")

    retriever.aget_relevant_documents.assert_awaited_once_with("foo", ANY)

Now, when running the test, I'm getting:

>       docs = await self.retriever.aget_relevant_documents(query, callbacks=run_manager.get_child())
E       TypeError: object MagicMock can't be used in 'await' expression

and, going in with a debugger, I can see that self.retriever is shown as <MagicMock name='mock._copy_and_set_values()' id='4960019664'>. This seems to me like something might be somehow changing the values I'm passing, but I can't really figure out how or where this would happen.

I know LangChain Chain class extends (eventually) the BaseModel in Pydantic, but while trying to reproduce the issue with just some plain Pydantic model, I couldn't make the same happen.

Any clue where I might be going wrong with this?

0

There are 0 answers