ONNX export failed on an operator with unrecognized namespace 'torch_scatter::scatter_max'

824 views Asked by At

I have a pytorch network like this

import torch.nn as nn
import torch_scatter.scatter_max

class DGCN(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        ...
        torch_scatter.scatter_max(x, index, dim=0)
        ...

But when i want export my model to onnx, i face this error:

  ...
  File "/usr/local/lib/python3.9/dist-packages/torch/onnx/utils.py", line 1115, in _model_to_graph
    graph = _optimize_graph(
  File "/usr/local/lib/python3.9/dist-packages/torch/onnx/utils.py", line 663, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
  File "/usr/local/lib/python3.9/dist-packages/torch/onnx/utils.py", line 1909, in _run_symbolic_function
    raise errors.UnsupportedOperatorError(
torch.onnx.errors.UnsupportedOperatorError: ONNX export failed on an operator with unrecognized namespace 'torch_scatter::scatter_max'. 
If you are trying to export a custom operator, make sure you registered it with the right domain and version.

So, How i can do this exactly?

1

There are 1 answers

0
Philip Lassen On

The Max reduction attribute for Scatter was recently added in ONNX opset 18 PR.

Unfortunately the pytorch to onnx exporters haven't been updated accordingly.

One approach you could take is to make some changes to the Pytorch repository in a fork. You could add the following lines to symbolic_opset18.py

from torch.onnx import _type_utils, symbolic_helper
from torch.onnx._internal import _beartype, jit_utils, registration

@_onnx_symbolic("aten::scatter_max")
@symbolic_helper.parse_args("v", "i", "v", "v")
@_beartype.beartype
def scatter_max(g: jit_utils.GraphContext, self, dim, index, src):
    if symbolic_helper.is_caffe2_aten_fallback():
        return g.at("scatter", self, dim, index, src, overload_name="src")


    src_type = _type_utils.JitScalarType.from_value(
        src, _type_utils.JitScalarType.UNDEFINED
    )
    src_sizes = symbolic_helper._get_tensor_sizes(src)
    index_sizes = symbolic_helper._get_tensor_sizes(index)


    if len(src_sizes) != len(index_sizes):
        return symbolic_helper._unimplemented(
            "scatter_max",
            f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})",
        )


    # PyTorch only allows index shape <= src shape, so we can only consider
    # taking index as subset size to src, like PyTorch does. When sizes for src
    # and index are not matched or there are dynamic axes, we take index shape to
    # slice src to accommodate.
    if src_sizes != index_sizes or None in index_sizes:
        adjusted_shape = g.op("Shape", index)
        starts = g.op("Constant", value_t=torch.tensor([0] * len(index_sizes)))
        src = g.op("Slice", src, starts, adjusted_shape)


    src = symbolic_helper._maybe_get_scalar(src)
    if symbolic_helper._is_value(src):
        return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="max")
    else:
        # Check if scalar "src" has same type as self (PyTorch allows different
        # type for scalar src (but not when src is tensor)). If not, insert Cast node.
        if _type_utils.JitScalarType.from_value(self) != src_type:
            src = g.op(
                "Cast",
                src,
                to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
            )


        return g.op(
            "ScatterElements",
            self,
            index,
            src,
            axis_i=dim,
            reduction_s="max",
        )

Note that this code was just shamelessly taken from symbolic_opset16.py, where the export of scatter_add is implemented.