I am trying to get the second derivative of the output w.r.t the input of a neural network built using Flax. The network is structured as follows:
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax import optim
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.tanh(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
model = MLP([20, 20, 20, 20, 20, 1])
batch = jnp.ones((32, 3)) #Dummy input to Initialize the NN
params = model.init(jax.random.PRNGKey(0), batch)
X = jnp.ones((32, 3))
output = model.apply(params, X)
I can get the single derivative by using vmap over grad :
@jit
def u_function(params, X):
u = model.apply(params, X)
return jnp.squeeze(u)
grad_fn = vmap(grad(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))
u_X = vmap(grad(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))(params, X)
However, when I try to do this again to obtain the second derivative :
u_X_func = vmap(grad(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))
u_XX_func = vmap(grad(u_X_func, argnums=1), in_axes=(None, 0), out_axes=(0))(params, X)
I get the folllowing error:
[/usr/local/lib/python3.7/dist-packages/flax/linen/linear.py](https://localhost:8080/#) in __call__(self, inputs)
186 kernel = self.param('kernel',
187 self.kernel_init,
--> 188 (jnp.shape(inputs)[-1], self.features),
189 self.param_dtype)
190 if self.use_bias:
IndexError: tuple index out of range
I tried using the hvp definition from the autodiff cookbook, but with params being an input to the function just wasnt sure how to proceed.
Any help on this would be really appreciable.
The issue is that your
u_function
maps a length-3 vector to a scalar. The first derivative of this is a length-3 vector, but the second derivative of this is a 3x3 hessian matrix, which you cannot compute viajax.grad
, which is only designed for scalar-output functions. Fortunately JAX provides thejax.hessian
transform to compute these general second derivatives: