indexing does not speed up retrival of numpy array from sqlite3

69 views Asked by At

I am trying to store embedding as numpy array in sqlite3. Sqlite3 handles numpy array as blob.I am handling approximately 30 million data. I am able to successfully store the data and retrieve it as blob and de-serialize into numpy array. But the process of retrieving the embedding with id is slow. Even after indexing, speed doesn't go up.Use batches to retrieve the data, so memory get released. It took 98.21 sec to retrieve 10000 data points from database. To retrieve 0.1 million embedding, it took 11 minutes. As far as I know, after indexing the time to execute the query should be less. I don't know where I am doing wrong, any suggestion will be much appreciated.


import sqlite3
import numpy as np
import io  

# Establish connection with SQLite database
conn = sqlite3.connect('my_db.db')
cursor = conn.cursor()

# Create a table to store the array data with Id
cursor.execute('''CREATE TABLE data_array(id INTEGER PRIMARY KEY AUTOINCREMENT,custom_id INTEGER UNIQUE, array_data BLOB);''')


def doc_loader(file_id):
    """
    Load the embedding and ids.
    
    """
    embeds = np.load(f"../downloaded_files/embeds_chunk_{file_id}.npy")
    ids = json.load(open(f"../downloaded_files/ids_chunk_{file_id}.json"))

    return (ids,embeds)

def byte_converter(embeddings):
    """
    Convert numpy array into byte code to store the embeddings.
    
    """
    buffer = io.BytesIO()
    np.save(buffer, embeddings)
    return buffer.getvalue()

def sql_format_converter(ids,embeddings):
    """
    Formatting the data for SQL.
    """
    embedding_blob_array = [byte_converter(embed) for embed in embeddings]
    embedding_ids =  [int(emb_id) for emb_id in ids]

    assert len(embedding_ids) == len(embedding_blob_array), "len mismatch for embeddings"

    return (embedding_ids,embedding_blob_array)  

def insert_to_database(ids,byte_code):
    """ Insert embeddings into the database """
    
    # Insert multiple arrays custom IDs into the table
    data_to_insert = [(ids[i], sqlite3.Binary(data_bytes)) for i, data_bytes in enumerate(byte_code)]

    ## inserting the data
    try:
        cursor.executemany("INSERT OR REPLACE INTO data_array (custom_id, array_data) VALUES (?, ?)", data_to_insert)
    except BaseException as e:
        print(e)  

## Final loop to store the into database
for file_id in tqdm(file_ids):

    docs = doc_loader(file_id) ## loading docs

    byte_codes = sql_format_converter(*docs) ## converting to byte codes

    insert_to_database(*byte_codes)


conn.commit()
conn.close()

Whole loop of data insertions completed in 50 minutes.

Data retrieval part


## Generate 10000 random numbers between 1-30 million
import random
random_numbers = [random.randrange(1, 30000001) for _ in range(10000)]

import io
import numpy as np

def retrieve_embeddings_in_batch(doc_list, cursor, batch_size=1000):
    """
    Retrieve multiple embeddings using batch fetching with streaming.
    """
    embeddings_list = []

    # Split doc_list into chunks to fetch in batches
    for i in range(0, len(doc_list), batch_size):
        placeholders = ','.join(['?'] * batch_size)
        query = f"SELECT custom_id, array_data FROM cpt_array WHERE custom_id IN ({placeholders})"
        cursor.execute(query, doc_list[i:i+batch_size])

        # Fetch results in batches and process them
        while True:
            results = cursor.fetchmany(batch_size)
            if not results:
                break

            for result in results:
                retrieved_id = result[0]
                array_blob = result[1]

                buffer = io.BytesIO(array_blob)
                retrieved_array = np.array(np.load(buffer))

                embeddings_list.append({'id': retrieved_id, 'embedding': retrieved_array})

    return embeddings_list  

import time

## measuring time
start_time = time.time()
embeddings = retrieve_embeddings_in_batch(random_numbers,cursor,batch_size=1000)
end_time = time.time()


execution_time = end_time - start_time
print(execution_time) ## 98.21 sec


0

There are 0 answers