Loading large sets of data for TensorFlow deep learning

43 views Asked by At

I'm loading data consisting of thousands of MRI images. I'm doing it like this using nibabel to obtain the 3D data arrays from the MRI files:

def get_voxels(path):
    img = nib.load(path)
    data = img.get_fdata()

    return data.copy()


df = pd.read_csv("/home/paths_updated_shuffled_4.csv")
df = df.reset_index()

labels = []
images = []
for index, row in df.iterrows():
    images.append(get_voxels(row['path']))
    labels.append(row['pass'])
labels = np.array(labels)
images = np.array(images)

n = len(df.index)
train_n = int(0.8 * n)
train_images = images[:train_n]
train_labels = labels[:train_n]
validation_n = (n - train_n) // 2
validation_end = train_n + validation_n
validation_images, validation_labels = images[train_n:validation_end], labels[train_n:validation_end]
test_images = images[validation_end:]
test_labels = labels[validation_end:]

train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
validation_ds = tf.data.Dataset.from_tensor_slices((validation_images, validation_labels))
test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels))

As you can see, I'm using tf.data.Dataset.from_tensor_slices. However, I'm running out of memory because of the large number of large files.

Is there a better way to do this in TensorFlow or Keras.

1

There are 1 answers

0
Paul Reiners On BEST ANSWER

Do as described in 3D image classification from CT scans by Hasib Zunair.

import nibabel as nib
import pandas as pd
import numpy as np

def process_scan(path):
    """Read and resize volume"""
    # Read scan
    volume = read_nifti_file(path)
    # Normalize
    volume = normalize(volume)
    # Resize width, height and depth
    volume = resize_volume(volume)
    return volume


df = pd.read_csv("/home/paths_updated_shuffled_4.csv")
n = len(df.index)
passing_rows = df.loc[df['pass'] == 1]
normal_scan_paths = passing_rows['path'].tolist()
failing_rows = df.loc[df['pass'] == 0]
abnormal_scan_paths = failing_rows['path'].tolist()

print("Passing MRI scans: " + str(len(normal_scan_paths)))
print("Failing MRI scans: " + str(len(abnormal_scan_paths)))

# Loading data and preprocessing
# Read and process the scans.
# Each scan is resized across height, width, and depth and rescaled.
abnormal_scans = np.array([process_scan(path) for path in abnormal_scan_paths])
normal_scans = np.array([process_scan(path) for path in normal_scan_paths])