I am working on a function to do some processing on big input data. However, since I can't fit all data into memory at once (117703x200000 matrix for dot product) I'm dividing it into chunks and computing into by parts.
The output only takes the first 5 elements (after sorting) and therefore must be of shape 117703x5, which is feasible to be held in memory. However, for some reason, as the loop goes my memory consumption keeps increasing until I get a memory error. Any ideas why? Here is the code:
def process_predictions_proto(frac=50): # Simulate some inputs query_embeddings = np.random.random((117703, 512)) proto_feat = np.random.random((200000, 512)) gal_cls = np.arange(200000) N_val = query_embeddings.shape pred =  for i in tqdm(range(frac)): start = i * int(np.ceil(N_val / frac)) stop = (i + 1) * int(np.ceil(N_val / frac)) val_i = query_embeddings[start:stop, :] # Compute distances dist_i = np.dot(val_i, proto_feat.transpose()) # Sort index_i = np.argsort(dist_i, axis=1)[::-1] dist_i = np.take_along_axis(dist_i, index_i, axis=1) # Convert distances to class_ids pred_i = np.take_along_axis( np.repeat(gal_cls[np.newaxis, :], index_i.shape, axis=0), index_i, axis=1) # Use pd.unique to remove copies of the same class_id and # get 5 most similar ids pred_i = [pd.unique(pi)[:5] for pi in pred_i] # Append to list pred.append(pred_i) # Free memory gc.collect() pred = np.stack(pred, 0) # N_val x 5 return pred