I have the following piece of code of JAX with my neural network model -- model:
(loss, (inner_state, logits)), grad = jax.value_and_grad(
lambda m: forward_and_loss(m, true_gradient=True), has_aux=True)(model)
so grad is actually a type of Flax.nn.Model (I have checked that) .
The function forward_and_loss just computes the loss and something of the model as follows:
def forward_and_loss(model: flax.nn.Model, true_gradient: bool = False):
"""Returns the model's loss, updated state and predictions.
Args:
model: The model that we are training.
true_gradient: If true, the same mixing parameter will be used for the
forward and backward pass for the Shake Shake and Shake Drop
regularization (see papers for more details).
"""
with flax.nn.stateful(state) as new_state:
with flax.nn.stochastic(prng_key):
try:
logits = model(
batch['image'], train=True, true_gradient=true_gradient)
except TypeError:
logits = model(batch['image'], train=True)
loss = cross_entropy_loss(logits, batch['label'])
# We apply weight decay to all parameters, including bias and batch norm
# parameters.
weight_penalty_params = jax.tree_leaves(model.params)
if FLAGS.no_weight_decay_on_bn:
weight_l2 = sum(
[jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1])
else:
weight_l2 = sum([jnp.sum(x ** 2) for x in weight_penalty_params])
weight_penalty = l2_reg * 0.5 * weight_l2
loss = loss + weight_penalty
return loss, (new_state, logits)
now I would like to compute the dot product of grad with itself, i.e. grad * grad (furthermore, I want one of these two grads stopping gradient, i.e. grad * jax.lax.stop_gradient(grad), but let's just ignore this at the moment). Let's denote it by grad_grad.
then I need to take the gradient of grad_grad, so getting something like hessian*grad, let's denote it by hessian_grad. Then I want to take grad + alpha * hessian_grad. Can anyone show me how to write code in JAX to do this? I am not quite familiar with JAX
Note that we must have the type of hessian_grad be the same as grad, which is the type of flax.nn.model. We can't simply compute hessian_grad as a function.
I'm not entirely clear on what you're trying to do. But you said that
grads
is aflax.nn.Model
, which makes me think that the operation you have in mind is element-wise multiplication of each array in the model by itself. In that case, you can do it withjax.tree_util.tree_map
:If that's not what you have in mind, then perhaps you could expand your question with more information on what exactly
grads
contains, and what it is that you'd like to compute.Edit: based on your updated question, I suspect what you have in mind is