What is the correct way to use a PyTorch Module inside a PyTorch Function?

925 views Asked by At

We have a custom torch.autograd.Function z(x, t) which computes an output y in a way not amenable to direct automatic differentiation, and have computed the Jacobian of the operation with respect to its inputs x and t, so we can implement the backward method.

However, the operation involves making several internal calls to a neural network, which we have implemented for now as a stack of torch.nn.Linear objects, wrapped in net, a torch.nn.Module. Mathematically, these are parameterized by t.

Is there any way that we can have net itself be an input to the forward method of z? Then, we would return from our backward the list of products of the upstream gradient Dy and parameter Jacobia dydt_i, one for each of the parameters ti that are children of net (in addition to Dy*dydx, although x is data and does not need gradient accumulation).

Or do we really instead need to take t (actually a list of individual t_i), and reconstruct internally in z.forward the actions of all the Linear layers in net?

2

There are 2 answers

0
Egan Johnson On

I'm doing something similar, where the static restrictions on PyTorch functions are cumbersome. Similar in spirit to trialNerror's answer, I instead keep the PyTorch function methods static and pass in functions for them to use, which gets around the issues with making the functor non-static:

class NonStaticBackward(Function):
    @staticmethod
    def forward(ctx, backward_fn, input):
        ctx.backward_fn = backward_fn
        # ... do other stuff
        return input

    @staticmethod
    def backward(ctx, grad_output):
        # Call into our non-static backward function

        # Since we passed in the backward function as input, 
        # PyTorch expects a placeholder grad for it. 
        return None, ctx.backward_fn(ctx, grad_output)

Passing in the backwards function every time gets annoying, so I usually wrap it:

def my_non_static_backward(ctx, grad_output):
    print("Hello from backward!")
    return grad_output

my_fn = lambda x: NonStaticBackward.apply(my_non_static_backward, x)

y = my_fn(Tensor([1, 2, 3]))

This way, you can write the grad function somewhere where it has access to what it needs: no need to pass net.

5
trialNerror On

I guess you could create a custom functor that inherits torch.autograd.Function and make the forward and backward methods non-static (i.e remove the @staticmethod in this example so that net could be an attribute of your functor. that would look like

class MyFunctor(torch.nn.autograd.Function):
    def __init(net):
         self.net = net
    
     def forward(ctx, x, t):
         #store x and t in ctx in the way you find useful
         # not sure how t is involved here
         return self.net(x) 

     def backward(ctx, grad):
         # do your backward stuff

net = nn.Sequential(nn.Linear(...), ...)
z = MyFunctor(net)
y = z(x, t)

This will yield a warning that you are using a deprecated legacy way of creating autograd functions (because of the non-static methods), and you need to be extra careful with zeroing the gradients in netafter having backpropagated. So not really convenient, but I am not aware of any better way to have a stateful autograd function.