How do I add Select TensorFlow op(s) to a python interpreter in Android using Chaquopy

142 views Asked by At

I'm using chaquopy within the Android code of a Flutter project to leverage a python script that uses some tensorflow lite models.

Here's the python script:

from io import BytesIO
import base64

import tensorflow as tf
from skimage import io
from imutils.object_detection import non_max_suppression
import numpy as np
import math
import time
import cv2
import string
from os.path import dirname, join


def preprocess_east(image: np.ndarray):
    input_image = image
    orig = input_image.copy()
    (H, W) = input_image.shape[:2]
    (newW, newH) = (416, 640)
    rW = W / float(newW)
    rH = H / float(newH)
    image = cv2.resize(input_image, (newW, newH))
    (H, W) = image.shape[:2]
    image = image.astype("float32")
    mean = np.array([123.68, 116.779, 103.939][::-1], dtype="float32")
    image -= mean
    image = np.expand_dims(image, 0)
    return input_image, image, rW, rH


def run_east_tflite(input_data):
    model_path = join(dirname(__file__), "east_float_640.tflite")
    interpreter = tf.lite.Interpreter(model_path=model_path)
    input_details = interpreter.get_input_details()
    interpreter.allocate_tensors()
    interpreter.set_tensor(input_details[0]["index"], input_data)
    interpreter.invoke()
    scores = interpreter.tensor(interpreter.get_output_details()[0]["index"])()
    geometry = interpreter.tensor(interpreter.get_output_details()[1]["index"])()
    return scores, geometry


def postprocess_east(scores, geometry, rW, rH, orig):
    scores = np.transpose(scores, (0, 3, 1, 2))
    geometry = np.transpose(geometry, (0, 3, 1, 2))
    (numRows, numCols) = scores.shape[2:4]
    rects = []
    confidences = []
    for y in range(0, numRows):
        scoresData = scores[0, 0, y]
        xData0 = geometry[0, 0, y]
        xData1 = geometry[0, 1, y]
        xData2 = geometry[0, 2, y]
        xData3 = geometry[0, 3, y]
        anglesData = geometry[0, 4, y]
        for x in range(0, numCols):
            if scoresData[x] < 0.5:
                continue
            (offsetX, offsetY) = (x * 4.0, y * 4.0)
            angle = anglesData[x]
            cos = np.cos(angle)
            sin = np.sin(angle)
            h = xData0[x] + xData2[x]
            w = xData1[x] + xData3[x]
            endX = int(offsetX + (cos * xData1[x]) + (sin * xData2[x]))
            endY = int(offsetY - (sin * xData1[x]) + (cos * xData2[x]))
            startX = int(endX - w)
            startY = int(endY - h)
            rects.append((startX, startY, endX, endY))
            confidences.append(scoresData[x])
    boxes = non_max_suppression(np.array(rects), probs=confidences)
    crops = []
    for startX, startY, endX, endY in boxes:
        startX = int(startX * rW)
        startY = int(startY * rH)
        endX = int(endX * rW)
        endY = int(endY * rH)
        cv2.rectangle(orig, (startX, startY), (endX, endY), (0, 0, 255), 3)
        crops.append([[startX, startY], [endX, endY]])
    return orig, crops


def preprocess_ocr(image):
    input_data = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    input_data = cv2.resize(input_data, (200, 31))
    input_data = input_data[np.newaxis]
    input_data = np.expand_dims(input_data, 3)
    input_data = input_data.astype("float32") / 255
    return input_data


def run_tflite_ocr(input_data):
    model_path = join(dirname(__file__), "keras_ocr_float16_ctc.tflite")
    interpreter = tf.lite.Interpreter(model_path=model_path)
    interpreter.allocate_tensors()

    # Get input and output tensors.
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    input_shape = input_details[0]["shape"]
    interpreter.set_tensor(input_details[0]["index"], input_data)

    interpreter.invoke()

    output = interpreter.get_tensor(output_details[0]["index"])
    return output


alphabets = string.digits + string.ascii_lowercase
blank_index = len(alphabets)


def postprocess_ocr(output, greedy=True):
    # Running decoder on TFLite Output
    final_output = "".join(
        alphabets[index] for index in output[0] if index not in [blank_index, -1]
    )
    return final_output


def run_ocr(img_bytes: bytes, detector="east", greedy=True):
    nd_array = read_image(img_bytes)
    start_time = time.time()
    input_image, preprocessed_image, rW, rH = preprocess_east(nd_array)
    scores, geometry = run_east_tflite(preprocessed_image)
    output, crops = postprocess_east(scores, geometry, rW, rH, input_image)
    font_scale = 1
    thickness = 2
    # i=0
    (h, w) = input_image.shape[:2]

    for box in crops:
        # i += 1
        yMin = box[0][1]
        yMax = box[1][1]
        xMin = box[0][0]
        xMax = box[1][0]
        xMin = max(0, xMin)
        yMin = max(0, yMin)
        xMax = min(w, xMax)
        yMax = min(h, yMax)

        cropped_image = input_image[yMin:yMax, xMin:xMax, :]
        # Uncomment it if you want to see the croppd images in output folder
        # cv2.imwrite(f'output/{i}.jpg', cropped_image)
        # print("i: ", i)
        # print("Box: ", box)
        # plt_imshow("cropped_image", input_image)
        processed_image = preprocess_ocr(cropped_image)
        ocr_output = run_tflite_ocr(processed_image)
        final_output = postprocess_ocr(ocr_output, greedy)
        # print("Text output: ", final_output)
        # final_output = ''
        cv2.putText(
            output,
            final_output,
            (box[0][0], box[0][1] - 10),
            cv2.FONT_HERSHEY_SIMPLEX,
            font_scale,
            (0, 0, 255),
            thickness,
        )
    print(
        f"Time taken to run OCR Model with {detector} detector and KERAS OCR is",
        time.time() - start_time,
        )
    return output.tobytes()


def image_to_byte_array(image_path: string) -> bytes:
    with open(image_path, "rb") as image:
        f = image.read()
        return bytes(f)


def read_image(content: bytes) -> np.ndarray:
    """
    Image bytes to OpenCV image

    :param content: Image bytes
    :returns OpenCV image
    :raises TypeError: If content is not bytes
    :raises ValueError: If content does not represent an image
    """
    if not isinstance(content, bytes):
        raise TypeError(f"Expected 'content' to be bytes, received: {type(content)}")
    image = cv2.imdecode(np.frombuffer(content, dtype=np.uint8), cv2.IMREAD_COLOR)
    if image is None:
        raise ValueError(f"Expected 'content' to be image bytes")
    return image


# image_path = r"/Users/josegeorges/Desktop/puro-labels/train/yes/label_1.jpg"
# img_bytes = image_to_byte_array(image_path)
# final_image = run_ocr(img_bytes, detector="east", greedy=True)


def call_ocr_from_android(img_bytes: bytearray):
    return run_ocr(img_bytes=bytes(img_bytes), detector="east", greedy=True)


# dst_folder = "./"
# out_file_name = "out_image.png"
# # Save the image in JPG format
# cv2.imwrite(os.path.join(dst_folder, out_file_name), final_image)

Here are the installed packages through gradle:

install "numpy"
install "opencv-python"
install "imutils"
install "scikit-image"
install "tensorflow"

I'm currently running into the following exception when trying to load the keras_ocr_float16_ctc.tflite interpreter:

Regular TensorFlow ops are not supported by this interpreter. Make sure you apply/link the Flex delegate before inference.Node number 192 (FlexCTCGreedyDecoder) failed to prepare.

From what I've read in TF docs, I should have the select-ops available since I'm installing the pip Tensorflow package, but that doesn't seem to be the case. I also thought I needed to follow the android instructions to add the org.tensorflow:tensorflow-lite-select-tf-ops:0.0.0-nightly-SNAPSHOT dependency but that also doesn't seem to work.

What can I do to run this Select Op(s) model on Android using chaquopy?

1

There are 1 answers

0
mhsmith On BEST ANSWER

Unfortunately it says here:

TensorFlow Lite with select TensorFlow ops are available in the TensorFlow pip package version since 2.3 for Linux and 2.4 for other environments.

But Chaquopy's TensorFlow build is currently at version 2.1.

As for the build.gradle dependencies block, that will only affect the TensorFlow Java API, not the Python API.

So the best options I can think of are:

  • Switch over to the Java API; or
  • Alter your model so it doesn't depend on the extra operators; or
  • Try using the tflite-runtime pip package instead. Chaquopy currently provides this at version 2.5, so it may have some additional operators.