I have function which takes 5 values as arguments and returns a scalar. That is a mapping of the form f:R^5 -> R.

Hench, its Jacobian J is a matrix with dimension (1x5) and might as well have been a row vector in the form of a gradient g.

I can easily compute the Jacobian for a single input of x using torch.autograd.functional.jacobian:

J = torch.autograd.functional(func=f, inputs=(x[0], x[1], x[2], x[3], x[4]))

Now, my questions are:

  1. What is the most efficient way to compute the Jacobian (gradient) multiple times when I have a range of values for one of the arguments?
  2. I do not need the gradient on it's own for each of the values in the range. But I would like to multiply the Jacobian J with a vector v and get another vector (elementwise) y=Jv of same dimention. Can I make the code more performant by using other functions such as jvp, vjp, or vmap?
  3. I've seen then mentioning of Jax other places. Is it better at these of problems?

I've made a small example showing what I am trying to acomplish below - just without using a loop or list comperhension.

import torch

def f(x0, x1, x2, x3, x4):
        return x0 ** 2 + x1 ** 3 + x2 ** 4 + x3 ** 5 + x4 ** 6

if __name__ == '__main__':
    a = torch.tensor(1.0)
    b = torch.tensor(1.0)
    c = torch.tensor(1.0)
    d = torch.tensor(1.0)
    e = torch.tensor(1.0)

    g = torch.autograd.functional.jacobian(func=f, inputs=(a, b, c, d, e))
    print(g)  
    """ (tensor(1.), tensor(2.), tensor(3.), tensor(4.), tensor(5.)) """

    x0_values = torch.arange(1.0, 10.0, 1.0)

    g_list = []
    for x0 in x0_values:
        g = torch.autograd.functional.jacobian(func=f, inputs=(x0, b, c, d, e))
        g_list.append(torch.hstack(g))

    J = torch.vstack(g_list)
    print(J)

    """
    tensor([[2., 3., 4., 5., 6.],
            [4., 3., 4., 5., 6.],
            [6., 3., 4., 5., 6.],
            [8., 3., 4., 5., 6.],
            [10., 3., 4., 5., 6.],
            [12., 3., 4., 5., 6.],
            [14., 3., 4., 5., 6.],
            [16., 3., 4., 5., 6.],
            [18., 3., 4., 5., 6.]])
    """
2

There are 2 answers

0
jakevdp On

Since you tagged jax and asked about vmap, here's how you can do the equivalent computation using jax.vmap in place of your loops:

import jax
import jax.numpy as jnp

def f(x0, x1, x2, x3, x4):
  return x0 ** 2 + x1 ** 3 + x2 ** 4 + x3 ** 5 + x4 ** 6

a = b = c = d = e = jnp.float32(1.0)

g = jax.jacobian(f, argnums=(0, 1, 2, 3, 4))(a, b, c, d, e)
print(g)
# (Array(2., dtype=float32), Array(3., dtype=float32), Array(4., dtype=float32),
#  Array(5., dtype=float32), Array(6., dtype=float32))

x0 = jnp.arange(1.0, 10.0)

J = jax.vmap(
    jax.jacobian(f, argnums=(0, 1, 2, 3, 4)),
    in_axes=(0, None, None, None, None)
)(x0, b, c, d, e)

print(jnp.column_stack(J))
# [[ 2.  3.  4.  5.  6.]
#  [ 4.  3.  4.  5.  6.]
#  [ 6.  3.  4.  5.  6.]
#  [ 8.  3.  4.  5.  6.]
#  [10.  3.  4.  5.  6.]
#  [12.  3.  4.  5.  6.]
#  [14.  3.  4.  5.  6.]
#  [16.  3.  4.  5.  6.]
#  [18.  3.  4.  5.  6.]]
0
Karl On

For you example, you would do something like this

import torch

# rewrite function to be broadcasted
def f(x, pows):
    return x.pow(pows).sum(-1)

# create batch of inputs with `requires_grad=True`
inputs = torch.ones(9,5)
inputs[:,0] = torch.arange(1.0, 10.0, 1.0)
inputs.requires_grad_(True);

# compute output
pows = torch.tensor([2,3,4,5,6])
output = f(inputs,pows)

# call backward
output.backward(torch.ones(output.shape))

print(inputs.grad)
>tensor([[ 2.,  3.,  4.,  5.,  6.],
         [ 4.,  3.,  4.,  5.,  6.],
         [ 6.,  3.,  4.,  5.,  6.],
         [ 8.,  3.,  4.,  5.,  6.],
         [10.,  3.,  4.,  5.,  6.],
         [12.,  3.,  4.,  5.,  6.],
         [14.,  3.,  4.,  5.,  6.],
         [16.,  3.,  4.,  5.,  6.],
         [18.,  3.,  4.,  5.,  6.]])