I want to learn how to convert pytorch model into TorchScript. To do that I have to define a torch.nn.Module class that wraps the model first.
I use HuggingFace Diffusers or Transformers class to wrap the models and convert into TorchScript before. I want to know how to define the wrapper class myself. If I have only the downloaded pytorch model, is it possible to define a wrapper class? Or is there anything I have to know?
Below is my code for a downloaded pre-trained model.
import torch
PATH = 'model.pth'
pretrained_dict = torch.load(PATH)
for key in list(pretrained_dict.keys()):
print(key)
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return 0
model = MyModel()
model.load_state_dict(pretrained_dict)
model.eval()
example_input = torch.rand(1, 3, 224, 224)
torch_script = torch.jit.trace(model, example_input)
output:
tok_embeddings.weight
norm.weight
output.weight
layers.0.attention.wq.weight
layers.0.attention.wk.weight
layers.0.attention.wv.weight
layers.0.attention.wo.weight
layers.0.feed_forward.w1.weight
layers.0.feed_forward.w2.weight
layers.0.feed_forward.w3.weight
layers.0.attention_norm.weight
layers.0.ffn_norm.weight
layers.1.attention.wq.weight
layers.1.attention.wk.weight
layers.1.attention.wv.weight
layers.1.attention.wo.weight
layers.1.feed_forward.w1.weight
layers.1.feed_forward.w2.weight
layers.1.feed_forward.w3.weight
layers.1.attention_norm.weight
layers.1.ffn_norm.weight
layers.2.attention.wq.weight
layers.2.attention.wk.weight
layers.2.attention.wv.weight
layers.2.attention.wo.weight
...
layers.31.feed_forward.w3.weight
layers.31.attention_norm.weight
layers.31.ffn_norm.weight
rope.freqs
--> 17 model.load_state_dict(pretrained_dict)
18 model.eval()
19 example_input = torch.rand(1, 3, 224, 224)
File ~/text-generation-webui-main/installer_files/env/lib/python3.10/site-packages/torch/nn/modules/module.py:2041, in Module.load_state_dict(self, state_dict, strict)
2036 error_msgs.insert(
2037 0, 'Missing key(s) in state_dict: {}. '.format(
2038 ', '.join('"{}"'.format(k) for k in missing_keys)))
2040 if len(error_msgs) > 0:
-> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
2042 self.__class__.__name__, "\n\t".join(error_msgs)))
2043 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for MyModel:
Unexpected key(s) in state_dict: "tok_embeddings.weight", "norm.weight", "output.weight", "layers.0.attention.wq.weight", "layers.0.attention.wk.weight", "layers.0.attention.wv.weight", "layers.0.attention.wo.weight", "layers.0.feed_forward.w1.weight", "layers.0.feed_forward.w2.weight", "layers.....
dict
is a reserved word for the type of dictionaries. You should call your parameters dictionarystate_dict
or something like that instead.