Problems with saving keras model (.h5) to TensorFlow SavedModel format (KeyError: 'inputs')

109 views Asked by At

In TF 2.4.0, I'm training a Keras RetinaNet model (code from https://github.com/fizyr/keras-retinanet). After training I want to convert model.h5 to TensorFlow SavedModel format. But I have an error KeyError: 'inputs'.

Code for convert:

# Import libraries
import tensorflow as tf
from tensorflow import keras
from keras_retinanet import models
from keras_retinanet.models import load_model

# Load the model
model = load_model("model.h5", backbone_name="resnet50")

# Save the model
model.save('model_tf', save_format='tf')

Error KeyError: 'inputs': Error KeyError: 'inputs'

Traceback (most recent call last):
  File "convert_h5_2_pb.py", line 11, in <module>
    model.save('model_tf', save_format='tf') 
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 2001, in save
    save.save_model(self, filepath, overwrite, include_optimizer, save_format,
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py", line 156, in save_model
    saved_model_save.save(model, filepath, overwrite, include_optimizer,
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save.py", line 89, in save
    save_lib.save(model, filepath, signatures, options)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py", line 1032, in save
    _, exported_graph, object_saver, asset_info = _build_meta_graph(
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py", line 1198, in _build_meta_graph
    return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py", line 1132, in _build_meta_graph_impl
    signatures = signature_serialization.find_function_to_export(
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/saved_model/signature_serialization.py", line 75, in find_function_to_export
    functions = saveable_view.list_functions(saveable_view.root)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py", line 150, in list_functions
    obj_functions = obj._list_functions_for_serialization(  # pylint: disable=protected-access
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 2612, in _list_functions_for_serialization
    functions = super(
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 3086, in _list_functions_for_serialization
    return (self._trackable_saved_model_saver
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/base_serialization.py", line 94, in list_functions_for_serialization
    fns = self.functions_to_serialize(serialization_cache)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 78, in functions_to_serialize
    return (self._get_serialized_attributes(
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 94, in _get_serialized_attributes
    object_dict, function_dict = self._get_serialized_attributes_internal(
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/model_serialization.py", line 56, in _get_serialized_attributes_internal
    super(ModelSavedModelSaver, self)._get_serialized_attributes_internal(
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 104, in _get_serialized_attributes_internal
    functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 155, in wrap_layer_functions
    original_fns = _replace_child_layer_functions(layer, serialization_cache)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 273, in _replace_child_layer_functions
    child_layer._trackable_saved_model_saver._get_serialized_attributes(
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 94, in _get_serialized_attributes
    object_dict, function_dict = self._get_serialized_attributes_internal(
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 104, in _get_serialized_attributes_internal
    functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 163, in wrap_layer_functions
    call_fn_with_losses = call_collection.add_function(
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 505, in add_function
    self.add_trace(*self._input_signature)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 420, in add_trace
    trace_with_training(True)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 418, in trace_with_training
    fn.get_concrete_function(*args, **kwargs)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 550, in get_concrete_function
    return super(LayerCall, self).get_concrete_function(*args, **kwargs)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 1299, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 1205, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 725, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2969, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3361, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3196, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 990, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 634, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 515, in wrapper
    inputs = call_collection.get_input_arg_value(args, kwargs)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 454, in get_input_arg_value
    return self.layer._get_call_arg_value(  # pylint: disable=protected-access
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 2603, in _get_call_arg_value
    return args_dict[arg_name]
KeyError: 'inputs'

What should I do to fix this?

I surfed the Internet and tried to change the code, but it didn't help.

2

There are 2 answers

0
Egorundel On BEST ANSWER

The solution has been found!

It was necessary to change the file in the pip installed package keras_resnet/layers/_batch_normalization.py the lines of code that are described here: github.com/broadinstitute/keras-resnet/commit/73c50f

5
Nathaldien On

First thing first, check the version of keras that you're using. You write that you're using tensorflow 2.4 but the last release is 2.14. Then in the repo it says "This project should work with keras 2.4 and tensorflow 2.3.0, newer versions might break support.", ok for keras, but about Tensorflow I have some doubts.

Try to save it as inference model (striping all the part used for the training should make it more save-friendly), i suggest to convert it in this way:

from keras_retinanet import models

# Convert the model to an inference model
inference_model = models.convert_model(model)

# Save the model in SavedModel format
inference_model.save('model_tf')