I'm using forward hooks to extract layer values from a pre-trained CNN and use them as features for my model. I also want to use torchscript for inference. The problem is that when I try to export any other method than 'forward' I get an error that 'forward' is missing for the registered forward hooks. I have a minimal example:
from typing import Iterable, Callable, Tuple
from torch import Tensor, nn, ones, jit, empty
from torchvision.models import resnet50
class FeatureExtractor(nn.Module):
def __init__(self, model: nn.Module, layers: Iterable[str]):
super().__init__()
self.model = model
self.layers = layers
self.hooks = []
for layer_id in layers:
layer = dict([*self.model.named_modules()])[layer_id]
hook = layer.register_forward_hook(self.save_outputs_hook(layer_id))
self.hooks.append(hook)
def save_outputs_hook(self, layer_id: str) -> Callable:
def fn(_, input: Tuple[Tensor], output):
print('Hi')
return fn
def forward(self, x: Tensor):
return self.model(x)
@jit.export
def predict(self, x: Tensor):
return self.model(x)
if __name__ == '__main__':
dummy_input = ones(10, 3, 224, 224)
resnet_features = FeatureExtractor(resnet50(), layers=["layer4", "avgpool"])
features = resnet_features(dummy_input)
script = jit.trace(resnet_features, dummy_input)
This fails with:
RuntimeError: Couldn't find method: 'forward' on class: '__torch__.torch.nn.modules.container.___torch_mangle_141.Sequential (of Python compilation unit at: 0x7fdc5a676da8)'
If I deregister the hooks or export forward instead of predict this of course runs without problem. Ist there any way to make 'forward' mandatory to integrate for jit so it will be seen by the hooks?
Use
jit.script(resnet_features)
instead ofjit.trace(resnet_features, dummy_input)
and it should work.