Wrap a pre-trained pytorch model into torch.nn.Module class

432 views Asked by At

I want to learn how to convert pytorch model into TorchScript. To do that I have to define a torch.nn.Module class that wraps the model first.

I use HuggingFace Diffusers or Transformers class to wrap the models and convert into TorchScript before. I want to know how to define the wrapper class myself. If I have only the downloaded pytorch model, is it possible to define a wrapper class? Or is there anything I have to know?

Below is my code for a downloaded pre-trained model.

import torch

PATH = 'model.pth'
pretrained_dict = torch.load(PATH)

for key in list(pretrained_dict.keys()):
    print(key)

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0

model = MyModel()
model.load_state_dict(pretrained_dict)
model.eval()
example_input = torch.rand(1, 3, 224, 224) 
torch_script = torch.jit.trace(model, example_input)

output:

tok_embeddings.weight
norm.weight
output.weight
layers.0.attention.wq.weight
layers.0.attention.wk.weight
layers.0.attention.wv.weight
layers.0.attention.wo.weight
layers.0.feed_forward.w1.weight
layers.0.feed_forward.w2.weight
layers.0.feed_forward.w3.weight
layers.0.attention_norm.weight
layers.0.ffn_norm.weight
layers.1.attention.wq.weight
layers.1.attention.wk.weight
layers.1.attention.wv.weight
layers.1.attention.wo.weight
layers.1.feed_forward.w1.weight
layers.1.feed_forward.w2.weight
layers.1.feed_forward.w3.weight
layers.1.attention_norm.weight
layers.1.ffn_norm.weight
layers.2.attention.wq.weight
layers.2.attention.wk.weight
layers.2.attention.wv.weight
layers.2.attention.wo.weight
...
layers.31.feed_forward.w3.weight
layers.31.attention_norm.weight
layers.31.ffn_norm.weight
rope.freqs


--> 17 model.load_state_dict(pretrained_dict)
     18 model.eval()
     19 example_input = torch.rand(1, 3, 224, 224) 

File ~/text-generation-webui-main/installer_files/env/lib/python3.10/site-packages/torch/nn/modules/module.py:2041, in Module.load_state_dict(self, state_dict, strict)
   2036         error_msgs.insert(
   2037             0, 'Missing key(s) in state_dict: {}. '.format(
   2038                 ', '.join('"{}"'.format(k) for k in missing_keys)))
   2040 if len(error_msgs) > 0:
-> 2041     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2042                        self.__class__.__name__, "\n\t".join(error_msgs)))
   2043 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for MyModel:
    Unexpected key(s) in state_dict: "tok_embeddings.weight", "norm.weight", "output.weight", "layers.0.attention.wq.weight", "layers.0.attention.wk.weight", "layers.0.attention.wv.weight", "layers.0.attention.wo.weight", "layers.0.feed_forward.w1.weight", "layers.0.feed_forward.w2.weight", "layers.....
3

There are 3 answers

2
Valentin Goldité On

dict is a reserved word for the type of dictionaries. You should call your parameters dictionary state_dict or something like that instead.

0
Ori Yarden PhD On

Like @Valentin Goldité pointed out, it's state_dict.

However, the other issue is that you're re-defining MyModel and then you're trying to load the state_dict into the newly defined MyModel's instance which does not contain any layers, weights, etc., or the "right" forward method (... I'm assuming your actual forward method does not just return 0).

So, we have to import MyModel, instantiate it, and then load the state_dict into the model instance; the rest of what you have is fine:

'''
# Do not re-define MyModel;
class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0
'''
# instead, import MyModel from the file.py in which it is defined;
from MyModel_file_name import MyModel
def load_pytorch_model(PATH, **kwargs):
    '''
    Steps:
        (1): Load model_dict.
        (2): init MyModel(**kwargs).
        (3): Load state_dict into instance of MyModel.
        (4): model.eval() # optional
        (5): model.to(device='cuda:0') # optional
        (6): etc., ...
    
    Returns:
        instance of MyModel with loaded state_dict.
    '''
    def load_pytorch_model_state_dict(PATH):
        return torch.load(PATH)

    model_dict = load_pytorch_model_state_dict(PATH)
    model = MyModel(**kwargs)
    model.load_state_dict(model_dict)
    model.eval()
    #model.to(device='cuda:0')
    return model


model = load_pytorch_model(PATH)

example_input = torch.rand(1, 3, 224, 224)
torch_script = torch.jit.trace(model, example_input)

and then if you print(f'torch_script = {torch_script.__dict__}') you should get something like:

torch_script = {'_non_persistent_buffers_set': set(), '_backward_pre_hooks': OrderedDict(), '_backward_hooks': OrderedDict(), '_is_full_backward_hook': None, '_forward_hooks': OrderedDict(), '_forward_hooks_with_kwargs': OrderedDict(), '_forward_hooks_always_called': OrderedDict(), '_forward_pre_hooks': OrderedDict(), '_forward_pre_hooks_with_kwargs': OrderedDict(), '_state_dict_hooks': OrderedDict(), '_state_dict_pre_hooks': OrderedDict(), '_load_state_dict_pre_hooks': OrderedDict(), '_load_state_dict_post_hooks': OrderedDict(), '_name': 'MultiDepthWiseConvAndFullyConnected', '_actual_script_module': RecursiveScriptModule(
  original_name=MultiDepthWiseConvAndFullyConnected
  (depthwise_layers_3x3): Sequential(
    original_name=Sequential
    (0): Conv2d(original_name=Conv2d)
    (1): BatchNorm2d(original_name=BatchNorm2d)
    (2): CustomLeakyReLU(original_name=CustomLeakyReLU)
    ....
   ...
...
0
will f On

Here's the documentation for Module.load_state_dict(state_dict, strict=True, assign=False):

Copies parameters and buffers from state_dict into this module and its descendants. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Notice that strict is True by default, and there doesn't seem to be any registration of the container's state dict. Might want to make sure your Module container is preconfigured to recognize the state keys, or try setting strict=False when using Module.load_state_dict