How can I get the associated tensor from a Torch FX Graph Node?

1k views Asked by At

I want to be able to get all the operations that occur within a torch module, along with how they are parameterized. To do this, I first made a torch.fx.Tracer that disables leaf nodes so that I can get the graph without call_modules:

class MyTracer(torch.fx.Tracer):
    def is_leaf_module(self, m, module_qualified_name):
        return False

I also have a basic module that I am working with:

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3,3,3)
        
    def forward(self, x):
        y1 = self.conv(x)
        y = torch.relu(y1)
        y = y + y1
        y = torch.relu(y)
        return y

I construct an instance of the module like so and trace it:

m = MyModule()
graph = MyTracer().trace(m)
graph.print_tabular()

which gives:

opcode         name         target                                                     args                                                    kwargs
-------------  -----------  ---------------------------------------------------------  ------------------------------------------------------  --------
placeholder    x            x                                                          ()                                                      {}
get_attr       conv_weight  conv.weight                                                ()                                                      {}
get_attr       conv_bias    conv.bias                                                  ()                                                      {}
call_function  conv2d       <built-in method conv2d of type object at 0x7f99b6a0a1c0>  (x, conv_weight, conv_bias, (1, 1), (0, 0), (1, 1), 1)  {}
call_function  relu         <built-in method relu of type object at 0x7f99b6a0a1c0>    (conv2d,)                                               {}
call_function  add          <built-in function add>                                    (relu, conv2d)                                          {}
call_function  relu_1       <built-in method relu of type object at 0x7f99b6a0a1c0>    (add,)                                                  {}
output         output       output                                                     (relu_1,)                                               {}

How do I actually get the associated parameters conv_weight and conv_bias without accessing them directly in the model (via m.conv.weight or m.conv.bias)?

1

There are 1 answers

0
iHowell On

After additional searching and outside assistance, I was shown the Interpreter pattern: https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern This pattern allows you to actually see the nodes while executing the graph. So, I built this small interpreter which prints out Conv2D information:

class MyInterpreter(fx.Interpreter):
    def call_function(self, target, args, kwargs):
        if target == torch.conv2d:
            print('CONV2D')
            print('kernel', args[1].shape)
            print('bias', args[2].shape)
        return super().call_function(target, args, kwargs)

gm = torch.fx.GraphModule(m, graph)
MyInterpreter(gm).run(torch.randn((3,3,3,3))

yields:

CONV2D
kernel torch.Size([3, 3, 3, 3])
bias torch.Size([3])