How do you apply session state to Gradio's ChatInterface?

612 views Asked by At

I have created a chatbot with the Gradio's gr.ChatInterface() and Langchain's ConversationalRetrievalChain with chat history. Once uploaded on Huggingface Spaces, I noticed the chat history was being shared across users. I have tried opening the link to my model in Huggingface Spaces on different browsers/ devices, and the conversation history is still retained.

I would like the chat history to be different for every user and not to get jumbled between different users. How can I implement Gradio's ChatInterface() with session state, where the chat history is cleared after each session and is different for every user?

My code is here:

import os
from typing import Optional, Tuple

import gradio as gr
from langchain.chains import ConversationChain, ConversationalRetrievalChain
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.memory import ConversationBufferMemory
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain.llms import OpenAI
from langchain.schema import AIMessage, HumanMessage

def load_chain(llm_name="gpt-3.5-turbo"):
    """Logic for loading the chain you want to use should go here."""
    
    # define embedding
    embedding = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-mpnet-base-v2",
    model_kwargs={'device': 'cpu'},
    encode_kwargs={'normalize_embeddings': False})
    
    # create vector database from data
    persist_directory = 'docs/chroma/'
    vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding)
    
    # Wrap our vectorstore
    llm = OpenAI(temperature=0)
    compressor = LLMChainExtractor.from_llm(llm)
    
    # define retriever
    compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor,
    base_retriever=vectordb.as_retriever(search_type = "mmr")
)    
    # Build prompt
    template = """
    Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. 
    If there are any assumptions or requirements for the answer to apply, please include them in your response. 
    {context}
    Question: {question}
    Helpful Answer:"""
    QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],template=template,)
    
    # define memory
    memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
    
    # create a chatbot chain
    chain = ConversationalRetrievalChain.from_llm(
        llm=ChatOpenAI(model_name=llm_name, temperature=0), 
        memory=memory, 
        retriever=compression_retriever, 
        combine_docs_chain_kwargs={"prompt": QA_CHAIN_PROMPT}
    )
    return chain

chain = load_chain()

def predict(message, history):
    history_langchain_format = []
    for human, ai in history:
        history_langchain_format.append(HumanMessage(content=human))
        history_langchain_format.append(AIMessage(content=ai))
    history_langchain_format.append(HumanMessage(content=message))
    gpt_response = chain({"question": message})
    return gpt_response['answer']
    
block = gr.Blocks()

with block:
    
    chatbot = gr.ChatInterface(
    fn=predict,
    title="Chatbot")

block.launch()
0

There are 0 answers