Custom Model with ML Kit Returning Incorrect Bounding Boxes, Mismatch with Preview

48 views Asked by At

I am working on an Android Studio project where I'm using a custom model wrapped with Google ML Kit for object detection in real-time. However, I'm encountering an issue where the bounding boxes returned by the model do not match the preview displayed on the screen. The labels seem to be accurate though. The input image is 1920 x 1920 and my preview 2239 x 1080. Here's my function:

private void BindPreview(ProcessCameraProvider CameraProvider) {
    speaker.speakText("Please hold phone in front of you to detect obstacles");

    preview = new Preview.Builder()
            .setTargetResolution(new Size(1920, 1920))
            .build();
    cameraSelector = new CameraSelector.Builder().requireLensFacing(camFacing).build();
    preview.setSurfaceProvider(previewView.getSurfaceProvider());

    imageAnalysis = new ImageAnalysis.Builder()
            .setTargetResolution(new Size(1920, 1920))
            .setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST)
            .build();

    imageAnalysis.setAnalyzer(ContextCompat.getMainExecutor(this),
            new ImageAnalysis.Analyzer() {
                @ExperimentalGetImage
                @Override
                public void analyze(@NonNull ImageProxy imageProxy) {
                    Image image = imageProxy.getImage();
                    if (image != null) {
                        InputImage inputImage = InputImage.fromMediaImage(image, imageProxy.getImageInfo().getRotationDegrees());
                        Task<List<DetectedObject>> task = objectDetector.process(inputImage);

                        task.addOnSuccessListener(
                                new OnSuccessListener<List<DetectedObject>>() {
                                    @Override
                                    public void onSuccess(List<DetectedObject> detectedObjects) {
                                        if (!detectedObjects.isEmpty()) {
                                            objectQueue.addAll(detectedObjects);
                                            Matrix mappingMatrix =                                                                       ProjectHelper.getMappingMatrix(imageProxy, previewView);
                                            for (DetectedObject object : detectedObjects) {
                                                Rect boundingBox = ProjectHelper.mapBoundingBox(object.getBoundingBox(), mappingMatrix);
                                                rectangleOverlayView.updateRect(boundingBox);
                                                rectangleOverlayView.invalidate();
                                            }
                                        }
                                    }
                                }
                        )
                                .addOnFailureListener(
                                        new OnFailureListener() {
                                            @Override
                                            public void onFailure(@NonNull Exception e) {
                                                Log.e("Object Detection", e.getMessage());
                                            }
                                        }
                                )
                                .addOnCompleteListener(
                                        new OnCompleteListener<List<DetectedObject>>() {
                                            @Override
                                            public void onComplete(@NonNull Task<List<DetectedObject>> task) {
                                                imageProxy.close();
                                                image.close();
                                            }
                                        }
                                );

                    }
                }
            });

    CameraProvider.bindToLifecycle((LifecycleOwner) this, cameraSelector, imageAnalysis, preview);
}

I got the 'getMappingMatrix' function from another Stack Overflow question but it didn't seem to help that much. Here is that function:

public static Matrix getMappingMatrix(ImageProxy imageProxy, PreviewView previewView) {
        Rect cropRect = imageProxy.getCropRect();
        int rotationDegrees = imageProxy.getImageInfo().getRotationDegrees();
        Matrix matrix = new Matrix();

        // A float array of the source vertices (crop rect) in clockwise order.
        float[] source = {
                cropRect.left,
                cropRect.top,
                cropRect.right,
                cropRect.top,
                cropRect.right,
                cropRect.bottom,
                cropRect.left,
                cropRect.bottom
        };

        // A float array of the destination vertices in clockwise order.
        float[] destination = {
                0f,
                0f,
                previewView.getWidth(),
                0f,
                previewView.getWidth(),
                previewView.getHeight(),
                0f,
                previewView.getHeight()
        };

        // The destination vertexes need to be shifted based on rotation degrees.
        // The rotation degree represents the clockwise rotation needed to correct
        // the image.

        // Each vertex is represented by 2 float numbers in the vertices array.
        int vertexSize = 2;
        // The destination needs to be shifted 1 vertex for every 90° rotation.
        int shiftOffset = rotationDegrees / 90 * vertexSize;
        float[] tempArray = destination.clone();
        for (int toIndex = 0; toIndex < source.length; toIndex++) {
            int fromIndex = (toIndex + shiftOffset) % source.length;
            destination[toIndex] = tempArray[fromIndex];
        }
        matrix.setPolyToPoly(source, 0, destination, 0, 4);
        return matrix;
    }

I'm using a model that was recommended by the ML Kit documentation. EfficientNet-Lite Image Classifier

Here is my onCreate function:


  @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_object_detection);
        LocalModel localModel =
                new LocalModel.Builder()
                        .setAssetFilePath("2.tflite")
                        .build();

        CustomObjectDetectorOptions customObjectDetectorOptions =
                new CustomObjectDetectorOptions.Builder(localModel)
                        .setDetectorMode(CustomObjectDetectorOptions.STREAM_MODE)
                        .enableClassification()
                        .setClassificationConfidenceThreshold(0.5f)
                        .setMaxPerObjectLabelCount(1)
                        .build();

        objectDetector = ObjectDetection.getClient(customObjectDetectorOptions);
        previewView = findViewById(R.id.cameraPreview);
        context = this;
        rectangleOverlayView = findViewById(R.id.rectangle_overlay);

        cameraProviderFuture = ProcessCameraProvider.getInstance(this);
        cameraProviderFuture.addListener(() -> {
            try {
                cameraProvider = cameraProviderFuture.get();
                if(ContextCompat.checkSelfPermission(ObjectDetectionActivity.this, android.Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED){
                    activityResultLauncher.launch(Manifest.permission.CAMERA);
                } else{
                    BindPreview(cameraProvider);
                }
            } catch (ExecutionException | InterruptedException e) {
                e.printStackTrace(); // Handle exceptions as needed
                Log.e("CamerX Camera Provider", e.getMessage());
            }
        }, ContextCompat.getMainExecutor(this));
    }
0

There are 0 answers