I am trying to scan an image where there may be multiple shapes of each class defined in the TFlite model. Confidence and the Bounding Boxes are getting resolved, however I am unable to figure out the Class ID of each detection. (For reference, TFLite model is trained in ssd_mobilenet_v1_fpn)
For example, if the model is trained for {"apple","oranges"}, the Image we are trying to scan contains multiple apples and oranges. We need to detect all occurrences with its bounding boxes.
The following is the code -
public void classifyImaged(Bitmap image){
try {
Modeld model = Modeld.newInstance(getApplicationContext());
ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
// Creates inputs for reference.
TensorBuffer inputFeature = TensorBuffer.createFixedSize(new int[]{1, 640, 640, 3}, DataType.FLOAT32);
ByteBuffer byteBuffer = ByteBuffer.allocateDirect(4 * imageSize * imageSize * 3);
byteBuffer.order(ByteOrder.nativeOrder());
// get 1D array of 224 * 224 pixels in image
int[] intValues = new int[imageSize * imageSize];
image.getPixels(intValues, 0, image.getWidth(), 0, 0, image.getWidth(), image.getHeight());
int pixel = 0;
for (int i = 0; i < imageSize; i++) {
for (int j = 0; j < imageSize; j++) {
int val = intValues[pixel++]; // RGB
byteBuffer.putFloat(((val >> 16) & 0xFF) * (1.f / 255.f));
byteBuffer.putFloat(((val >> 8) & 0xFF) * (1.f / 255.f));
byteBuffer.putFloat((val & 0xFF) * (1.f / 255.f));
}
}
inputFeature.loadBuffer(byteBuffer);
Modeld.Outputs outputs = executor.submit(() -> model.process(inputFeature)).get();
// Assuming outputFeature0 contains confidence scores, outputFeature1 contains bounding boxes, and outputFeature2 contains class IDs
TensorBuffer outputFeature0 = outputs.getOutputFeature0AsTensorBuffer(); // Confidence
TensorBuffer outputFeature1 = outputs.getOutputFeature1AsTensorBuffer(); // Bounding Boxes
TensorBuffer outputFeature2 = outputs.getOutputFeature2AsTensorBuffer(); // Class IDs
float[] confidences = outputFeature0.getFloatArray();
float[] boundingBoxes = outputFeature1.getFloatArray();
int[] classIds = outputFeature2.getIntArray();
int numDetections = Math.min(Math.min(confidences.length, boundingBoxes.length / 4), classIds.length);
for (int i = 0; i < numDetections; i++) {
int confidenceIndex = i;
int bboxIndex = i * 4;
int classIdIndex = i;
float confidence = confidences[confidenceIndex];
float x = boundingBoxes[bboxIndex];
float y = boundingBoxes[bboxIndex + 1];
float width = boundingBoxes[bboxIndex + 2];
float height = boundingBoxes[bboxIndex + 3];
int classId = classIds[classIdIndex];
System.out.println("ClassID: " + classId + ", Bounding Box: [" + x + ", " + y + ", " + width + ", " + height + "], Confidence: " + confidence);
}
executor.shutdown();
model.close();
} catch (IOException e) {
// TODO Handle the exception
} catch (ExecutionException e) {
throw new RuntimeException(e);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
Thanks in advance for the help.
I am unable to figure out how to get the Class id of each detection.