Filter langchain vector database using as_retriever search_kwargs parameter

5.9k views Asked by At

How to filter a langchain vector database using search_kwargs parameter from the as_retriever function ?

Here is an example of what I would like to do :

# Let´s say I have the following vector database
db = {'3c3bc745': Document(page_content="This is my text A", metadata={'Field_1': 'S', 'Field_2': 'R'}),
        '14f84778': Document(page_content="This is my text B", metadata={'Field_1': 'S', 'Field_2': 'V'}),
        'bd0022c9-449b': Document(page_content="This is my text C", metadata={'Field_1': 'Z', 'Field_2': 'V'})}


# Filter the vector database
retriever = db.as_retriever(search_kwargs={'filter': dict(Field_1='Z'), 'k': 1})

# Create the conversationnal chain
chain = ConversationalRetrievalChain.from_llm(llm=ChatOpenAI(temperature=0.0,
                                                         model_name='gpt-3.5-turbo',
                                                         deployment_id="chat"),
                                                        retriever=retriever)

chat_history = []
prompt = "Which sentences do you have ?"

# Expect to get only "This is my text C" but I get also get the two other page_content elements
chain({"question": prompt, "chat_history": chat_history}) 
2

There are 2 answers

0
Samuel On

If you are using Datastax Astra/Cassandra as VectorDB it would be something like:

import cassio
cassio.init(token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], database_id=os.environ["ASTRA_DB_ID"])

from langchain.vectorstores.cassandra import Cassandra
table_name = 'vs_investment_kb'
keyspace = 'demo'

CassVectorStore = Cassandra(
    session= cassio.config.resolve_session(),
    keyspace= keyspace,
    table_name= table_name,
    embedding=embedding_generator
)

retrieverSim = CassVectorStore.as_retriever(
    search_type='similarity',
    search_kwargs={
        'k': 4,
        'filter': {"source": file}
    },
)

# Create a "RetrievalQA" chain
chainSim = RetrievalQA.from_chain_type(
    llm=llm,
    retriever=retrieverSim,
    chain_type_kwargs={
        'prompt': PROMPT,
        'document_variable_name': 'summaries'
    }
)
# Run it and print results
responseSim = chainSim.run(QUERY)
print(responseSim)

Full example here: https://github.com/smatiolids/astra-agent-memory/blob/main/Explicando%20Retrieval%20Augmented%20Generation.ipynb

0
Martin D On
retriever = db.as_retriever(search_kwargs={'filter': {'Field_1':'Z')}})