Torchscript for prediction is missing 'forward' when using forward hooks

597 views Asked by At

I'm using forward hooks to extract layer values from a pre-trained CNN and use them as features for my model. I also want to use torchscript for inference. The problem is that when I try to export any other method than 'forward' I get an error that 'forward' is missing for the registered forward hooks. I have a minimal example:

from typing import Iterable, Callable, Tuple
from torch import Tensor, nn, ones, jit, empty
from torchvision.models import resnet50

class FeatureExtractor(nn.Module):
    def __init__(self, model: nn.Module, layers: Iterable[str]):
        super().__init__()
        self.model = model
        self.layers = layers
        self.hooks = []

        for layer_id in layers:
            layer = dict([*self.model.named_modules()])[layer_id]
            hook = layer.register_forward_hook(self.save_outputs_hook(layer_id))
            self.hooks.append(hook)

    def save_outputs_hook(self, layer_id: str) -> Callable:
        def fn(_, input: Tuple[Tensor], output):
            print('Hi')
        return fn

    def forward(self, x: Tensor):
        return self.model(x)

    @jit.export
    def predict(self, x: Tensor):
        return self.model(x)

if __name__ == '__main__':
    dummy_input = ones(10, 3, 224, 224)
    resnet_features = FeatureExtractor(resnet50(), layers=["layer4", "avgpool"])
    features = resnet_features(dummy_input)
    script = jit.trace(resnet_features, dummy_input)

This fails with:

RuntimeError: Couldn't find method: 'forward' on class: '__torch__.torch.nn.modules.container.___torch_mangle_141.Sequential (of Python compilation unit at: 0x7fdc5a676da8)'

If I deregister the hooks or export forward instead of predict this of course runs without problem. Ist there any way to make 'forward' mandatory to integrate for jit so it will be seen by the hooks?

1

There are 1 answers

2
Juliano S Assine On

Use jit.script(resnet_features) instead of jit.trace(resnet_features, dummy_input) and it should work.