Loading model with a custom layer error Tensorflow 2.6.2

86 views Asked by At

I have the following custom layer in my Vision Transformer

class DataAugmentation(Layer):
def __init__(self, norm, SIZE):
    super(DataAugmentation, self).__init__()
    self.norm = norm
    self.SIZE = SIZE
    self.resize = Resizing(SIZE, SIZE)
    self.flip = RandomFlip('horizontal')
    self.rotation = RandomRotation(factor=0.02)
    self.zoom = RandomZoom(height_factor=0.2, width_factor=0.2)
def call(self, X):
    x = self.norm(X)
    x = self.resize(x)
    x = self.flip(x)
    x = self.rotation(x)
    x = self.zoom(x)
    return x

def get_config(self):
    config = super().get_config()
    config.update({
    "norm": self.norm,
    "SIZE": self.SIZE,
    })
    return config

I have saved the weights after training but whenever I load the weights I have the following error:

    File "test_vit.py", line 313, in <module>
    best_model = keras.models.load_model("ViT-Model-new.h5")
  File "/usr/local/lib/python3.6/dist-packages/keras/saving/save.py", line 201, in load_model
    compile)
  File "/usr/local/lib/python3.6/dist-packages/keras/saving/hdf5_format.py", line 181, in load_model_from_hdf5
    custom_objects=custom_objects)
  File "/usr/local/lib/python3.6/dist-packages/keras/saving/model_config.py", line 52, in model_from_config
    return deserialize(config, custom_objects=custom_objects)
  File "/usr/local/lib/python3.6/dist-packages/keras/layers/serialization.py", line 212, in deserialize
    printable_module_name='layer')
  File "/usr/local/lib/python3.6/dist-packages/keras/utils/generic_utils.py", line 678, in deserialize_keras_object
    list(custom_objects.items())))
  File "/usr/local/lib/python3.6/dist-packages/keras/engine/functional.py", line 663, in from_config
    config, custom_objects)
  File "/usr/local/lib/python3.6/dist-packages/keras/engine/functional.py", line 1273, in reconstruct_from_config
    process_layer(layer_data)
  File "/usr/local/lib/python3.6/dist-packages/keras/engine/functional.py", line 1255, in process_layer
    layer = deserialize_layer(layer_data, custom_objects=custom_objects)
  File "/usr/local/lib/python3.6/dist-packages/keras/layers/serialization.py", line 212, in deserialize
    printable_module_name='layer')
  File "/usr/local/lib/python3.6/dist-packages/keras/utils/generic_utils.py", line 681, in deserialize_keras_object
    deserialized_obj = cls.from_config(cls_config)
  File "/usr/local/lib/python3.6/dist-packages/keras/engine/base_layer.py", line 748, in from_config
    return cls(**config)
TypeError: __init__() got an unexpected keyword argument 'name'

What I tried:

1- I put @tf.keras.utils.register_keras_serializable() before the class definition

2- I loaded the model with the custom object scope

with tf.keras.utils.custom_object_scope({"DataAugmentation": DataAugmentation}):
    model = load_model("ViT-Model-new.h5")

For both solutions I have the same error.

My tensorflow version is 2.6.2

1

There are 1 answers

0
Dr. Snoopy On BEST ANSWER

Your implementation of the layer is not correct, you need to take keyword arguments (**kwargs) in the constructor __init__ and pass them to the superclass:

class DataAugmentation(Layer):
    def __init__(self, norm, SIZE, **kwargs):
        super(DataAugmentation, self).__init__(**kwargs)
        self.norm = norm
        self.SIZE = SIZE
        self.resize = Resizing(SIZE, SIZE)
        self.flip = RandomFlip('horizontal')
        self.rotation = RandomRotation(factor=0.02)
        self.zoom = RandomZoom(height_factor=0.2, width_factor=0.2)