I want to implement a tf.Module
for decoding box predictions and applying NonMaxSuppression that is convertible to tflite.
This implementation includes elements from here.
It also follows this guide on operation fusion.
This is my Code:
def decode_predictions_lite(anchor_boxes: tf.Tensor, box_pred: tf.Tensor, cls_pred: tf.Tensor, variance: tf.Tensor, image_shape):
"""a concrete function that decodes box and class predictions and applies NonMaxSuppression
Args:
anchor_boxes: tf.Tensor of shape [1, N, 4] representing anchor_boxes in the 'center_yxhw' format
box_pred: tf.Tensor of shape [b, N, 4] representing encoded box predictions from the model's output
cls_pred: tf.Tensor of shape [b, N, num_classes] representing the class logits from the model's output
variance: tf.Tensor of shape [4] representing the box variance that was used when encoding the boxes
image_shape: tf.Tensor of shape [2] representing the height and width of the input image
Returns:
pass
Note:
- N is the number of anchor boxes. All input tensors are of dtype float32
- b is the batch_size. For now, only a batch_size of 1 is supported.
"""
scores = tf.sigmoid(cls_pred)
boxes = box_pred*variance
decoded_boxes = tf.concat(
[
boxes[..., :2] * anchor_boxes[..., :2] + anchor_boxes[..., :2],
tf.math.exp(boxes[..., :2]) * anchor_boxes[..., :2]
], axis=-1
)
# Normalize anchor coordinates for TFLite's NMS operation.
normalize_factor = tf.tile(image_shape, [2])
anchor_boxes = anchor_boxes / normalize_factor
anchor_boxes = tf.squeeze(anchor_boxes) # squeeze so the anchor_boxes are of shape (N, 4)
# normalize box coordinates for TFLite's NMS operation
decoded_boxes_rel = decoded_boxes / normalize_factor
def get_implements_signature():
implements_signature = ' '.join([
'name: "%s"' % 'TFLite_Detection_PostProcess',
'attr { key: "max_detections" value { i: %d } }' % 100,
'attr { key: "max_classes_per_detection" value { i: %d } }' % 1,
'attr { key: "detections_per_class" value { i: %d } }' % 5,
'attr { key: "use_regular_nms" value { b: %s } }' % "false", # Lower
'attr { key: "nms_score_threshold" value { f: %f } }' % 0.1,
'attr { key: "nms_iou_threshold" value { f: %f } }' % 0.5,
'attr { key: "y_scale" value { f: %f } }' % 1.0,
'attr { key: "x_scale" value { f: %f } }' % 1.0,
'attr { key: "h_scale" value { f: %f } }' % 1.0,
'attr { key: "w_scale" value { f: %f } }' % 1.0,
'attr { key: "num_classes" value { i: %d } }' % num_classes,
])
return implements_signature
@tf.function(experimental_implements=get_implements_signature())
def dummy_postprocessing_nms(input_boxes, input_scores, input_anchors):
boxes = tf.constant(0.0, dtype=tf.float32, name='boxes')
scores = tf.constant(0.0, dtype=tf.float32, name='scores')
classes = tf.constant(0.0, shape=(1, 100), dtype=tf.float32, name='classes')
num_detections = tf.constant(0.0, dtype=tf.float32, name='num_detections')
return boxes, classes, scores, num_detections
return decoded_boxes, dummy_postprocessing_nms(decoded_boxes_rel, scores, anchor_boxes)
class PredictionDecoderLite(tf.Module):
def __init__(self):
super(PredictionDecoderLite, self).__init__()
@tf.function(input_signature=[
tf.TensorSpec(shape=[1, None, 4], dtype=tf.float32), # anchor_boxes
tf.TensorSpec(shape=[1, None, 4], dtype=tf.float32), # box_pred
tf.TensorSpec(shape=[1, None, None], dtype=tf.float32), # cls_pred
tf.TensorSpec(shape=[4], dtype=tf.float32), # variance
tf.TensorSpec(shape=[2], dtype=tf.float32) # image_shape
])
def decode_preds_lite(self, anchor_boxes, box_pred, cls_pred, variance, image_shape):
return decode_predictions_lite(anchor_boxes, box_pred, cls_pred, variance, image_shape)
decoder_module = PredictionDecoderLite()
concrete_fn_with_nms = decoder_module.decode_preds_lite.get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_fn_with_nms], trackable_obj=decoder_module)
tflite_pred_decoder = converter.convert()
with open(tflite_postprocess_path, 'wb') as f:
f.write(tflite_pred_decoder)
The issue is that the converter does not rewrites the dummy function to TFLite's custom NMS operation. I tested the tflite module with the following code:
image_shape = [HEIGHT, WIDTH, 3]
variance = tf.constant([1.0, 1.0, 1.0, 1.0])
anchor_generator = kcv.models.RetinaNet.default_anchor_generator(bounding_box_format)
anchors = anchor_generator(image_shape=image_shape)
anchors = ops.concatenate([a for a in anchors.values()], axis=0)
anchors = tf.expand_dims(anchors, axis=0)
print("anchors: ", anchors.shape) # anchors: (1, 76725, 4)
print("boxes: ", predictions['box'].shape) # boxes: (1, 76725, 4)
print("scores: ", predictions['classification'].shape) # scores: (1, 76725, 4)
interpreter = tf.lite.Interpreter(tflite_postprocess_path)
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.resize_tensor_input(input_details[0]['index'], anchors.shape)
interpreter.resize_tensor_input(input_details[1]['index'], predictions['box'].shape)
interpreter.resize_tensor_input(input_details[2]['index'], predictions['classification'].shape)
interpreter.allocate_tensors()
interpreter.set_tensor(input_details[0]['index'], anchors.numpy())
interpreter.set_tensor(input_details[1]['index'], predictions['box'])
interpreter.set_tensor(input_details[2]['index'], predictions['classification'])
interpreter.set_tensor(input_details[3]['index'], variance.numpy())
interpreter.set_tensor(input_details[4]['index'], np.array(image_shape[:2], dtype='float32'))
interpreter.invoke()
# Retrieve outputs:
decoded_boxes = interpreter.get_tensor(output_details[0]['index'])
final_boxes = interpreter.get_tensor(output_details[1]['index'])
final_scores = interpreter.get_tensor(output_details[2]['index'])
final_classes = interpreter.get_tensor(output_details[3]['index'])
num_detections = interpreter.get_tensor(output_details[4]['index'])
I received zeros from my dummy function and no post-processed boxes. I don't know how to get more information out of the converter to see, what is wrong. Does anyone have an idea, how to proceed?
So... After almost two days of debugging, I finally found a workaround. Instead of using
tf.lite.TFLiteConverter.from_concrete_functions()
we first save the module and then usetf.lite.TFLiteConverter.from_saved_model()
. In the tensorflow model garden, they do it the same way (See here).There were also some other issues in the code above so here is an updated version. You can directly use the box and class predictions from a
keras_cv.models.RetinaNet
that was converted to tflite.We can test the module the same way as before. Note: the order of input and output tensors is a bit strange.
It is worth noting that TFLite's custom NMS operation does the box-decoding for you (See here). So far I have not found any documentation about this op. It would have made thinks much easier.