DL4J/ND4J Model Fitting Error: Rank of Input Data Drops to 0 After Slicing INDArray for Training Subset

18 views Asked by At

I'm using ND4J for a machine learning task and have encountered an unexpected behavior with my 4D INDArray. During the model fitting process, an error related to the input data rank is thrown.

  • My initial INDArray, allData, has a shape of [4392, 2, 21, 12].
  • A control loop iterates over allData in batches of size 24 and confirms that all batches have a consistent rank of 4.
  • After extracting a subset y for training using the code below, batches extracted from y starting at index 3072 during the iteration have a rank of 0, which is unexpected.
INDArray y = allData.get(NDArrayIndex.interval(numberOfSkipHours, lenTotal),
                         NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());
  • numberOfSkipHours is set to 672, and lenTotal is 4392, so y is expected to be [3720, 2, 21, 12].
  • The issue is not observed in allData before extraction, and no rank problems are detected in the control loop for allData.

This drop in rank occurs only after creating the subset y, and I'm at a loss as to why. Here are the outputs for allData before and after normalization, and for y after extraction. output for rank control loop

Why does this rank change occur at index 3072 for y, and how it might be related to the model fitting error?

I attempted to create a training subset y from a larger INDArray allData by slicing it along the first dimension to exclude a certain number of initial hours (numberOfSkipHours). I expected the resulting y to have the same rank (4) as allData because I only reduced the size of the first dimension, maintaining the 4D structure.

INDArray y = allData.get(NDArrayIndex.interval(numberOfSkipHours, allData.size(0)),
                         NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());

After slicing, I iterated over y in batches to check the rank. I expected each batch to maintain a consistent rank of 4 throughout the iteration. However, unexpectedly, the rank of the batches dropped to 0 starting at index 3072.

int batchSize = 24; // Known to divide evenly into the total number of samples
for (int i = 0; i < y.size(0); i += batchSize) {
    INDArray batch = y.get(NDArrayIndex.interval(i, i + batchSize),
                           NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());
    if (batch.rank() != 4) {
        // This is where I encountered the unexpected result
        System.out.println("Unexpected rank detected at batch starting index: " + i);
        break;
    }
}

I anticipated no rank issues, given the controlled conditions, but the actual result was a batch with a rank of 0, which does not make sense in the context and suggests an error with the slicing or the data itself. This issue is critical as it leads to errors during the model fitting process in ND4J.

0

There are 0 answers