How to debug exploding gradient (covariance matrix) in Tensorflow 2.0 (TFP)

317 views Asked by At

A question that comes from the fact that I never had to debug my models in TF so deeply.

I'm running a variational inference with a full-rank Gaussian approximation using Tensorflow Probability. I noticed my optimization often explodes. Here is my loss curve.

I suspect numerical issues, as all the losses and the optimization process look reasonable and I don't observe any NaNs.

I use tfp.distributions.MultivariateNormalTriL with a covariance parameter transformed by tfp.bijectors.FillScaleTriL with the default diagonal shift. The condition number of the covariance matrix is reasonable. The variational inference is performed with fit_surrogate_posterior function.

I optimize with an SGD with momentum, using 10 samples per iteration.

Internally in Tensorflow Probability source code, the minimization objective uses a gradient tape:

   with tf.GradientTape(watch_accessed_variables=trainable_variables is None) as tape:
      for v in trainable_variables or []:
        tape.watch(v)
      loss = loss_fn()

In order to solve my issue I would like to see the gradient through every operation.

My question is how can I get more insight into which operation is exploding by the gradient computation? How to get the value of gradient at every tensor?

And if any of you faced a similar issue: Is there a better way to prevent instabilities in the covariance matrix optimization?

Detailed explanations:

I observed that this explosion is caused by one parameter (though it is not always the same parameter that explodes). This can be simply checked by comparing the covariance matrix two iterations before the explosion

and one iteration before the point where the loss explodes

Note the last parameter. When I run the same optimization multiple times, it might happen that one of the "small" parameters (rows from 9 to the last) explodes at some point.

Thanks, Mateusz

1

There are 1 answers

0
Yaoshiang On

The professional answer is to run in eager mode and attach a debugger, but that never ends up being as easy as it sounds and eager mode and graph mode do not behave exactly the same.

An exploding gradient is usually caused by a weight that's too big, or, too close to zero and used in division or log. I don't know the math behind your work but when I hear variance transfer, I wonder if you might be doing a fancier version of taking a distribution, dividing by it's variance to try to get unit variance... well, if your variance is very low like 1e-5 but there's a single outlier of value 1... you just got a huge weight of 1e5.

I'd personally start by placing tf.debugging.check_numerics(...) all over the place. In a pinch, if you think the values will eventually converge to reasonable values, you can also try gradient clipping.

Another idea - if you are doing something along the lines of

distribution_a /= std(distribution_a)
distribution a *= std(distribution_b)

You might rewrite this as

factor = std(distribution_b) / std(distribution_a + epsilon)
factor = max(factor, 5.0)
distribution_a *= factor