Torch Dynamo graph tracing error when meeting tensor slicing operation

24 views Asked by At

I’m using torch 2.2.1 stable to do QAT and encounter the described problem when using torch Dynamo for graph capture. During the normal forward loop, the tensor x1’s shape is [1, 16, 210, 348], while during trace, the shape is [1, 16, 8, 348]. No idea about what causes that. The following are some details. Any help, please~

  1. model forward code snippet:
def forward(self, inputs):
        x = self.aggregator(inputs)
        x1, x2 = x[:, :self.dim_extract], x[:, self.dim_extract:]
        x1 = self.extractor_fraction_1(x1)
        x1 = self.extractor_fraction_2(x1)
        x = torch.cat([x1, x2], dim=1) # concat along channel dimension
        return x
  1. error detail:
File "/mnt/afs/user/jarvis/miniconda3/envs/ntire2024/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 765, in conv
    conv_backend = torch._C._select_conv_backend(**kwargs)
torch._dynamo.exc.TorchRuntimeError: Failed running call_module L__self___mpfd_blocks_0_out_1_partial_conv_extractor_fraction_1(*(FakeTensor(..., size=(16, 8, 348), grad_fn=<SliceBackward0>),), **{}):
Given groups=1, weight of size [8, 8, 5, 5], expected input[1, 16, 8, 348] to have 8 channels, but got 16 channels instead

from user code:
   File "/mnt/afs/user/jarvis/projects/efficient-sr/NTIRE2024_ESR_Challenge/NTIRE2024_ESR/src/model/fan.py", line 247, in forward
    x_forward = mpfd_block(x_forward)
  File "/mnt/afs/user/jarvis/miniconda3/envs/ntire2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/user/jarvis/projects/efficient-sr/NTIRE2024_ESR_Challenge/NTIRE2024_ESR/src/model/fan.py", line 189, in forward
    out_1 = self.out_1(inputs)
  File "/mnt/afs/user/jarvis/miniconda3/envs/ntire2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/user/jarvis/projects/efficient-sr/NTIRE2024_ESR_Challenge/NTIRE2024_ESR/src/model/fan.py", line 113, in forward
    x = self.partial_conv(inputs)
  File "/mnt/afs/user/jarvis/miniconda3/envs/ntire2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/user/jarvis/projects/efficient-sr/NTIRE2024_ESR_Challenge/NTIRE2024_ESR/src/model/fan.py", line 84, in forward
    x1 = self.extractor_fraction_1(x1)

[2024-03-09 18:14:22,879] torch._dynamo.utils: [INFO] TorchDynamo compilation metrics:
[2024-03-09 18:14:22,879] torch._dynamo.utils: [INFO] Function                           Runtimes (s)
[2024-03-09 18:14:22,879] torch._dynamo.utils: [INFO] -------------------------------  --------------
[2024-03-09 18:14:22,879] torch._dynamo.utils: [INFO] _compile.<locals>.compile_inner               0

I have tried substituting the NumPy-style slicing operation ':' with the torch.narrow function, but the error message remains unchanged.

0

There are 0 answers