Using TFLite Object detection Model in Android, how to get Class ID of the detection

116 views Asked by At

I am trying to scan an image where there may be multiple shapes of each class defined in the TFlite model. Confidence and the Bounding Boxes are getting resolved, however I am unable to figure out the Class ID of each detection. (For reference, TFLite model is trained in ssd_mobilenet_v1_fpn)

For example, if the model is trained for {"apple","oranges"}, the Image we are trying to scan contains multiple apples and oranges. We need to detect all occurrences with its bounding boxes.

The following is the code -


public void classifyImaged(Bitmap image){
    try {
        Modeld model = Modeld.newInstance(getApplicationContext());
        ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());

        // Creates inputs for reference.
        TensorBuffer inputFeature = TensorBuffer.createFixedSize(new int[]{1, 640, 640, 3}, DataType.FLOAT32);

        ByteBuffer byteBuffer = ByteBuffer.allocateDirect(4 * imageSize * imageSize * 3);
        byteBuffer.order(ByteOrder.nativeOrder());

        // get 1D array of 224 * 224 pixels in image
        int[] intValues = new int[imageSize * imageSize];
        image.getPixels(intValues, 0, image.getWidth(), 0, 0, image.getWidth(), image.getHeight());

        int pixel = 0;
        for (int i = 0; i < imageSize; i++) {
            for (int j = 0; j < imageSize; j++) {
                int val = intValues[pixel++]; // RGB
                byteBuffer.putFloat(((val >> 16) & 0xFF) * (1.f / 255.f));
                byteBuffer.putFloat(((val >> 8) & 0xFF) * (1.f / 255.f));
                byteBuffer.putFloat((val & 0xFF) * (1.f / 255.f));
            }
        }

        inputFeature.loadBuffer(byteBuffer);
        Modeld.Outputs outputs = executor.submit(() -> model.process(inputFeature)).get();

        // Assuming outputFeature0 contains confidence scores, outputFeature1 contains bounding boxes, and outputFeature2 contains class IDs

        TensorBuffer outputFeature0 = outputs.getOutputFeature0AsTensorBuffer(); // Confidence
        TensorBuffer outputFeature1 = outputs.getOutputFeature1AsTensorBuffer(); // Bounding Boxes
        TensorBuffer outputFeature2 = outputs.getOutputFeature2AsTensorBuffer(); // Class IDs

        float[] confidences = outputFeature0.getFloatArray();
        float[] boundingBoxes = outputFeature1.getFloatArray();
        int[] classIds = outputFeature2.getIntArray();

        int numDetections = Math.min(Math.min(confidences.length, boundingBoxes.length / 4), classIds.length);

        for (int i = 0; i < numDetections; i++) {
            int confidenceIndex = i;
            int bboxIndex = i * 4;
            int classIdIndex = i;

            float confidence = confidences[confidenceIndex];
            float x = boundingBoxes[bboxIndex];
            float y = boundingBoxes[bboxIndex + 1];
            float width = boundingBoxes[bboxIndex + 2];
            float height = boundingBoxes[bboxIndex + 3];
            int classId = classIds[classIdIndex];
            System.out.println("ClassID: " + classId + ", Bounding Box: [" + x + ", " + y + ", " + width + ", " + height + "], Confidence: " + confidence);
        }
        executor.shutdown();
        model.close();
    } catch (IOException e) {
        // TODO Handle the exception
    } catch (ExecutionException e) {
        throw new RuntimeException(e);
    } catch (InterruptedException e) {
        throw new RuntimeException(e);
    }
}

Thanks in advance for the help.

I am unable to figure out how to get the Class id of each detection.

0

There are 0 answers