I'm using ND4J for a machine learning task and have encountered an unexpected behavior with my 4D INDArray
. During the model fitting process, an error related to the input data rank is thrown.
- My initial
INDArray
,allData
, has a shape of[4392, 2, 21, 12]
. - A control loop iterates over
allData
in batches of size 24 and confirms that all batches have a consistent rank of 4. - After extracting a subset
y
for training using the code below, batches extracted fromy
starting at index 3072 during the iteration have a rank of 0, which is unexpected.
INDArray y = allData.get(NDArrayIndex.interval(numberOfSkipHours, lenTotal),
NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());
numberOfSkipHours
is set to 672, andlenTotal
is 4392, soy
is expected to be[3720, 2, 21, 12]
.- The issue is not observed in
allData
before extraction, and no rank problems are detected in the control loop forallData
.
This drop in rank occurs only after creating the subset y
, and I'm at a loss as to why. Here are the outputs for allData
before and after normalization, and for y
after extraction.
output for rank control loop
Why does this rank change occur at index 3072 for y
, and how it might be related to the model fitting error?
I attempted to create a training subset y
from a larger INDArray
allData
by slicing it along the first dimension to exclude a certain number of initial hours (numberOfSkipHours
). I expected the resulting y
to have the same rank (4) as allData
because I only reduced the size of the first dimension, maintaining the 4D structure.
INDArray y = allData.get(NDArrayIndex.interval(numberOfSkipHours, allData.size(0)),
NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());
After slicing, I iterated over y
in batches to check the rank. I expected each batch to maintain a consistent rank of 4 throughout the iteration. However, unexpectedly, the rank of the batches dropped to 0 starting at index 3072.
int batchSize = 24; // Known to divide evenly into the total number of samples
for (int i = 0; i < y.size(0); i += batchSize) {
INDArray batch = y.get(NDArrayIndex.interval(i, i + batchSize),
NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());
if (batch.rank() != 4) {
// This is where I encountered the unexpected result
System.out.println("Unexpected rank detected at batch starting index: " + i);
break;
}
}
I anticipated no rank issues, given the controlled conditions, but the actual result was a batch with a rank of 0, which does not make sense in the context and suggests an error with the slicing or the data itself. This issue is critical as it leads to errors during the model fitting process in ND4J.