I would need to compute the gradient of a batched function using JAX. The following is a minimal example of what I would like to do:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
x = jnp.expand_dims(jnp.linspace(-1, 1, 20), axis=1)
u = lambda x: jnp.sin(jnp.pi * x)
ux = jax.vmap(jax.grad(u))
plt.plot(x, u(x))
plt.plot(x, ux(x)) # Use vx instead of ux
plt.show()
I have tried a variety of ways of making this work using vmap, but I don't seem to be able to get the code to run without removing the batch dimension in the input x. I have seen some workarounds using the Jacobian but this doesn't seem natural as the given is a scalar function of a single variable.
In the end u will be a neural network (implemented in Flax) that I need to differentiate with respect to the input (not the parameters of the network), so I cannot remove the batch dimension.
To ensure the kernel (
u
) returns a scalar value, so thatjax.grad
makes sense, the batched dimension also needs to be mapped over.Which composition of maps to use depends on what's happening in the batched dimension.