TFLite converter not replacing dummy_function with TFLite_Detection_PostProcess

38 views Asked by At

I want to implement a tf.Module for decoding box predictions and applying NonMaxSuppression that is convertible to tflite.

This implementation includes elements from here.

It also follows this guide on operation fusion.

This is my Code:

def decode_predictions_lite(anchor_boxes: tf.Tensor, box_pred: tf.Tensor, cls_pred: tf.Tensor, variance: tf.Tensor, image_shape):
    """a concrete function that decodes box and class predictions and applies NonMaxSuppression

    Args:
        anchor_boxes: tf.Tensor of shape [1, N, 4] representing anchor_boxes in the 'center_yxhw' format
        box_pred: tf.Tensor of shape [b, N, 4] representing encoded box predictions from the model's output
        cls_pred: tf.Tensor of shape [b, N, num_classes] representing the class logits from the model's output
        variance: tf.Tensor of shape [4] representing the box variance that was used when encoding the boxes
        image_shape: tf.Tensor of shape [2] representing the height and width of the input image
    Returns: 
        pass
    Note: 
        - N is the number of anchor boxes. All input tensors are of dtype float32
        - b is the batch_size. For now, only a batch_size of 1 is supported.
    """
    scores = tf.sigmoid(cls_pred)
    boxes = box_pred*variance
    decoded_boxes = tf.concat(
        [
            boxes[..., :2] * anchor_boxes[..., :2] + anchor_boxes[..., :2],
            tf.math.exp(boxes[..., :2]) * anchor_boxes[..., :2]
        ], axis=-1
    )
    
    # Normalize anchor coordinates for TFLite's NMS operation.
    normalize_factor = tf.tile(image_shape, [2])
    anchor_boxes = anchor_boxes / normalize_factor
    anchor_boxes = tf.squeeze(anchor_boxes) # squeeze so the anchor_boxes are of shape (N, 4)
    
    # normalize box coordinates for TFLite's NMS operation
    decoded_boxes_rel = decoded_boxes / normalize_factor

    def get_implements_signature():
        implements_signature = ' '.join([
        'name: "%s"' % 'TFLite_Detection_PostProcess',
        'attr { key: "max_detections" value { i: %d } }' % 100,
        'attr { key: "max_classes_per_detection" value { i: %d } }' % 1,
        'attr { key: "detections_per_class" value { i: %d } }' % 5,
        'attr { key: "use_regular_nms" value { b: %s } }' % "false", # Lower
        'attr { key: "nms_score_threshold" value { f: %f } }' % 0.1,
        'attr { key: "nms_iou_threshold" value { f: %f } }' % 0.5,
        'attr { key: "y_scale" value { f: %f } }' % 1.0,
        'attr { key: "x_scale" value { f: %f } }' % 1.0,
        'attr { key: "h_scale" value { f: %f } }' % 1.0,
        'attr { key: "w_scale" value { f: %f } }' % 1.0,
        'attr { key: "num_classes" value { i: %d } }' % num_classes,
        ])
        return implements_signature
    
    @tf.function(experimental_implements=get_implements_signature())
    def dummy_postprocessing_nms(input_boxes, input_scores, input_anchors):
        boxes = tf.constant(0.0, dtype=tf.float32, name='boxes')
        scores = tf.constant(0.0, dtype=tf.float32, name='scores')
        classes = tf.constant(0.0, shape=(1, 100), dtype=tf.float32, name='classes')
        num_detections = tf.constant(0.0, dtype=tf.float32, name='num_detections')
        return boxes, classes, scores, num_detections
    
    return decoded_boxes, dummy_postprocessing_nms(decoded_boxes_rel, scores, anchor_boxes)

class PredictionDecoderLite(tf.Module):
    def __init__(self):
        super(PredictionDecoderLite, self).__init__()
    
    @tf.function(input_signature=[
        tf.TensorSpec(shape=[1, None, 4], dtype=tf.float32),  # anchor_boxes
        tf.TensorSpec(shape=[1, None, 4], dtype=tf.float32),  # box_pred
        tf.TensorSpec(shape=[1, None, None], dtype=tf.float32),  # cls_pred
        tf.TensorSpec(shape=[4], dtype=tf.float32),  # variance
        tf.TensorSpec(shape=[2], dtype=tf.float32)   # image_shape
    ])
    def decode_preds_lite(self, anchor_boxes, box_pred, cls_pred, variance, image_shape):
        return decode_predictions_lite(anchor_boxes, box_pred, cls_pred, variance, image_shape)
    
decoder_module = PredictionDecoderLite()
concrete_fn_with_nms = decoder_module.decode_preds_lite.get_concrete_function()

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_fn_with_nms], trackable_obj=decoder_module)

tflite_pred_decoder = converter.convert()
with open(tflite_postprocess_path, 'wb') as f:
    f.write(tflite_pred_decoder)

The issue is that the converter does not rewrites the dummy function to TFLite's custom NMS operation. I tested the tflite module with the following code:

image_shape = [HEIGHT, WIDTH, 3]
variance = tf.constant([1.0, 1.0, 1.0, 1.0])

anchor_generator = kcv.models.RetinaNet.default_anchor_generator(bounding_box_format)
anchors = anchor_generator(image_shape=image_shape)
anchors = ops.concatenate([a for a in anchors.values()], axis=0)
anchors = tf.expand_dims(anchors, axis=0)
print("anchors: ", anchors.shape)   # anchors:  (1, 76725, 4)
print("boxes: ", predictions['box'].shape)  # boxes:  (1, 76725, 4)
print("scores: ", predictions['classification'].shape)  # scores:  (1, 76725, 4)

interpreter = tf.lite.Interpreter(tflite_postprocess_path)
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

interpreter.resize_tensor_input(input_details[0]['index'], anchors.shape)
interpreter.resize_tensor_input(input_details[1]['index'], predictions['box'].shape)
interpreter.resize_tensor_input(input_details[2]['index'], predictions['classification'].shape)
interpreter.allocate_tensors()
interpreter.set_tensor(input_details[0]['index'], anchors.numpy())
interpreter.set_tensor(input_details[1]['index'], predictions['box'])
interpreter.set_tensor(input_details[2]['index'], predictions['classification'])
interpreter.set_tensor(input_details[3]['index'], variance.numpy())
interpreter.set_tensor(input_details[4]['index'], np.array(image_shape[:2], dtype='float32'))
interpreter.invoke()

# Retrieve outputs:
decoded_boxes = interpreter.get_tensor(output_details[0]['index'])
final_boxes = interpreter.get_tensor(output_details[1]['index'])
final_scores = interpreter.get_tensor(output_details[2]['index'])
final_classes = interpreter.get_tensor(output_details[3]['index'])
num_detections = interpreter.get_tensor(output_details[4]['index'])

I received zeros from my dummy function and no post-processed boxes. I don't know how to get more information out of the converter to see, what is wrong. Does anyone have an idea, how to proceed?

1

There are 1 answers

0
Robert Sundermeyer On

So... After almost two days of debugging, I finally found a workaround. Instead of using tf.lite.TFLiteConverter.from_concrete_functions() we first save the module and then use tf.lite.TFLiteConverter.from_saved_model(). In the tensorflow model garden, they do it the same way (See here).

There were also some other issues in the code above so here is an updated version. You can directly use the box and class predictions from a keras_cv.models.RetinaNet that was converted to tflite.

class PredictionDecoderLite(tf.Module):
    def __init__(self, num_classes, box_variance):
        self._num_classes = num_classes
        self._box_variance = box_variance
        super(PredictionDecoderLite, self).__init__()
        
        self.decode_preds_lite = tf.function(
            input_signature=[
                tf.TensorSpec(shape=[None, 4], dtype=tf.float32),  # anchor_boxes
                tf.TensorSpec(shape=[1, None, 4], dtype=tf.float32),  # box_pred
                tf.TensorSpec(shape=[1, None, self._num_classes], dtype=tf.float32),  # cls_pred
            ]
        )(self._decode_preds_impl)
    
    def _decode_preds_impl(self, anchor_boxes, box_pred, cls_pred):
        """a concrete function that decodes box and class predictions and applies NonMaxSuppression

        Args:
            anchor_boxes: tf.Tensor of shape [N, 4] representing anchor_boxes in the 'center_yxhw' format
            box_pred: tf.Tensor of shape [b, N, 4] representing encoded box predictions from the model's output
            cls_pred: tf.Tensor of shape [b, N, num_classes] representing the class logits from the model's output
        Returns: 
            pass
        Note: 
            - N is the number of anchor boxes. All input tensors are of dtype float32
            - b is the batch_size. For now, only a batch_size of 1 is supported.
              See: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/detection_postprocess.cc#L894C3-L894C17
        """
        class_predictions = tf.sigmoid(cls_pred)

        # Rename the input tensors.
        # See: https://github.com/tensorflow/tensorflow/blob/8c93fe44deb850c5978229723e65b3a7130d7422/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc#L58
        with tf.name_scope('raw_outputs'):
            box_encodings = tf.identity(box_pred, name='box_encodings')
            class_predictions = tf.identity(class_predictions, name='class_predictions')
        anchor_boxes = tf.identity(anchor_boxes, name='anchors') # Rename anchor_boxes to match TFLite's custom NMS    

        # See: https://github.com/tensorflow/models/blob/master/research/object_detection/export_tflite_graph_lib_tf2.py#L188
        def get_implements_signature():
            implements_signature = ' '.join([
            'name: "%s"' % 'TFLite_Detection_PostProcess',
            'attr { key: "max_detections" value { i: %d } }' % 100,
            'attr { key: "max_classes_per_detection" value { i: %d } }' % 1,
            'attr { key: "detections_per_class" value { i: %d } }' % 5,
            'attr { key: "use_regular_nms" value { b: %s } }' % "false",
            'attr { key: "nms_score_threshold" value { f: %f } }' % 0.1,
            'attr { key: "nms_iou_threshold" value { f: %f } }' % 0.5,
            'attr { key: "y_scale" value { f: %f } }' % (1.0/self._box_variance[0]),
            'attr { key: "x_scale" value { f: %f } }' % (1.0/self._box_variance[1]),
            'attr { key: "h_scale" value { f: %f } }' % (1.0/self._box_variance[2]),
            'attr { key: "w_scale" value { f: %f } }' % (1.0/self._box_variance[3]),
            'attr { key: "num_classes" value { i: %d } }' % cls_pred.shape[-1],
            ])
            return implements_signature
        
        @tf.function(experimental_implements=get_implements_signature())
        def dummy_postprocessing_nms(input_boxes, input_scores, input_anchors):
            boxes = tf.constant(0.0, dtype=tf.float32, name='boxes')
            scores = tf.constant(0.0, dtype=tf.float32, name='scores')
            classes = tf.constant(0.0, shape=(1, 100), dtype=tf.float32, name='classes')
            num_detections = tf.constant(0.0, dtype=tf.float32, name='num_detections')
            return boxes, classes, scores, num_detections
        return dummy_postprocessing_nms(box_encodings, class_predictions, anchor_boxes)[::-1]
    
decoder_module = PredictionDecoderLite(
    num_classes=num_classes, 
    box_variance=[0.1, 0.1, 0.2, 0.2]
)
concrete_fn_with_nms = decoder_module.decode_preds_lite.get_concrete_function()

# Export SavedModel
tf.saved_model.save(
    decoder_module,
    os.path.join(model_dir, 'detection_module'),
    signatures=concrete_fn_with_nms)

# Convert SavedModel
converter = tf.lite.TFLiteConverter.from_saved_model(
    saved_model_dir=os.path.join(model_dir, 'detection_module'),
)
converter.allow_custom_ops = True
tflite_pred_decoder = converter.convert()
with open(tflite_postprocess_path, 'wb') as f:
    f.write(tflite_pred_decoder)

We can test the module the same way as before. Note: the order of input and output tensors is a bit strange.

image_shape = [HEIGHT, WIDTH, 3]

anchor_generator = kcv.models.RetinaNet.default_anchor_generator("center_yxhw")
anchors = anchor_generator(image_shape=image_shape)
anchors = ops.concatenate([a for a in anchors.values()], axis=0).numpy()
box_pred = predictions['box']
cls_pred = predictions['classification']

print("anchors: ", anchors.shape)   # anchors:  (76725, 4)
print("boxes: ", box_pred.shape)    # boxes:    (1, 76725, 4)
print("scores: ", cls_pred.shape)   # scores:   (1, 76725, 4)

interpreter = tf.lite.Interpreter(tflite_postprocess_path)
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
signature_list = interpreter.get_signature_list()
print(input_details)
print(output_details)

interpreter.resize_tensor_input(input_details[0]['index'], box_pred.shape)
interpreter.resize_tensor_input(input_details[1]['index'], anchors.shape)
interpreter.resize_tensor_input(input_details[2]['index'], cls_pred.shape)
interpreter.allocate_tensors()
interpreter.set_tensor(input_details[0]['index'], box_pred)
interpreter.set_tensor(input_details[1]['index'], anchors)
interpreter.set_tensor(input_details[2]['index'], cls_pred)
interpreter.invoke()

# # Retrieve outputs:
final_boxes = interpreter.get_tensor(output_details[1]['index'])
final_scores = interpreter.get_tensor(output_details[0]['index'])
final_classes = interpreter.get_tensor(output_details[3]['index'])
num_detections = interpreter.get_tensor(output_details[2]['index'])
n = int(num_detections[0])
print("final_boxes: ", final_boxes[0, :n])
print("final_scores: ", final_scores[0, :n])
print("final_classes: ", final_classes[0, :n])
print("num_detections: ", n)

It is worth noting that TFLite's custom NMS operation does the box-decoding for you (See here). So far I have not found any documentation about this op. It would have made thinks much easier.