Calculating the Hessian Vector Product of a Flax NN output wrt to the inputs

415 views Asked by At

I am trying to get the second derivative of the output w.r.t the input of a neural network built using Flax. The network is structured as follows:

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax import optim

class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.tanh(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

model = MLP([20, 20, 20, 20, 20, 1])
batch = jnp.ones((32, 3)) #Dummy input to Initialize the NN
params = model.init(jax.random.PRNGKey(0), batch)
X =  jnp.ones((32, 3))
output = model.apply(params, X)

I can get the single derivative by using vmap over grad :

@jit
def u_function(params, X):
  u = model.apply(params, X)
  return jnp.squeeze(u)

grad_fn = vmap(grad(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))

u_X = vmap(grad(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))(params, X)

However, when I try to do this again to obtain the second derivative :

u_X_func = vmap(grad(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))
u_XX_func = vmap(grad(u_X_func, argnums=1), in_axes=(None, 0), out_axes=(0))(params, X)

I get the folllowing error:

[/usr/local/lib/python3.7/dist-packages/flax/linen/linear.py](https://localhost:8080/#) in __call__(self, inputs)
    186     kernel = self.param('kernel',
    187                         self.kernel_init,
--> 188                         (jnp.shape(inputs)[-1], self.features),
    189                         self.param_dtype)
    190     if self.use_bias:

IndexError: tuple index out of range

I tried using the hvp definition from the autodiff cookbook, but with params being an input to the function just wasnt sure how to proceed.

Any help on this would be really appreciable.

1

There are 1 answers

0
jakevdp On

The issue is that your u_function maps a length-3 vector to a scalar. The first derivative of this is a length-3 vector, but the second derivative of this is a 3x3 hessian matrix, which you cannot compute via jax.grad, which is only designed for scalar-output functions. Fortunately JAX provides the jax.hessian transform to compute these general second derivatives:

u_XX = vmap(hessian(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))(params, X)
print(u_XX.shape)
# (32, 3, 3)