How to handle memory intensive task causing WorkerLostError with Celery and HuggingFaceEmbedding?

21 views Asked by At

I'm trying to use celery to handle the heavy task of creating a new qdrant collection every time a new model is created, I need to extract the content of the file, create embedding and store it in qdrant db as a collection. The problem is, I get the following error when I call embeddings.embed with HuggingFaceEmbedding inside celery.

celery-dev-1  | [2024-03-27 10:18:27,451: INFO/ForkPoolWorker-19] Load pretrained SentenceTransformer: sentence-transformers/all-mpnet-base-v2
celery-dev-1  | [2024-03-27 10:18:35,856: ERROR/MainProcess] Process 'ForkPoolWorker-19' pid:115 exited with 'signal 11 (SIGSEGV)'
celery-dev-1  | [2024-03-27 10:18:35,868: ERROR/MainProcess] Task handler raised error: WorkerLostError('Worker exited prematurely: signal 11 (SIGSEGV) Job: 3.')
celery-dev-1  | Traceback (most recent call last):
celery-dev-1  |   File "/usr/local/lib/python3.10/site-packages/billiard/pool.py", line 1264, in mark_as_worker_lost
celery-dev-1  |     raise WorkerLostError(
celery-dev-1  | billiard.einfo.ExceptionWithTraceback:
celery-dev-1  | """
celery-dev-1  | Traceback (most recent call last):
celery-dev-1  |   File "/usr/local/lib/python3.10/site-packages/billiard/pool.py", line 1264, in mark_as_worker_lost
celery-dev-1  |     raise WorkerLostError(
celery-dev-1  | billiard.exceptions.WorkerLostError: Worker exited prematurely: signal 11 (SIGSEGV) Job: 3.
celery-dev-1  | """

Here is the knowledge model when the task is called,

class Knowledge(Common):
    name = models.CharField(max_length=255, blank=True, null=True)
    file = models.FileField(upload_to=knowledge_path, storage=PublicMediaStorage())
    qd_knowledge_id = models.CharField(max_length=255, blank=True, null=True)
    is_public = models.BooleanField(default=False)

    #
    def save(self, *args, **kwargs):

        if self.pk is None:
            collection_name = f"{self.name}-{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}"
            process_files_and_upload_to_qdrant.delay(self.file.name, collection_name)
            self.qd_knowledge_id = collection_name
        super().save(*args, **kwargs)

here is the task and the functions it calls:

@shared_task
def process_files_and_upload_to_qdrant(file_name, collection_name):
    file_path = default_storage.open(file_name)
    result = process_file(file_path, collection_name)
    return result

def process_file(file : InMemoryUploadedFile, collection_name):
    text = read_data_from_pdf(file)
    chunks = get_text_chunks(text)
    embeddings = get_embeddings(chunks)
    client.create_collection(
            collection_name=collection_name,
            vectors_config=qdrant_models.VectorParams(
                size=768, distance=qdrant_models.Distance.COSINE
            ),
        )
    client.upsert(collection_name=collection_name, wait=True, points=embeddings)


def read_data_from_pdf(file : InMemoryUploadedFile):
    text = ""

    pdf_reader = PdfReader(file)

    for page in pdf_reader.pages:
        text += page.extract_text()

    return text


def get_text_chunks(texts: str):
    text_splitter = CharacterTextSplitter(
        separator="\n", chunk_size=1000, chunk_overlap=200, length_function=len
    )
    chunks = text_splitter.split_text(texts)
    return chunks


def get_embeddings(text_chunks):
    from langchain_community.embeddings import HuggingFaceEmbeddings
    from qdrant_client.http.models import PointStruct
    embeddings = HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-mpnet-base-v2"
    )
    points = []
   
    for chunk in text_chunks:
        embedding = embeddings.embed_query(chunk) <---- The error occurs here
        point_id = str(uuid.uuid4())  
        points.append(
            PointStruct(id=point_id, vector=embedding, payload={"text": chunk})
        )

    return points

How do I approach this? Since the model is created as a many to many field, the response takes a long time, due to which I'm trying to move it into a celery task. (Some delay when storing in qdrant is acceptable, it just shouldn't affect the api response time). The api works fine when I do it without celery, but it's super slow.

I've tried splitting them into multiple small celery tasks, but I can't pass the embeddings or non-json serializable data into the task. I don't know how to approach this.

0

There are 0 answers