NaN gradients in Jax softmax

101 views Asked by At

I have encountered some mysterious NaN gradient issues in my training process of a model implemented through flax. With the help of jax_debug_nans, I am able to identify that it comes from the gradient of an implemented transformer block which's code is provided below, however it is rather difficult to really understand where goes wrong.

The error stack trace captured through jax_debug_nans is:

Traceback (most recent call last):
  File "/home/_/_/codes/_/_/main.py", line 61, in <module>
    app.run(main)
  File "/home/_/_/_/envs/_/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/_/_/_/envs/_/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/_/_/codes/_/_/main.py", line 53, in main
    train_eval.train(FLAGS.config, FLAGS.workdir)
  File "/home/_/_/codes/_/_/_/train_eval.py", line 143, in train
    (rng, training_state), loss = train_step_fn((rng, training_state), processed_data)
  File "/home/_/_/codes/_/_/_/training/losses.py", line 304, in step_fn
    (loss, (y_pred_mu, y_pred_sigma, mu_context, sigma_context, mu_tgt, sigma_tgt, mc_mean, kl_mean)), grad = grad_fn(step_rng, params, batch) 
  File "/home/_/_/codes/_/_/_/training/losses.py", line 29, in anonymous_loss
    jax.vmap(partial_model_apply, in_axes=0)(data_x, data_y, data_x, data_y, context_mask, target_mask)
  File "/home/_/_/codes/_/_/_/training/losses.py", line 24, in <lambda>
    model.apply(variables, x_context=x_ctx, y_context=y_ctx, x_target=x_tgt, \
  File "/home/_/_/codes/_/_/_/models/model.py", line 446, in __call__
    v_star = self.query_specific_encode(x_context, y_context, context_mask, x_target)
  File "/home/_/_/codes/_/_/_/models/model.py", line 408, in query_specific_encode
    v_star = self.qkv_to_v_star(q, k, v, ctx_mask)
  File "/home/_/_/codes/_/_/_/models/utils/nn.py", line 152, in __call__
    h = _scaled_dot_product_attention(qs, ks, vs, 
  File "/home/_/_/codes/_/_/_/models/utils/nn.py", line 34, in _scaled_dot_product_attention
    ws = softmax(
  File "/home/_/_/_/envs/_/lib/python3.10/site-packages/jax/_src/nn/functions.py", line 352, in softmax
    return _softmax_deprecated(x, axis, where, initial)
  File "/home/_/_/_/envs/_/lib/python3.10/site-packages/jax/_src/nn/functions.py", line 377, in _softmax_deprecated
    result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
  File "/home/_/_/_/envs/_/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 791, in op
    return getattr(self.aval, f"_{name}")(self, *args)
  File "/home/_/_/_/envs/_/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 258, in deferring_binary_op
    return binary_op(*args)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: FloatingPointError: invalid value (nan) encountered in jit(mul)

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

I have done extensive investigations and some findings are:

  • This is not due to too big learning rate (1e-4), the loss curve and the weights & grads limit look normal (see fig at the end), also the global grad norm clipping is imposed (same as Optax quick start notebook)
  • The issue only comes from specific data in a batch, all the rest data is working, I have double-checked with all the data and there is no NaN/Inf at all.
  • I have tried to either use jax_softmax_custom_jvp to use the not deprecated softmax or use config.update("jax_enable_x64", True) to potentially remove any issue dealing with numerical accuracy, however there is no luck.
  • NaN grad only occurs in grad (not loss value and the parameters to calculate the loss value), more specifically, NaN grad mainly occurs in attention blocks and is provided in this gist

The attention block is implemented as:

def _scaled_dot_product_attention(Q: jax.Array, K: jax.Array, V: jax.Array, 
                                  Q_mask: jax.Array, K_mask: jax.Array) -> jax.Array:
    d_k = Q.shape[-1]
    ws = softmax(
        np.matmul(Q * Q_mask[..., None], 
                  rearrange(K * K_mask[..., None], '... seq_length key_dim -> ... key_dim seq_length')) / d_k ** 0.5, \
                    where=K_mask, initial=0.0, )
    return np.matmul(ws, (V * K_mask[..., None]))


class MultiHeadCrossAttentionBlock(MultiHeadSelfAttentionBlock):
    """
    The corss attention block
    """

    def __call__(self, queries: jax.Array, keys: jax.Array, values: jax.Array, 
                 keys_mask: Optional[jax.Array]) -> Array:
        rearrange_arg = (
            "... (num_heads key_dim) -> num_heads ... key_dim"
        )
        qs = rearrange(
            self.projs_q(queries),
            rearrange_arg,
            num_heads=self.heads_num,
            key_dim=self.key_dim,
        )
        ks = rearrange(
            self.projs_k(keys),
            rearrange_arg,
            num_heads=self.heads_num,
            key_dim=self.key_dim,
        )
        vs = rearrange(
            self.projs_v(values),
            rearrange_arg,
            num_heads=self.heads_num,
            key_dim=self.key_dim,
        )
        h = _scaled_dot_product_attention(qs, ks, vs, 
                                          Q_mask = np.ones(shape=(queries.shape[:-1])).astype(np.bool_),
                                          K_mask=keys_mask)  # [num_heads, ..., H]
        # concatenate and projection
        h = self.proj(
            np.squeeze(
                np.concatenate(
                    np.split(h, indices_or_sections=self.heads_num, axis=0), axis=-1
                ),
                axis=0,
            )
        )  # proj: [num_heads, ..., H] -> [..., num_heads * H] -> [..., H]
        return h

am almost running out of my head and stuck for a while, any hints on this issue would be greatly appreciated!


Additional info:

enter image description here

enter image description here

0

There are 0 answers