Computing the gradient of a batched function using JAX

590 views Asked by At

I would need to compute the gradient of a batched function using JAX. The following is a minimal example of what I would like to do:

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

x = jnp.expand_dims(jnp.linspace(-1, 1, 20), axis=1)

u = lambda x: jnp.sin(jnp.pi * x)
ux = jax.vmap(jax.grad(u))

plt.plot(x, u(x))
plt.plot(x, ux(x))  # Use vx instead of ux
plt.show()

I have tried a variety of ways of making this work using vmap, but I don't seem to be able to get the code to run without removing the batch dimension in the input x. I have seen some workarounds using the Jacobian but this doesn't seem natural as the given is a scalar function of a single variable.

In the end u will be a neural network (implemented in Flax) that I need to differentiate with respect to the input (not the parameters of the network), so I cannot remove the batch dimension.

1

There are 1 answers

3
DavidJ On BEST ANSWER

To ensure the kernel (u) returns a scalar value, so that jax.grad makes sense, the batched dimension also needs to be mapped over.

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

x = jnp.expand_dims(jnp.linspace(-1, 1, 20), axis=1)

u = lambda x: jnp.sin(jnp.pi * x)
ux = jax.vmap(jax.vmap(jax.grad(u)))
# ux = lambda x : jax.lax.map(jax.vmap(jax.grad(u)), x) # sequential version
# ux = lambda x : jax.vmap(jax.grad(u))(x.reshape(-1)).reshape(x.shape) # flattened map version

plt.plot(x, u(x))
plt.plot(x, ux(x))  # Use vx instead of ux
plt.show()

Which composition of maps to use depends on what's happening in the batched dimension.