I am working on a function to do some processing on big input data. However, since I can't fit all data into memory at once (117703x200000 matrix for dot product) I'm dividing it into chunks and computing into by parts.

The output only takes the first 5 elements (after sorting) and therefore must be of shape 117703x5, which is feasible to be held in memory. However, for some reason, as the loop goes my memory consumption keeps increasing until I get a memory error. Any ideas why? Here is the code:

def process_predictions_proto(frac=50):
    # Simulate some inputs
    query_embeddings = np.random.random((117703, 512))
    proto_feat = np.random.random((200000, 512))
    gal_cls = np.arange(200000)

    N_val = query_embeddings.shape[0]
    pred = []

    for i in tqdm(range(frac)):
        start = i * int(np.ceil(N_val / frac))
        stop = (i + 1) * int(np.ceil(N_val / frac))
        val_i = query_embeddings[start:stop, :]
        # Compute distances
        dist_i = np.dot(val_i, proto_feat.transpose())
        # Sort
        index_i = np.argsort(dist_i, axis=1)[::-1]
        dist_i = np.take_along_axis(dist_i, index_i, axis=1)
        # Convert distances to class_ids
        pred_i = np.take_along_axis(
            np.repeat(gal_cls[np.newaxis, :], index_i.shape[0], axis=0),
            index_i, axis=1)
        # Use pd.unique to remove copies of the same class_id and
        # get 5 most similar ids
        pred_i = [pd.unique(pi)[:5] for pi in pred_i]
        # Append to list
        pred.append(pred_i)
        # Free memory
        gc.collect()
    pred = np.stack(pred, 0)  # N_val x 5
    return pred

1 Answers

3
Barmar On Best Solutions

Delete all the temporary variables before calling gc.collect(), so that the data will become garbage immediately.

del start, stop, val_i, dist_i, index_i, dist_i, pred_i
gc.collect()

In your code, when you call gc.collect() the first time none of the data is garbage, because it can still be referenced from all the variables. The data from the first iteration won't be collected until the end of the second iteration; during each iteration after the first, you'll have two chunks of data in memory (the current iteration and the previous iteration). So you're using twice as much memory as you need (I assume there are references between some of the objects, so automatic GC is not cleaning up objects as the variables are reassigned during the loop).