Is there any solution for failing to load model in tensorflow2.3?

621 views Asked by At

I try to use tf.keras.models.load_model to load saved model in tensorflow 2.3. However, I got the same error in https://github.com/tensorflow/tensorflow/issues/41535

It seems an important function. But this issue is still not solved. Does anyone know if there is any alternative method to implement the same result?

1

There are 1 answers

3
wangsy On BEST ANSWER

I found an alternative method to load custom model in tensorflow 2.3. You need to do some following changes. I will explain by some code snapshots

  • for __init__() of custom model. Before,

    def __init__(self, mask_ratio=0.1, hyperparam=0.1, **kwargs):
        layers = []
        layer_configs = {}
        if 'layers' in kwargs.keys():
            layer_configs = kwargs['layers']
        for config in layer_configs:
            layer = tf.keras.layers.deserialize(config)
            layers.append(layer)
        super(custom_model, self).__init__(layers)  # custom_model is your custom model class
        self.mask_ratio = mask_ratio
        self.hyperparam = hyperparam
        ...
    

    After,

    def __init__(self, mask_ratio=0.1, hyperparam=0.1, **kwargs):
        super(custom_model, self).__init__()  # custom_model is your custom model class
        self.mask_ratio = mask_ratio
        self.hyperparam = hyperparam
        ...
    
  • define two functions in your custom model class

    def get_config(self):
        config = {
            'mask_ratio': self.mask_ratio,
            'hyperparam': self.hyperparam
        }
        base_config = super(custom_model, self).get_config()
        return dict(list(config.items()) + list(base_config.items()))
    @classmethod
    def from_config(cls, config):
        #config = cls().get_config()
        return cls(**config)
    
  • After finishing training, save model using 'h5' format

    model.save(file_path, save_format='h5')
    
  • Finally, load model as following codes,

    model = tf.keras.models.load_model(model_path, compile=False, custom_objects={'custom_model': custom_model})