why does my U-Net segmentation model not work in Android app

29 views Asked by At

I develop an Android app using Android Studio that is capable of detecting and segmenting chronic wounds in an image. To do this, I created two ML models: a detection model created with Tensorflow and a segmentation model created from a U-Net model. This two models take as input an RGB image in 320320 format. I then converted these models to tflite format to put them in my app. I managed to correctly implement the detection model which allows me to correctly detect the shapes that interest me. Following this, I crop the detected shapes and send them to my segmentation model after resizing them to 320320 format (input format requested by my U-Net model).

This is where my problem comes in because I get an entirely black image as if my U-Net model had not detected anything. I get a list with a value for each pixel but these values are very low because they consider all pixels as not part of the form detected, which is wrong. The metrics of my model are very good and the tests performed were satisfactory. What intrigues me the most is the fact that it detects absolutely nothing, so I think there is a problem in the way I treat images before or after segmentation in java.

I share you a part of my android studio code :


    public void createMask(List<Bitmap> wounds){
        try {
            Weights2db model = Weights2db.newInstance(getContext());

            Bitmap finalMask = Bitmap.createBitmap(imageSize, imageSize, Bitmap.Config.ARGB_8888);
            // Creates inputs for reference.
            for(Bitmap wound : wounds) {

                Bitmap resizedWound = Bitmap.createScaledBitmap(wound, imageSize, imageSize, true);

                TensorBuffer inputFeature0 = TensorBuffer.createFixedSize(new int[]{1, 320, 320, 3}, DataType.FLOAT32);
                ByteBuffer byteBuffer2 = convertBitmapToByteBufferMask(resizedWound);

                inputFeature0.loadBuffer(byteBuffer2);

                // Runs model inference and gets result.
                Weights2db.Outputs outputs = model.process(inputFeature0);
                TensorBuffer outputFeature0 = outputs.getOutputFeature0AsTensorBuffer();

                float[] prediction = outputFeature0.getFloatArray();

                float threshold = 0.5f;

                Bitmap maskBitmap = binarizeMask(prediction, threshold, imageSize, imageSize);

                overlayMask(finalMask, maskBitmap);
            }
            // Releases model resources if no longer used.
            model.close();

            imageView.setImageBitmap(finalMask);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }


    private ByteBuffer convertBitmapToByteBufferMask(Bitmap bitmap){
        ByteBuffer byteBuffer = ByteBuffer.allocateDirect(4*imageSize*imageSize*3);
        byteBuffer.order(ByteOrder.nativeOrder());

        int[] intValues = new int[imageSize*imageSize];
        bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
        int pixel = 0;
        for (int i = 0; i<imageSize; i++){
            for(int j = 0; j<imageSize; j++) {
                int val = intValues[pixel++];
                byteBuffer.putFloat(((val>>16)&0xFF)/255.f);
                byteBuffer.putFloat(((val>>8)&0xFF)/255.f);
                byteBuffer.putFloat((val&0xFF)/255.f);
            }
        }
        return byteBuffer;
    }

    private Bitmap binarizeMask(float[] predictions, float threshold, int width, int height) {
        Bitmap maskBitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888);
        int pixelIndex = 0;
        for (int y = 0; y < height; y++) {
            for (int x = 0; x < width; x++) {
                float prediction = predictions[pixelIndex++];
                int pixelColor = (prediction > threshold) ? Color.WHITE : Color.BLACK;
                maskBitmap.setPixel(x, y, pixelColor);
            }
        }
        return maskBitmap;
    }

And the format of my tflite segmentation file :

try {
    Weights2db model = Weights2db.newInstance(context);

    // Creates inputs for reference.
    TensorBuffer inputFeature0 = TensorBuffer.createFixedSize(new int[]{1, 320, 320, 3}, DataType.FLOAT32);
    inputFeature0.loadBuffer(byteBuffer);

    // Runs model inference and gets result.
    Weights2db.Outputs outputs = model.process(inputFeature0);
    TensorBuffer outputFeature0 = outputs.getOutputFeature0AsTensorBuffer();

    // Releases model resources if no longer used.
    model.close();
} catch (IOException e) {
    // TODO Handle the exception
}

I share you also the post-treatement that I do on Python after the segmentation prediction. This part works well and print correctly the segmentation :

id = 0
for x, y, original_size in zip(X_test, y_test, original_sizes):
    # Reshape input image for model prediction
    cv2_imshow(x)
    x_input = np.expand_dims(x, axis=0)
    # Predict using the model
    y_pred = model.predict(x_input)
    y_pred = tf.image.resize(y_pred, [original_size[0], original_size[1]])
    y_pred = y_pred.numpy()
    y_pred = 1 * (y_pred > 0.5)
    temp = np.squeeze(np.uint8(y_pred[0] * 255))

    # Create an empty mask with the correct shape
    y_pred_resized = np.zeros((original_size[0], original_size[1], 1), dtype=np.uint8)
    y_pred_resized[:, :, 0] = temp

    # Resize predicted mask back to the original dimensions
    y_pred_resized = cv2.resize(y_pred_resized, (original_size[1], original_size[0]))
    y_resized = cv2.resize(y, (original_size[1], original_size[0]))

    # Mask Binarisation
    y_resized *= 255
    _, y_resized_mask = cv2.threshold(y_resized, 127, 255, cv2.THRESH_BINARY)
    _, y_pred_binarized = cv2.threshold(y_pred_resized, 127, 255, cv2.THRESH_BINARY)

    print("Real mask")
    plt.imshow(y_resized_mask, cmap='gray')
    plt.show()
    print("Predicted mask")
    plt.imshow(y_pred_resized, cmap='gray')  # Affichez le masque en niveaux de gris
    plt.show()
    img_save = f"segmented_{id}.png"
    img_path = os.path.join(save_path, img_save)

    cv2.imwrite(img_path, y_pred_resized)



    id += 1

I set the threshold value very low and the image displayed was always entirely black.

Once the prediction was made, I also multiplied each pixel by an important value and the result was that the segmentation did not work properly. Indeed, some pixels belonging to the background were colored in white while others belonging to the form detected remained in black. My U-Net model being very efficient, these results amaze me.

I also debugged my application to see in detail the steps and the image I provided to my model is the one I want so this is not the problem.

Edited : I used my tflite model on google Colab to see if the problem came from my model. I use the same input image that I use to test my android app. In this case, it gives me a good segmentation of the form so my model perform well. My issue come from my code in android studio and not from my modeI guess but I don't find it. Here is my Google Colab code :

import cv2
import numpy as np

interpreter = 

tf.lite.Interpreter('/mydrive2/chronic_wounds_project/Segmentation:U-Net_2DB/results/weights_2DB.tflite') interpreter.get_input_details() interpreter.get_output_details()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

image = cv2.imread(image_path)

cv2.waitKey(0)

desired_size = (320, 320)
image = cv2.resize(image, desired_size)
cv2_imshow(image)

x_input = np.expand_dims(image, axis=0)

x_input_float32 = x_input.astype(np.float32)

interpreter.set_tensor(input_details[0]['index'], 
x_input_float32)
interpreter.invoke()
y_pred = interpreter.get_tensor(output_details[0]['index'])
print(y_pred)
print(y_pred.shape)

y_pred_resized = cv2.resize(y_pred[0], (image.shape[1], 
image.shape[0]))

y_pred_binarized = np.uint8(y_pred_resized > 0.5) * 255

cv2_imshow(y_pred_binarized)
cv2.waitKey(0)
cv2.destroyAllWindows()

My model take as input an image with 13203203 format and have for output a 13203201 format.

0

There are 0 answers