I have the following custom layer in my Vision Transformer
class DataAugmentation(Layer):
def __init__(self, norm, SIZE):
super(DataAugmentation, self).__init__()
self.norm = norm
self.SIZE = SIZE
self.resize = Resizing(SIZE, SIZE)
self.flip = RandomFlip('horizontal')
self.rotation = RandomRotation(factor=0.02)
self.zoom = RandomZoom(height_factor=0.2, width_factor=0.2)
def call(self, X):
x = self.norm(X)
x = self.resize(x)
x = self.flip(x)
x = self.rotation(x)
x = self.zoom(x)
return x
def get_config(self):
config = super().get_config()
config.update({
"norm": self.norm,
"SIZE": self.SIZE,
})
return config
I have saved the weights after training but whenever I load the weights I have the following error:
File "test_vit.py", line 313, in <module>
best_model = keras.models.load_model("ViT-Model-new.h5")
File "/usr/local/lib/python3.6/dist-packages/keras/saving/save.py", line 201, in load_model
compile)
File "/usr/local/lib/python3.6/dist-packages/keras/saving/hdf5_format.py", line 181, in load_model_from_hdf5
custom_objects=custom_objects)
File "/usr/local/lib/python3.6/dist-packages/keras/saving/model_config.py", line 52, in model_from_config
return deserialize(config, custom_objects=custom_objects)
File "/usr/local/lib/python3.6/dist-packages/keras/layers/serialization.py", line 212, in deserialize
printable_module_name='layer')
File "/usr/local/lib/python3.6/dist-packages/keras/utils/generic_utils.py", line 678, in deserialize_keras_object
list(custom_objects.items())))
File "/usr/local/lib/python3.6/dist-packages/keras/engine/functional.py", line 663, in from_config
config, custom_objects)
File "/usr/local/lib/python3.6/dist-packages/keras/engine/functional.py", line 1273, in reconstruct_from_config
process_layer(layer_data)
File "/usr/local/lib/python3.6/dist-packages/keras/engine/functional.py", line 1255, in process_layer
layer = deserialize_layer(layer_data, custom_objects=custom_objects)
File "/usr/local/lib/python3.6/dist-packages/keras/layers/serialization.py", line 212, in deserialize
printable_module_name='layer')
File "/usr/local/lib/python3.6/dist-packages/keras/utils/generic_utils.py", line 681, in deserialize_keras_object
deserialized_obj = cls.from_config(cls_config)
File "/usr/local/lib/python3.6/dist-packages/keras/engine/base_layer.py", line 748, in from_config
return cls(**config)
TypeError: __init__() got an unexpected keyword argument 'name'
What I tried:
1- I put @tf.keras.utils.register_keras_serializable() before the class definition
2- I loaded the model with the custom object scope
with tf.keras.utils.custom_object_scope({"DataAugmentation": DataAugmentation}):
model = load_model("ViT-Model-new.h5")
For both solutions I have the same error.
My tensorflow version is 2.6.2
Your implementation of the layer is not correct, you need to take keyword arguments (
**kwargs
) in the constructor__init__
and pass them to the superclass: