I have function which takes 5 values as arguments and returns a scalar. That is a mapping of the form f:R^5 -> R
.
Hench, its Jacobian J
is a matrix with dimension (1x5)
and might as well have been a row vector in the form of a gradient g
.
I can easily compute the Jacobian for a single input of x
using torch.autograd.functional.jacobian
:
J = torch.autograd.functional(func=f, inputs=(x[0], x[1], x[2], x[3], x[4]))
Now, my questions are:
- What is the most efficient way to compute the Jacobian (gradient) multiple times when I have a range of values for one of the arguments?
- I do not need the gradient on it's own for each of the values in the range. But I would like to multiply the Jacobian
J
with a vectorv
and get another vector (elementwise)y=Jv
of same dimention. Can I make the code more performant by using other functions such asjvp
,vjp
, orvmap
? - I've seen then mentioning of Jax other places. Is it better at these of problems?
I've made a small example showing what I am trying to acomplish below - just without using a loop or list comperhension.
import torch
def f(x0, x1, x2, x3, x4):
return x0 ** 2 + x1 ** 3 + x2 ** 4 + x3 ** 5 + x4 ** 6
if __name__ == '__main__':
a = torch.tensor(1.0)
b = torch.tensor(1.0)
c = torch.tensor(1.0)
d = torch.tensor(1.0)
e = torch.tensor(1.0)
g = torch.autograd.functional.jacobian(func=f, inputs=(a, b, c, d, e))
print(g)
""" (tensor(1.), tensor(2.), tensor(3.), tensor(4.), tensor(5.)) """
x0_values = torch.arange(1.0, 10.0, 1.0)
g_list = []
for x0 in x0_values:
g = torch.autograd.functional.jacobian(func=f, inputs=(x0, b, c, d, e))
g_list.append(torch.hstack(g))
J = torch.vstack(g_list)
print(J)
"""
tensor([[2., 3., 4., 5., 6.],
[4., 3., 4., 5., 6.],
[6., 3., 4., 5., 6.],
[8., 3., 4., 5., 6.],
[10., 3., 4., 5., 6.],
[12., 3., 4., 5., 6.],
[14., 3., 4., 5., 6.],
[16., 3., 4., 5., 6.],
[18., 3., 4., 5., 6.]])
"""
Since you tagged
jax
and asked aboutvmap
, here's how you can do the equivalent computation usingjax.vmap
in place of your loops: