I want to be able to get all the operations that occur within a torch module, along with how they are parameterized. To do this, I first made a torch.fx.Tracer
that disables leaf nodes so that I can get the graph without call_module
s:
class MyTracer(torch.fx.Tracer):
def is_leaf_module(self, m, module_qualified_name):
return False
I also have a basic module that I am working with:
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3,3,3)
def forward(self, x):
y1 = self.conv(x)
y = torch.relu(y1)
y = y + y1
y = torch.relu(y)
return y
I construct an instance of the module like so and trace it:
m = MyModule()
graph = MyTracer().trace(m)
graph.print_tabular()
which gives:
opcode name target args kwargs
------------- ----------- --------------------------------------------------------- ------------------------------------------------------ --------
placeholder x x () {}
get_attr conv_weight conv.weight () {}
get_attr conv_bias conv.bias () {}
call_function conv2d <built-in method conv2d of type object at 0x7f99b6a0a1c0> (x, conv_weight, conv_bias, (1, 1), (0, 0), (1, 1), 1) {}
call_function relu <built-in method relu of type object at 0x7f99b6a0a1c0> (conv2d,) {}
call_function add <built-in function add> (relu, conv2d) {}
call_function relu_1 <built-in method relu of type object at 0x7f99b6a0a1c0> (add,) {}
output output output (relu_1,) {}
How do I actually get the associated parameters conv_weight
and conv_bias
without accessing them directly in the model (via m.conv.weight
or m.conv.bias
)?
After additional searching and outside assistance, I was shown the Interpreter pattern: https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern This pattern allows you to actually see the nodes while executing the graph. So, I built this small interpreter which prints out Conv2D information:
yields: