We have a custom torch.autograd.Function
z(x, t)
which computes an output y
in a way not amenable to direct automatic differentiation, and have computed the Jacobian of the operation with respect to its inputs x
and t
, so we can implement the backward
method.
However, the operation involves making several internal calls to a neural network, which we have implemented for now as a stack of torch.nn.Linear
objects, wrapped in net
, a torch.nn.Module
. Mathematically, these are parameterized by t
.
Is there any way that we can have net
itself be an input to the forward
method of z
? Then, we would return from our backward
the list of products of the upstream gradient Dy
and parameter Jacobia dydt_i
, one for each of the parameters ti
that are children of net
(in addition to Dy*dydx
, although x is data and does not need gradient accumulation).
Or do we really instead need to take t
(actually a list of individual t_i
), and reconstruct internally in z.forward
the actions of all the Linear
layers in net
?
I'm doing something similar, where the static restrictions on PyTorch functions are cumbersome. Similar in spirit to trialNerror's answer, I instead keep the PyTorch function methods static and pass in functions for them to use, which gets around the issues with making the functor non-static:
Passing in the backwards function every time gets annoying, so I usually wrap it:
This way, you can write the grad function somewhere where it has access to what it needs: no need to pass
net
.