Computing dot product of gradients with itself for a neural network model in JAX

180 views Asked by At

I have the following piece of code of JAX with my neural network model -- model:

(loss, (inner_state, logits)), grad = jax.value_and_grad(
    lambda m: forward_and_loss(m, true_gradient=True), has_aux=True)(model)

so grad is actually a type of Flax.nn.Model (I have checked that) .

The function forward_and_loss just computes the loss and something of the model as follows:

 def forward_and_loss(model: flax.nn.Model, true_gradient: bool = False):
"""Returns the model's loss, updated state and predictions.

Args:
  model: The model that we are training.
  true_gradient: If true, the same mixing parameter will be used for the
    forward and backward pass for the Shake Shake and Shake Drop
    regularization (see papers for more details).
"""
with flax.nn.stateful(state) as new_state:
  with flax.nn.stochastic(prng_key):
    try:
      logits = model(
          batch['image'], train=True, true_gradient=true_gradient)
    except TypeError:
      logits = model(batch['image'], train=True)
loss = cross_entropy_loss(logits, batch['label'])
# We apply weight decay to all parameters, including bias and batch norm
# parameters.
weight_penalty_params = jax.tree_leaves(model.params)
if FLAGS.no_weight_decay_on_bn:
  weight_l2 = sum(
      [jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1])
else:
  weight_l2 = sum([jnp.sum(x ** 2) for x in weight_penalty_params])
weight_penalty = l2_reg * 0.5 * weight_l2
loss = loss + weight_penalty
return loss, (new_state, logits)

now I would like to compute the dot product of grad with itself, i.e. grad * grad (furthermore, I want one of these two grads stopping gradient, i.e. grad * jax.lax.stop_gradient(grad), but let's just ignore this at the moment). Let's denote it by grad_grad.

then I need to take the gradient of grad_grad, so getting something like hessian*grad, let's denote it by hessian_grad. Then I want to take grad + alpha * hessian_grad. Can anyone show me how to write code in JAX to do this? I am not quite familiar with JAX

Note that we must have the type of hessian_grad be the same as grad, which is the type of flax.nn.model. We can't simply compute hessian_grad as a function.

1

There are 1 answers

5
jakevdp On

I'm not entirely clear on what you're trying to do. But you said that grads is a flax.nn.Model, which makes me think that the operation you have in mind is element-wise multiplication of each array in the model by itself. In that case, you can do it with jax.tree_util.tree_map:

result = jax.tree_util.tree_map(lambda x: x * x, grads)

If that's not what you have in mind, then perhaps you could expand your question with more information on what exactly grads contains, and what it is that you'd like to compute.

Edit: based on your updated question, I suspect what you have in mind is

result = jax.tree_util.tree_map(lambda g: g @ g.T, grads)